diff tests/test_kmeans.py @ 87:9d5a5e9f5c3b

add kmeans + dataset
author Jeff Hammel <k0scist@gmail.com>
date Sun, 17 Dec 2017 14:05:57 -0800
parents
children 596dac7f3e98
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/test_kmeans.py	Sun Dec 17 14:05:57 2017 -0800
@@ -0,0 +1,65 @@
+#!/usr/bin/env python
+
+"""
+tests K means algorithm
+"""
+
+import unittest
+from tvii import kmeans
+from nettwerk.dataset.circle import CircularRandom
+
+
+class TestKMeans(unittest.TestCase):
+
+    def test_dualing_gaussians(self):
+        """tests two gaussian distributions;  first, cut overlap"""
+        # TODO
+
+    def test_circles(self):
+        """test with two circles of points"""
+
+        # generate two non-overlapping circles
+        n_points = 10000   # per circle
+        p1 = CircularRandom((-1.5, 0), 1)(n_points)
+        p2 = CircularRandom((1.5, 0), 1)(n_points)
+
+        # run kmeans
+        classes, centroids = kmeans.kmeans(p1+p2, 2)
+
+        # sanity
+        assert len(centroids) == 2
+        assert len(classes) == 2
+
+        # the centroids should have opposite x values
+        xprod = centroids[0][0] * centroids[1][0]
+        assert xprod < 0.
+        assert abs(xprod + 2.25) < 0.1
+
+        # assert we're kinda close
+        for c in centroids:
+            c = [abs(i) for i in c]
+            assert abs(c[0]-1.5) < 0.1
+            assert abs(c[1]) < 0.1
+
+        # its a pretty clean break; our points should be exact, most likely
+        if centroids[0][0] < 0.:
+            left = 0
+            right = 1
+        else:
+            left = 1
+            right = 0
+        assert sorted(p1) == sorted(classes[left])
+        assert sorted(p2) == sorted(classes[right])
+
+    def test_help(self):
+        """smoketest for CLI"""
+
+        try:
+            kmeans.main(['--help'])
+        except SystemExit:
+            # this is expected
+            pass
+
+
+if __name__ == '__main__':
+    unittest.main()