87
|
1 #!/usr/bin/env python
|
|
2
|
|
3 """
|
|
4 tests K means algorithm
|
|
5 """
|
|
6
|
|
7 import unittest
|
|
8 from tvii import kmeans
|
88
|
9 from tvii.dataset.circle import CircularRandom
|
87
|
10
|
|
11
|
|
12 class TestKMeans(unittest.TestCase):
|
|
13
|
|
14 def test_dualing_gaussians(self):
|
|
15 """tests two gaussian distributions; first, cut overlap"""
|
|
16 # TODO
|
|
17
|
|
18 def test_circles(self):
|
|
19 """test with two circles of points"""
|
|
20
|
|
21 # generate two non-overlapping circles
|
|
22 n_points = 10000 # per circle
|
|
23 p1 = CircularRandom((-1.5, 0), 1)(n_points)
|
|
24 p2 = CircularRandom((1.5, 0), 1)(n_points)
|
|
25
|
|
26 # run kmeans
|
|
27 classes, centroids = kmeans.kmeans(p1+p2, 2)
|
|
28
|
|
29 # sanity
|
|
30 assert len(centroids) == 2
|
|
31 assert len(classes) == 2
|
|
32
|
|
33 # the centroids should have opposite x values
|
|
34 xprod = centroids[0][0] * centroids[1][0]
|
|
35 assert xprod < 0.
|
|
36 assert abs(xprod + 2.25) < 0.1
|
|
37
|
|
38 # assert we're kinda close
|
|
39 for c in centroids:
|
|
40 c = [abs(i) for i in c]
|
|
41 assert abs(c[0]-1.5) < 0.1
|
|
42 assert abs(c[1]) < 0.1
|
|
43
|
|
44 # its a pretty clean break; our points should be exact, most likely
|
|
45 if centroids[0][0] < 0.:
|
|
46 left = 0
|
|
47 right = 1
|
|
48 else:
|
|
49 left = 1
|
|
50 right = 0
|
|
51 assert sorted(p1) == sorted(classes[left])
|
|
52 assert sorted(p2) == sorted(classes[right])
|
|
53
|
|
54 def test_help(self):
|
|
55 """smoketest for CLI"""
|
|
56
|
|
57 try:
|
|
58 kmeans.main(['--help'])
|
|
59 except SystemExit:
|
|
60 # this is expected
|
|
61 pass
|
|
62
|
|
63
|
|
64 if __name__ == '__main__':
|
|
65 unittest.main()
|