comparison tests/test_linear_regression.py @ 92:f1d1f2388fd6

test linear regression
author Jeff Hammel <k0scist@gmail.com>
date Sun, 17 Dec 2017 14:26:15 -0800
parents
children
comparison
equal deleted inserted replaced
91:d603ee579c3e 92:f1d1f2388fd6
1 """
2 test linear regression
3 """
4
5 import csv
6 import math
7 import os
8 import numpy as np
9 import random
10 from tvii import linear_regression
11 from tvii.noise import add_noise
12
13
14 def test_linear_regression():
15 """Make sure we can do `W*x + b = y` properly"""
16
17 # training data: exact fit, W=-1, b=1
18 x_train = [1,2,3,4]
19 y_train = [0,-1,-2,-3]
20
21 # our guesses
22 W_guess = 0. # Why not? Be bold
23 b_guess = 0.
24
25 # perform the regression
26 W, b, loss = linear_regression.linear_regression(x_train,
27 y_train,
28 W_guess=W_guess,
29 b_guess=b_guess)
30 # make sure we're close
31 W_exact = -1.
32 b_exact = 1.
33
34 assert abs(W - W_exact) < 1e-5
35 assert abs(b - b_exact) < 1e-5
36
37
38 def test_linear_regression_noisy():
39 """
40 Make sure we can do `W*x + b = y` with some noise
41 """
42
43 # start simple
44 slope = 1.5 # rises 3 every 2
45 intercept = random.random() * 5.
46 line = lambda x: slope*x + intercept
47
48 # make range
49 # TODO: np.linspace(-10., 10, 100)
50 xspan = (-10., 10.)
51 npoints = 100
52 dx = (xspan[-1] - xspan[0])/(npoints-1.)
53 xi = [xspan[0]+dx*i
54 for i in range(npoints)]
55
56 # add some noise to it
57 x = add_noise(xi, fraction=0.01)
58 assert len(x) == len(xi)
59 assert x != xi
60 assert x == sorted(x)
61
62 # calculate true y
63 truey = [line(xx) for xx in x]
64
65 # add some noise to that
66 y = add_noise(truey, fraction=0.01)
67 assert len(y) == len(truey)
68
69 # you're now all set up for your regression
70 W, b, loss = linear_regression.linear_regression(x, y)
71
72 # Show us what you got!
73 # TODO: this gives nan for both `W` and `b`
74 # The lines loop okay so I'm guessing some sort of
75 # numerical instability
76 try:
77 assert W == slope # XXX shouldn't be exactly equal anyway
78 except AssertionError:
79 dumpfile = os.environ.get('NETTWERK_FAILURE')
80 if dumpfile:
81 # dump the points
82 with open(dumpfile, 'w') as f:
83 writer = csv.writer(f)
84 writer.writerows(zip(x, y))
85 pass # XXX ignoring true negative :(