view tests/test_kmeans.py @ 87:9d5a5e9f5c3b

add kmeans + dataset
author Jeff Hammel <k0scist@gmail.com>
date Sun, 17 Dec 2017 14:05:57 -0800
parents
children 596dac7f3e98
line wrap: on
line source

#!/usr/bin/env python

"""
tests K means algorithm
"""

import unittest
from tvii import kmeans
from nettwerk.dataset.circle import CircularRandom


class TestKMeans(unittest.TestCase):

    def test_dualing_gaussians(self):
        """tests two gaussian distributions;  first, cut overlap"""
        # TODO

    def test_circles(self):
        """test with two circles of points"""

        # generate two non-overlapping circles
        n_points = 10000   # per circle
        p1 = CircularRandom((-1.5, 0), 1)(n_points)
        p2 = CircularRandom((1.5, 0), 1)(n_points)

        # run kmeans
        classes, centroids = kmeans.kmeans(p1+p2, 2)

        # sanity
        assert len(centroids) == 2
        assert len(classes) == 2

        # the centroids should have opposite x values
        xprod = centroids[0][0] * centroids[1][0]
        assert xprod < 0.
        assert abs(xprod + 2.25) < 0.1

        # assert we're kinda close
        for c in centroids:
            c = [abs(i) for i in c]
            assert abs(c[0]-1.5) < 0.1
            assert abs(c[1]) < 0.1

        # its a pretty clean break; our points should be exact, most likely
        if centroids[0][0] < 0.:
            left = 0
            right = 1
        else:
            left = 1
            right = 0
        assert sorted(p1) == sorted(classes[left])
        assert sorted(p2) == sorted(classes[right])

    def test_help(self):
        """smoketest for CLI"""

        try:
            kmeans.main(['--help'])
        except SystemExit:
            # this is expected
            pass


if __name__ == '__main__':
    unittest.main()