Mercurial > hg > tvii
comparison tests/test_logistic_regression.py @ 31:fa7a51df0d90
[logistic regression] test gradient descent
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Mon, 04 Sep 2017 12:37:45 -0700 |
parents | cf7584f0a29f |
children | 0f29b02f4806 |
comparison
equal
deleted
inserted
replaced
30:ae0c345ea09d | 31:fa7a51df0d90 |
---|---|
9 import unittest | 9 import unittest |
10 from tvii import logistic_regression | 10 from tvii import logistic_regression |
11 | 11 |
12 | 12 |
13 class LogisticRegresionTests(unittest.TestCase): | 13 class LogisticRegresionTests(unittest.TestCase): |
14 | |
15 def compare_arrays(self, a, b): | |
16 assert a.shape == b.shape | |
17 for x, y in zip(a.flatten(), | |
18 b.flatten()): | |
19 self.assertAlmostEqual(x, y) | |
20 | |
14 | 21 |
15 def test_cost(self): | 22 def test_cost(self): |
16 """test cost function""" | 23 """test cost function""" |
17 | 24 |
18 w, b, X, Y = (np.array([[1],[2]]), | 25 w, b, X, Y = (np.array([[1],[2]]), |
46 assert grads['dw'].shape == dw_expected.shape | 53 assert grads['dw'].shape == dw_expected.shape |
47 for a, b in zip(grads['dw'].flatten(), | 54 for a, b in zip(grads['dw'].flatten(), |
48 dw_expected.flatten()): | 55 dw_expected.flatten()): |
49 self.assertAlmostEqual(a, b) | 56 self.assertAlmostEqual(a, b) |
50 | 57 |
58 def test_optimize(self): | |
59 """test gradient descent method""" | |
60 | |
61 # test examples | |
62 w, b, X, Y = np.array([[1],[2]]), 2, np.array([[1,2],[3,4]]), np.array([[1,0]]) | |
63 | |
64 params, grads, costs = logistic_regression.optimize(w, b, X, Y, num_iterations= 100, learning_rate = 0.009, print_cost = False) | |
65 | |
66 # expected output | |
67 w_expected = np.array([[0.1124579 ], | |
68 [0.23106775]]) | |
69 dw_expected = np.array([[ 0.90158428], | |
70 [ 1.76250842]]) | |
71 b_expected = 1.55930492484 | |
72 db_expected = 0.430462071679 | |
73 | |
74 # compare output | |
75 self.assertAlmostEqual(params['b'], b_expected) | |
76 self.assertAlmostEqual(grads['db'], db_expected) | |
77 self.compare_arrays(w_expected, params['w']) | |
78 self.compare_arrays(dw_expected, grads['dw']) | |
79 | |
51 | 80 |
52 if __name__ == '__main__': | 81 if __name__ == '__main__': |
53 unittest.main() | 82 unittest.main() |