87
|
1 #!/usr/bin/env python
|
|
2 # -*- coding: utf-8 -*-
|
|
3
|
|
4 """
|
|
5 K-means unsupervised learning algorithm
|
|
6 """
|
|
7
|
|
8 import csv
|
|
9 import os
|
|
10 import random
|
|
11 import sys
|
|
12 from .centroid import centroid
|
|
13 from .cli import CLIParser
|
|
14 from .distance import distance
|
|
15 from .read import read
|
|
16
|
|
17
|
|
18 def kmeans(x, k):
|
|
19 """
|
|
20 applies K-means algorithm to data set `x`
|
|
21 to determine `k` classes of the problem
|
|
22 """
|
|
23
|
|
24 # initialization:
|
|
25 # pick `k` arbitrary centroids
|
|
26 assert k <= len(x)
|
|
27 centroids = random.sample(x, k)
|
|
28 oldcentroids = None
|
|
29
|
|
30 while centroids != oldcentroids:
|
|
31 # ???convergence?
|
|
32
|
|
33 # - divide `x` into `k` classes based on distance
|
|
34 classes = [[] for i in range(k)]
|
|
35 for point in x:
|
|
36 closest, d = min([(index, distance(point, c))
|
|
37 for index, c in enumerate(centroids)],
|
|
38 key=lambda x: x[1])
|
|
39 classes[closest].append(point)
|
|
40
|
|
41 # - move centroids to the center of the points
|
|
42 oldcentroids = centroids
|
|
43 centroids = [centroid(*pts) for pts in classes]
|
|
44
|
|
45 return (classes, centroids)
|
|
46
|
|
47
|
|
48
|
|
49 def main(args=sys.argv[1:]):
|
|
50 """CLI"""
|
|
51
|
|
52 # parse command line
|
|
53 parser = CLIParser(description=__doc__)
|
|
54 parser.add_argument('points', type=read,
|
|
55 help="points to consider")
|
|
56 parser.add_argument('--k', dest='k',
|
|
57 type=int, default=2,
|
|
58 help="number of classes to discern [DEFAULT: %(default)s]")
|
|
59 options = parser.parse_args(args)
|
|
60
|
|
61 # run kmeans
|
|
62 classes, centroids = kmeans(options.points, options.k)
|
|
63
|
|
64 # output centroids
|
|
65 # TODO: if an output flag is specified then output the different classes
|
|
66 writer = csv.writer(sys.stdout)
|
|
67 for c in centroids:
|
|
68 writer.writerow(c)
|
|
69
|
|
70 if __name__ == '__main__':
|
|
71 main()
|