Mercurial > hg > tvii
comparison tvii/kmeans.py @ 87:9d5a5e9f5c3b
add kmeans + dataset
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 17 Dec 2017 14:05:57 -0800 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
86:b56d329c238d | 87:9d5a5e9f5c3b |
---|---|
1 #!/usr/bin/env python | |
2 # -*- coding: utf-8 -*- | |
3 | |
4 """ | |
5 K-means unsupervised learning algorithm | |
6 """ | |
7 | |
8 import csv | |
9 import os | |
10 import random | |
11 import sys | |
12 from .centroid import centroid | |
13 from .cli import CLIParser | |
14 from .distance import distance | |
15 from .read import read | |
16 | |
17 | |
18 def kmeans(x, k): | |
19 """ | |
20 applies K-means algorithm to data set `x` | |
21 to determine `k` classes of the problem | |
22 """ | |
23 | |
24 # initialization: | |
25 # pick `k` arbitrary centroids | |
26 assert k <= len(x) | |
27 centroids = random.sample(x, k) | |
28 oldcentroids = None | |
29 | |
30 while centroids != oldcentroids: | |
31 # ???convergence? | |
32 | |
33 # - divide `x` into `k` classes based on distance | |
34 classes = [[] for i in range(k)] | |
35 for point in x: | |
36 closest, d = min([(index, distance(point, c)) | |
37 for index, c in enumerate(centroids)], | |
38 key=lambda x: x[1]) | |
39 classes[closest].append(point) | |
40 | |
41 # - move centroids to the center of the points | |
42 oldcentroids = centroids | |
43 centroids = [centroid(*pts) for pts in classes] | |
44 | |
45 return (classes, centroids) | |
46 | |
47 | |
48 | |
49 def main(args=sys.argv[1:]): | |
50 """CLI""" | |
51 | |
52 # parse command line | |
53 parser = CLIParser(description=__doc__) | |
54 parser.add_argument('points', type=read, | |
55 help="points to consider") | |
56 parser.add_argument('--k', dest='k', | |
57 type=int, default=2, | |
58 help="number of classes to discern [DEFAULT: %(default)s]") | |
59 options = parser.parse_args(args) | |
60 | |
61 # run kmeans | |
62 classes, centroids = kmeans(options.points, options.k) | |
63 | |
64 # output centroids | |
65 # TODO: if an output flag is specified then output the different classes | |
66 writer = csv.writer(sys.stdout) | |
67 for c in centroids: | |
68 writer.writerow(c) | |
69 | |
70 if __name__ == '__main__': | |
71 main() |