Mercurial > hg > tvii
diff tests/test_kmeans.py @ 87:9d5a5e9f5c3b
add kmeans + dataset
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 17 Dec 2017 14:05:57 -0800 |
parents | |
children | 596dac7f3e98 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tests/test_kmeans.py Sun Dec 17 14:05:57 2017 -0800 @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +""" +tests K means algorithm +""" + +import unittest +from tvii import kmeans +from nettwerk.dataset.circle import CircularRandom + + +class TestKMeans(unittest.TestCase): + + def test_dualing_gaussians(self): + """tests two gaussian distributions; first, cut overlap""" + # TODO + + def test_circles(self): + """test with two circles of points""" + + # generate two non-overlapping circles + n_points = 10000 # per circle + p1 = CircularRandom((-1.5, 0), 1)(n_points) + p2 = CircularRandom((1.5, 0), 1)(n_points) + + # run kmeans + classes, centroids = kmeans.kmeans(p1+p2, 2) + + # sanity + assert len(centroids) == 2 + assert len(classes) == 2 + + # the centroids should have opposite x values + xprod = centroids[0][0] * centroids[1][0] + assert xprod < 0. + assert abs(xprod + 2.25) < 0.1 + + # assert we're kinda close + for c in centroids: + c = [abs(i) for i in c] + assert abs(c[0]-1.5) < 0.1 + assert abs(c[1]) < 0.1 + + # its a pretty clean break; our points should be exact, most likely + if centroids[0][0] < 0.: + left = 0 + right = 1 + else: + left = 1 + right = 0 + assert sorted(p1) == sorted(classes[left]) + assert sorted(p2) == sorted(classes[right]) + + def test_help(self): + """smoketest for CLI""" + + try: + kmeans.main(['--help']) + except SystemExit: + # this is expected + pass + + +if __name__ == '__main__': + unittest.main()