view tvii/kmeans.py @ 87:9d5a5e9f5c3b

add kmeans + dataset
author Jeff Hammel <k0scist@gmail.com>
date Sun, 17 Dec 2017 14:05:57 -0800
parents
children
line wrap: on
line source

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
K-means unsupervised learning algorithm
"""

import csv
import os
import random
import sys
from .centroid import centroid
from .cli import CLIParser
from .distance import distance
from .read import read


def kmeans(x, k):
    """
    applies K-means algorithm to data set `x`
    to determine `k` classes of the problem
    """

    # initialization:
    # pick `k` arbitrary centroids
    assert k <= len(x)
    centroids = random.sample(x, k)
    oldcentroids = None

    while centroids != oldcentroids:
        # ???convergence?

        # - divide `x` into `k` classes based on distance
        classes = [[] for i in range(k)]
        for point in x:
            closest, d = min([(index, distance(point, c))
                              for index, c in enumerate(centroids)],
                             key=lambda x: x[1])
            classes[closest].append(point)

        # - move centroids to the center of the points
        oldcentroids = centroids
        centroids = [centroid(*pts) for pts in classes]

    return (classes, centroids)



def main(args=sys.argv[1:]):
    """CLI"""

    # parse command line
    parser = CLIParser(description=__doc__)
    parser.add_argument('points', type=read,
                        help="points to consider")
    parser.add_argument('--k', dest='k',
                        type=int, default=2,
                        help="number of classes to discern [DEFAULT: %(default)s]")
    options = parser.parse_args(args)

    # run kmeans
    classes, centroids = kmeans(options.points, options.k)

    # output centroids
    # TODO:  if an output flag is specified then output the different classes
    writer = csv.writer(sys.stdout)
    for c in centroids:
        writer.writerow(c)

if __name__ == '__main__':
    main()