92
|
1 """
|
|
2 test linear regression
|
|
3 """
|
|
4
|
|
5 import csv
|
|
6 import math
|
|
7 import os
|
|
8 import numpy as np
|
|
9 import random
|
|
10 from tvii import linear_regression
|
|
11 from tvii.noise import add_noise
|
|
12
|
|
13
|
|
14 def test_linear_regression():
|
|
15 """Make sure we can do `W*x + b = y` properly"""
|
|
16
|
|
17 # training data: exact fit, W=-1, b=1
|
|
18 x_train = [1,2,3,4]
|
|
19 y_train = [0,-1,-2,-3]
|
|
20
|
|
21 # our guesses
|
|
22 W_guess = 0. # Why not? Be bold
|
|
23 b_guess = 0.
|
|
24
|
|
25 # perform the regression
|
|
26 W, b, loss = linear_regression.linear_regression(x_train,
|
|
27 y_train,
|
|
28 W_guess=W_guess,
|
|
29 b_guess=b_guess)
|
|
30 # make sure we're close
|
|
31 W_exact = -1.
|
|
32 b_exact = 1.
|
|
33
|
|
34 assert abs(W - W_exact) < 1e-5
|
|
35 assert abs(b - b_exact) < 1e-5
|
|
36
|
|
37
|
|
38 def test_linear_regression_noisy():
|
|
39 """
|
|
40 Make sure we can do `W*x + b = y` with some noise
|
|
41 """
|
|
42
|
|
43 # start simple
|
|
44 slope = 1.5 # rises 3 every 2
|
|
45 intercept = random.random() * 5.
|
|
46 line = lambda x: slope*x + intercept
|
|
47
|
|
48 # make range
|
|
49 # TODO: np.linspace(-10., 10, 100)
|
|
50 xspan = (-10., 10.)
|
|
51 npoints = 100
|
|
52 dx = (xspan[-1] - xspan[0])/(npoints-1.)
|
|
53 xi = [xspan[0]+dx*i
|
|
54 for i in range(npoints)]
|
|
55
|
|
56 # add some noise to it
|
|
57 x = add_noise(xi, fraction=0.01)
|
|
58 assert len(x) == len(xi)
|
|
59 assert x != xi
|
|
60 assert x == sorted(x)
|
|
61
|
|
62 # calculate true y
|
|
63 truey = [line(xx) for xx in x]
|
|
64
|
|
65 # add some noise to that
|
|
66 y = add_noise(truey, fraction=0.01)
|
|
67 assert len(y) == len(truey)
|
|
68
|
|
69 # you're now all set up for your regression
|
|
70 W, b, loss = linear_regression.linear_regression(x, y)
|
|
71
|
|
72 # Show us what you got!
|
|
73 # TODO: this gives nan for both `W` and `b`
|
|
74 # The lines loop okay so I'm guessing some sort of
|
|
75 # numerical instability
|
|
76 try:
|
|
77 assert W == slope # XXX shouldn't be exactly equal anyway
|
|
78 except AssertionError:
|
|
79 dumpfile = os.environ.get('NETTWERK_FAILURE')
|
|
80 if dumpfile:
|
|
81 # dump the points
|
|
82 with open(dumpfile, 'w') as f:
|
|
83 writer = csv.writer(f)
|
|
84 writer.writerows(zip(x, y))
|
|
85 pass # XXX ignoring true negative :(
|