Mercurial > hg > tvii
comparison tests/test_kmeans.py @ 87:9d5a5e9f5c3b
add kmeans + dataset
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 17 Dec 2017 14:05:57 -0800 |
parents | |
children | 596dac7f3e98 |
comparison
equal
deleted
inserted
replaced
86:b56d329c238d | 87:9d5a5e9f5c3b |
---|---|
1 #!/usr/bin/env python | |
2 | |
3 """ | |
4 tests K means algorithm | |
5 """ | |
6 | |
7 import unittest | |
8 from tvii import kmeans | |
9 from nettwerk.dataset.circle import CircularRandom | |
10 | |
11 | |
12 class TestKMeans(unittest.TestCase): | |
13 | |
14 def test_dualing_gaussians(self): | |
15 """tests two gaussian distributions; first, cut overlap""" | |
16 # TODO | |
17 | |
18 def test_circles(self): | |
19 """test with two circles of points""" | |
20 | |
21 # generate two non-overlapping circles | |
22 n_points = 10000 # per circle | |
23 p1 = CircularRandom((-1.5, 0), 1)(n_points) | |
24 p2 = CircularRandom((1.5, 0), 1)(n_points) | |
25 | |
26 # run kmeans | |
27 classes, centroids = kmeans.kmeans(p1+p2, 2) | |
28 | |
29 # sanity | |
30 assert len(centroids) == 2 | |
31 assert len(classes) == 2 | |
32 | |
33 # the centroids should have opposite x values | |
34 xprod = centroids[0][0] * centroids[1][0] | |
35 assert xprod < 0. | |
36 assert abs(xprod + 2.25) < 0.1 | |
37 | |
38 # assert we're kinda close | |
39 for c in centroids: | |
40 c = [abs(i) for i in c] | |
41 assert abs(c[0]-1.5) < 0.1 | |
42 assert abs(c[1]) < 0.1 | |
43 | |
44 # its a pretty clean break; our points should be exact, most likely | |
45 if centroids[0][0] < 0.: | |
46 left = 0 | |
47 right = 1 | |
48 else: | |
49 left = 1 | |
50 right = 0 | |
51 assert sorted(p1) == sorted(classes[left]) | |
52 assert sorted(p2) == sorted(classes[right]) | |
53 | |
54 def test_help(self): | |
55 """smoketest for CLI""" | |
56 | |
57 try: | |
58 kmeans.main(['--help']) | |
59 except SystemExit: | |
60 # this is expected | |
61 pass | |
62 | |
63 | |
64 if __name__ == '__main__': | |
65 unittest.main() |