comparison tvii/kmeans.py @ 87:9d5a5e9f5c3b

add kmeans + dataset
author Jeff Hammel <k0scist@gmail.com>
date Sun, 17 Dec 2017 14:05:57 -0800
parents
children
comparison
equal deleted inserted replaced
86:b56d329c238d 87:9d5a5e9f5c3b
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 """
5 K-means unsupervised learning algorithm
6 """
7
8 import csv
9 import os
10 import random
11 import sys
12 from .centroid import centroid
13 from .cli import CLIParser
14 from .distance import distance
15 from .read import read
16
17
18 def kmeans(x, k):
19 """
20 applies K-means algorithm to data set `x`
21 to determine `k` classes of the problem
22 """
23
24 # initialization:
25 # pick `k` arbitrary centroids
26 assert k <= len(x)
27 centroids = random.sample(x, k)
28 oldcentroids = None
29
30 while centroids != oldcentroids:
31 # ???convergence?
32
33 # - divide `x` into `k` classes based on distance
34 classes = [[] for i in range(k)]
35 for point in x:
36 closest, d = min([(index, distance(point, c))
37 for index, c in enumerate(centroids)],
38 key=lambda x: x[1])
39 classes[closest].append(point)
40
41 # - move centroids to the center of the points
42 oldcentroids = centroids
43 centroids = [centroid(*pts) for pts in classes]
44
45 return (classes, centroids)
46
47
48
49 def main(args=sys.argv[1:]):
50 """CLI"""
51
52 # parse command line
53 parser = CLIParser(description=__doc__)
54 parser.add_argument('points', type=read,
55 help="points to consider")
56 parser.add_argument('--k', dest='k',
57 type=int, default=2,
58 help="number of classes to discern [DEFAULT: %(default)s]")
59 options = parser.parse_args(args)
60
61 # run kmeans
62 classes, centroids = kmeans(options.points, options.k)
63
64 # output centroids
65 # TODO: if an output flag is specified then output the different classes
66 writer = csv.writer(sys.stdout)
67 for c in centroids:
68 writer.writerow(c)
69
70 if __name__ == '__main__':
71 main()