comparison tests/test_kmeans.py @ 87:9d5a5e9f5c3b

add kmeans + dataset
author Jeff Hammel <k0scist@gmail.com>
date Sun, 17 Dec 2017 14:05:57 -0800
parents
children 596dac7f3e98
comparison
equal deleted inserted replaced
86:b56d329c238d 87:9d5a5e9f5c3b
1 #!/usr/bin/env python
2
3 """
4 tests K means algorithm
5 """
6
7 import unittest
8 from tvii import kmeans
9 from nettwerk.dataset.circle import CircularRandom
10
11
12 class TestKMeans(unittest.TestCase):
13
14 def test_dualing_gaussians(self):
15 """tests two gaussian distributions; first, cut overlap"""
16 # TODO
17
18 def test_circles(self):
19 """test with two circles of points"""
20
21 # generate two non-overlapping circles
22 n_points = 10000 # per circle
23 p1 = CircularRandom((-1.5, 0), 1)(n_points)
24 p2 = CircularRandom((1.5, 0), 1)(n_points)
25
26 # run kmeans
27 classes, centroids = kmeans.kmeans(p1+p2, 2)
28
29 # sanity
30 assert len(centroids) == 2
31 assert len(classes) == 2
32
33 # the centroids should have opposite x values
34 xprod = centroids[0][0] * centroids[1][0]
35 assert xprod < 0.
36 assert abs(xprod + 2.25) < 0.1
37
38 # assert we're kinda close
39 for c in centroids:
40 c = [abs(i) for i in c]
41 assert abs(c[0]-1.5) < 0.1
42 assert abs(c[1]) < 0.1
43
44 # its a pretty clean break; our points should be exact, most likely
45 if centroids[0][0] < 0.:
46 left = 0
47 right = 1
48 else:
49 left = 1
50 right = 0
51 assert sorted(p1) == sorted(classes[left])
52 assert sorted(p2) == sorted(classes[right])
53
54 def test_help(self):
55 """smoketest for CLI"""
56
57 try:
58 kmeans.main(['--help'])
59 except SystemExit:
60 # this is expected
61 pass
62
63
64 if __name__ == '__main__':
65 unittest.main()