diff tests/test_logistic_regression.py @ 31:fa7a51df0d90

[logistic regression] test gradient descent
author Jeff Hammel <k0scist@gmail.com>
date Mon, 04 Sep 2017 12:37:45 -0700
parents cf7584f0a29f
children 0f29b02f4806
line wrap: on
line diff
--- a/tests/test_logistic_regression.py	Mon Sep 04 12:04:58 2017 -0700
+++ b/tests/test_logistic_regression.py	Mon Sep 04 12:37:45 2017 -0700
@@ -12,6 +12,13 @@
 
 class LogisticRegresionTests(unittest.TestCase):
 
+    def compare_arrays(self, a, b):
+        assert a.shape == b.shape
+        for x, y in zip(a.flatten(),
+                        b.flatten()):
+            self.assertAlmostEqual(x, y)
+
+
     def test_cost(self):
         """test cost function"""
 
@@ -48,6 +55,28 @@
                            dw_expected.flatten()):
             self.assertAlmostEqual(a, b)
 
+    def test_optimize(self):
+        """test gradient descent method"""
+
+        # test examples
+        w, b, X, Y = np.array([[1],[2]]), 2, np.array([[1,2],[3,4]]), np.array([[1,0]])
+
+        params, grads, costs = logistic_regression.optimize(w, b, X, Y, num_iterations= 100, learning_rate = 0.009, print_cost = False)
+
+        # expected output
+        w_expected = np.array([[0.1124579 ],
+                               [0.23106775]])
+        dw_expected = np.array([[ 0.90158428],
+                                [ 1.76250842]])
+        b_expected = 1.55930492484
+        db_expected = 0.430462071679
+
+        # compare output
+        self.assertAlmostEqual(params['b'], b_expected)
+        self.assertAlmostEqual(grads['db'], db_expected)
+        self.compare_arrays(w_expected, params['w'])
+        self.compare_arrays(dw_expected, grads['dw'])
+
 
 if __name__ == '__main__':
     unittest.main()