diff tests/test_logistic_regression.py @ 29:cf7584f0a29f

test linear regression
author Jeff Hammel <k0scist@gmail.com>
date Mon, 04 Sep 2017 12:01:57 -0700
parents 77f68c241b37
children fa7a51df0d90
line wrap: on
line diff
--- a/tests/test_logistic_regression.py	Mon Sep 04 11:53:23 2017 -0700
+++ b/tests/test_logistic_regression.py	Mon Sep 04 12:01:57 2017 -0700
@@ -37,10 +37,17 @@
         grads, cost = logistic_regression.propagate(w, b, X, Y)
 
         # compare to expected,
-        dw_expected = [[ 0.99993216], [ 1.99980262]]
+        dw_expected = np.array([[ 0.99993216], [ 1.99980262]])
         db_expected = 0.499935230625
         cost_expected = 6.000064773192205
 
+        self.assertAlmostEqual(cost_expected, cost)
+        self.assertAlmostEqual(grads['db'], db_expected)
+        assert grads['dw'].shape == dw_expected.shape
+        for a, b in zip(grads['dw'].flatten(),
+                           dw_expected.flatten()):
+            self.assertAlmostEqual(a, b)
+
 
 if __name__ == '__main__':
     unittest.main()