Mercurial > hg > tvii
view tvii/kmeans.py @ 87:9d5a5e9f5c3b
add kmeans + dataset
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 17 Dec 2017 14:05:57 -0800 |
parents | |
children |
line wrap: on
line source
#!/usr/bin/env python # -*- coding: utf-8 -*- """ K-means unsupervised learning algorithm """ import csv import os import random import sys from .centroid import centroid from .cli import CLIParser from .distance import distance from .read import read def kmeans(x, k): """ applies K-means algorithm to data set `x` to determine `k` classes of the problem """ # initialization: # pick `k` arbitrary centroids assert k <= len(x) centroids = random.sample(x, k) oldcentroids = None while centroids != oldcentroids: # ???convergence? # - divide `x` into `k` classes based on distance classes = [[] for i in range(k)] for point in x: closest, d = min([(index, distance(point, c)) for index, c in enumerate(centroids)], key=lambda x: x[1]) classes[closest].append(point) # - move centroids to the center of the points oldcentroids = centroids centroids = [centroid(*pts) for pts in classes] return (classes, centroids) def main(args=sys.argv[1:]): """CLI""" # parse command line parser = CLIParser(description=__doc__) parser.add_argument('points', type=read, help="points to consider") parser.add_argument('--k', dest='k', type=int, default=2, help="number of classes to discern [DEFAULT: %(default)s]") options = parser.parse_args(args) # run kmeans classes, centroids = kmeans(options.points, options.k) # output centroids # TODO: if an output flag is specified then output the different classes writer = csv.writer(sys.stdout) for c in centroids: writer.writerow(c) if __name__ == '__main__': main()