Mercurial > hg > tvii
annotate tests/test_logistic_regression.py @ 28:77f68c241b37
[logistic regression] propagate
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Mon, 04 Sep 2017 11:53:23 -0700 |
parents | f34110e28a0a |
children | cf7584f0a29f |
rev | line source |
---|---|
11 | 1 #!/usr/bin/env python |
2 | |
3 """ | |
4 test logistic regression | |
5 """ | |
6 | |
22
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
7 import numpy as np |
11 | 8 import os |
9 import unittest | |
10 from tvii import logistic_regression | |
11 | |
28
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
12 |
11 | 13 class LogisticRegresionTests(unittest.TestCase): |
22
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
14 |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
15 def test_cost(self): |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
16 """test cost function""" |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
17 |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
18 w, b, X, Y = (np.array([[1],[2]]), |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
19 2, |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
20 np.array([[1,2],[3,4]]), |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
21 np.array([[1,0]])) |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
22 |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
23 expected_cost = 6.000064773192205 |
3713c6733990
[logistic regression] introduce illustrative test
Jeff Hammel <k0scist@gmail.com>
parents:
11
diff
changeset
|
24 cost = logistic_regression.cost_function(w, b, X, Y) |
23
f34110e28a0a
[logistic regression] we have a working cost function
Jeff Hammel <k0scist@gmail.com>
parents:
22
diff
changeset
|
25 assert abs(cost - expected_cost) < 1e-6 |
11 | 26 |
28
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
27 def test_propagate(self): |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
28 """test canned logistic regression example""" |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
29 |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
30 # sample variables |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
31 w = np.array([[1],[2]]) |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
32 b = 2 |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
33 X = np.array([[1,2],[3,4]]) |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
34 Y = np.array([[1,0]]) |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
35 |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
36 # calculate gradient and cost |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
37 grads, cost = logistic_regression.propagate(w, b, X, Y) |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
38 |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
39 # compare to expected, |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
40 dw_expected = [[ 0.99993216], [ 1.99980262]] |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
41 db_expected = 0.499935230625 |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
42 cost_expected = 6.000064773192205 |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
43 |
77f68c241b37
[logistic regression] propagate
Jeff Hammel <k0scist@gmail.com>
parents:
23
diff
changeset
|
44 |
11 | 45 if __name__ == '__main__': |
46 unittest.main() |