Mercurial > hg > tvii
view tests/test_kmeans.py @ 90:3ff05538259c
mnist example
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 17 Dec 2017 14:23:35 -0800 |
parents | 596dac7f3e98 |
children |
line wrap: on
line source
#!/usr/bin/env python """ tests K means algorithm """ import unittest from tvii import kmeans from tvii.dataset.circle import CircularRandom class TestKMeans(unittest.TestCase): def test_dualing_gaussians(self): """tests two gaussian distributions; first, cut overlap""" # TODO def test_circles(self): """test with two circles of points""" # generate two non-overlapping circles n_points = 10000 # per circle p1 = CircularRandom((-1.5, 0), 1)(n_points) p2 = CircularRandom((1.5, 0), 1)(n_points) # run kmeans classes, centroids = kmeans.kmeans(p1+p2, 2) # sanity assert len(centroids) == 2 assert len(classes) == 2 # the centroids should have opposite x values xprod = centroids[0][0] * centroids[1][0] assert xprod < 0. assert abs(xprod + 2.25) < 0.1 # assert we're kinda close for c in centroids: c = [abs(i) for i in c] assert abs(c[0]-1.5) < 0.1 assert abs(c[1]) < 0.1 # its a pretty clean break; our points should be exact, most likely if centroids[0][0] < 0.: left = 0 right = 1 else: left = 1 right = 0 assert sorted(p1) == sorted(classes[left]) assert sorted(p2) == sorted(classes[right]) def test_help(self): """smoketest for CLI""" try: kmeans.main(['--help']) except SystemExit: # this is expected pass if __name__ == '__main__': unittest.main()