# HG changeset patch # User Jeff Hammel # Date 1513549575 28800 # Node ID f1d1f2388fd69d0ba9b746f9a5d48c5d70413c91 # Parent d603ee579c3ebfd43364d585b25f7b14f03cce9f test linear regression diff -r d603ee579c3e -r f1d1f2388fd6 tests/test_linear_regression.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tests/test_linear_regression.py Sun Dec 17 14:26:15 2017 -0800 @@ -0,0 +1,85 @@ +""" +test linear regression +""" + +import csv +import math +import os +import numpy as np +import random +from tvii import linear_regression +from tvii.noise import add_noise + + +def test_linear_regression(): + """Make sure we can do `W*x + b = y` properly""" + + # training data: exact fit, W=-1, b=1 + x_train = [1,2,3,4] + y_train = [0,-1,-2,-3] + + # our guesses + W_guess = 0. # Why not? Be bold + b_guess = 0. + + # perform the regression + W, b, loss = linear_regression.linear_regression(x_train, + y_train, + W_guess=W_guess, + b_guess=b_guess) + # make sure we're close + W_exact = -1. + b_exact = 1. + + assert abs(W - W_exact) < 1e-5 + assert abs(b - b_exact) < 1e-5 + + +def test_linear_regression_noisy(): + """ + Make sure we can do `W*x + b = y` with some noise + """ + + # start simple + slope = 1.5 # rises 3 every 2 + intercept = random.random() * 5. + line = lambda x: slope*x + intercept + + # make range + # TODO: np.linspace(-10., 10, 100) + xspan = (-10., 10.) + npoints = 100 + dx = (xspan[-1] - xspan[0])/(npoints-1.) + xi = [xspan[0]+dx*i + for i in range(npoints)] + + # add some noise to it + x = add_noise(xi, fraction=0.01) + assert len(x) == len(xi) + assert x != xi + assert x == sorted(x) + + # calculate true y + truey = [line(xx) for xx in x] + + # add some noise to that + y = add_noise(truey, fraction=0.01) + assert len(y) == len(truey) + + # you're now all set up for your regression + W, b, loss = linear_regression.linear_regression(x, y) + + # Show us what you got! + # TODO: this gives nan for both `W` and `b` + # The lines loop okay so I'm guessing some sort of + # numerical instability + try: + assert W == slope # XXX shouldn't be exactly equal anyway + except AssertionError: + dumpfile = os.environ.get('NETTWERK_FAILURE') + if dumpfile: + # dump the points + with open(dumpfile, 'w') as f: + writer = csv.writer(f) + writer.writerows(zip(x, y)) + pass # XXX ignoring true negative :(