Mercurial > hg > tvii
diff tests/test_logistic_regression.py @ 28:77f68c241b37
[logistic regression] propagate
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Mon, 04 Sep 2017 11:53:23 -0700 |
parents | f34110e28a0a |
children | cf7584f0a29f |
line wrap: on
line diff
--- a/tests/test_logistic_regression.py Mon Sep 04 11:38:46 2017 -0700 +++ b/tests/test_logistic_regression.py Mon Sep 04 11:53:23 2017 -0700 @@ -9,6 +9,7 @@ import unittest from tvii import logistic_regression + class LogisticRegresionTests(unittest.TestCase): def test_cost(self): @@ -23,5 +24,23 @@ cost = logistic_regression.cost_function(w, b, X, Y) assert abs(cost - expected_cost) < 1e-6 + def test_propagate(self): + """test canned logistic regression example""" + + # sample variables + w = np.array([[1],[2]]) + b = 2 + X = np.array([[1,2],[3,4]]) + Y = np.array([[1,0]]) + + # calculate gradient and cost + grads, cost = logistic_regression.propagate(w, b, X, Y) + + # compare to expected, + dw_expected = [[ 0.99993216], [ 1.99980262]] + db_expected = 0.499935230625 + cost_expected = 6.000064773192205 + + if __name__ == '__main__': unittest.main()