view tests/test_logistic_regression.py @ 24:89f46435a9e2

[logistic regression] call cost function
author Jeff Hammel <k0scist@gmail.com>
date Mon, 04 Sep 2017 09:58:01 -0700
parents f34110e28a0a
children 77f68c241b37
line wrap: on
line source

#!/usr/bin/env python

"""
test logistic regression
"""

import numpy as np
import os
import unittest
from tvii import logistic_regression

class LogisticRegresionTests(unittest.TestCase):

    def test_cost(self):
        """test cost function"""

        w, b, X, Y = (np.array([[1],[2]]),
                      2,
                      np.array([[1,2],[3,4]]),
                      np.array([[1,0]]))

        expected_cost = 6.000064773192205
        cost = logistic_regression.cost_function(w, b, X, Y)
        assert abs(cost - expected_cost) < 1e-6

if __name__ == '__main__':
    unittest.main()