view tests/test_logistic_regression.py @ 29:cf7584f0a29f

test linear regression
author Jeff Hammel <k0scist@gmail.com>
date Mon, 04 Sep 2017 12:01:57 -0700
parents 77f68c241b37
children fa7a51df0d90
line wrap: on
line source

#!/usr/bin/env python

"""
test logistic regression
"""

import numpy as np
import os
import unittest
from tvii import logistic_regression


class LogisticRegresionTests(unittest.TestCase):

    def test_cost(self):
        """test cost function"""

        w, b, X, Y = (np.array([[1],[2]]),
                      2,
                      np.array([[1,2],[3,4]]),
                      np.array([[1,0]]))

        expected_cost = 6.000064773192205
        cost = logistic_regression.cost_function(w, b, X, Y)
        assert abs(cost - expected_cost) < 1e-6

    def test_propagate(self):
        """test canned logistic regression example"""

        # sample variables
        w = np.array([[1],[2]])
        b = 2
        X = np.array([[1,2],[3,4]])
        Y = np.array([[1,0]])

        # calculate gradient and cost
        grads, cost = logistic_regression.propagate(w, b, X, Y)

        # compare to expected,
        dw_expected = np.array([[ 0.99993216], [ 1.99980262]])
        db_expected = 0.499935230625
        cost_expected = 6.000064773192205

        self.assertAlmostEqual(cost_expected, cost)
        self.assertAlmostEqual(grads['db'], db_expected)
        assert grads['dw'].shape == dw_expected.shape
        for a, b in zip(grads['dw'].flatten(),
                           dw_expected.flatten()):
            self.assertAlmostEqual(a, b)


if __name__ == '__main__':
    unittest.main()