Mercurial > hg > numerics
annotate numerics/interpolation.py @ 120:4e0c6887604e
cleanup and note dependencies
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 15 Mar 2015 20:58:15 -0700 |
parents | 7dd1b18c9f78 |
children |
rev | line source |
---|---|
4 | 1 #!/usr/bin/env python |
2 | |
3 """ | |
4 interpolation | |
5 """ | |
6 | |
66
7dd1b18c9f78
minor cleanup and better error
Jeff Hammel <k0scist@gmail.com>
parents:
37
diff
changeset
|
7 # imports |
36 | 8 import argparse |
37 | 9 import sys |
36 | 10 |
10 | 11 __all__ = ['neighbors', 'linear_interpolation', 'InterpolateParser'] |
4 | 12 |
13 def neighbors(start, finish): | |
14 """ | |
15 returns the neighbors in finish from start | |
16 assumes both are sorted | |
17 """ | |
18 assert finish | |
19 index = 0 | |
20 retval = [] | |
21 for x in start: | |
22 # endpoints | |
23 if x < finish[0]: | |
24 retval.append((None, 0)) | |
25 continue | |
26 if x > finish[-1]: | |
27 retval.append((len(finish)-1, None)) | |
28 continue | |
29 # traverse | |
30 try: | |
31 while True: | |
32 if x < finish[index] or x > finish[index+1]: | |
33 index += 1 | |
34 continue | |
35 else: | |
36 break | |
37 retval.append((index, index+1)) | |
38 except IndexError: | |
39 retval.append((len(finish)-2, len(finish)-1)) | |
40 | |
41 return retval | |
42 | |
43 | |
44 def linear_interpolation(data, points): | |
45 """ | |
46 linearly interpolate data to points | |
47 | |
48 data -- iterable of 2-tuples (or equivalent) of `x,y` | |
49 points -- `x`-values to interpolate to | |
50 """ | |
51 | |
52 # ensure we are sorted | |
53 data = sorted(data, key=lambda x: x[0]) | |
54 points = sorted(points) | |
55 | |
56 # get the neighbors | |
57 x = [value[0] for value in data] | |
58 nearest_neighbors = neighbors(points, x) | |
59 | |
60 # we don't support endpoints yet; this is interpolation, not extrapolation | |
66
7dd1b18c9f78
minor cleanup and better error
Jeff Hammel <k0scist@gmail.com>
parents:
37
diff
changeset
|
61 if any([(neighbor[0] is None or neighbor[1] is None) |
7dd1b18c9f78
minor cleanup and better error
Jeff Hammel <k0scist@gmail.com>
parents:
37
diff
changeset
|
62 for neighbor in nearest_neighbors]): |
7dd1b18c9f78
minor cleanup and better error
Jeff Hammel <k0scist@gmail.com>
parents:
37
diff
changeset
|
63 raise AssertionError("Bad neighbors: {}".format(nearest_neighbors)) |
4 | 64 |
65 retval = [] | |
66 for index, (left, right) in enumerate(nearest_neighbors): | |
67 # linearly interpolate | |
68 ratio = (points[index] - data[left][0])/float(data[right][0] - data[left][0]) | |
69 retval.append(ratio*data[right][1] + (1.-ratio)*data[left][1]) | |
70 return retval | |
71 | |
10 | 72 class InterpolateParser(argparse.ArgumentParser): |
73 """CLI option parser""" | |
4 | 74 |
10 | 75 def __init__(self, **kwargs): |
76 kwargs.setdefault('description', __doc__) | |
77 argparse.ArgumentParser.__init__(self, **kwargs) | |
78 self.add_argument('input', nargs='?', | |
79 type=argparse.FileType('r'), default=sys.stdin, | |
80 help='input file, or read from stdin if ommitted') | |
81 self.add_argument('-o', '--output', dest='output', | |
82 type=argparse.FileType('w'), default=sys.stdout, | |
83 help="output file, or stdout if ommitted") | |
84 self.add_argument('--points', '--print-points', dest='print_points', | |
85 action='store_true', default=False, | |
86 help="print the points to interpolate to and exit") | |
87 self.options = None | |
88 | |
89 def parse_args(self, *args, **kw): | |
90 options = argparse.ArgumentParser.parse_args(self, *args, **kw) | |
91 self.validate(options) | |
92 self.options = options | |
93 return options | |
94 | |
95 def validate(self, options): | |
96 """validate options""" | |
97 | |
98 def main(args=sys.argv[1:]): | |
99 """CLI""" | |
100 | |
101 # parse command line options | |
102 parser = InterpolateParser() | |
103 options = parser.parse_args(args) | |
104 | |
105 # read the CSV | |
106 reader = csv.reader(options.input) | |
107 data = [[float(col) for col in row] for row in reader] | |
108 ncols = set([len(row) for row in data]) | |
109 assert len(ncols) == 1 | |
110 ncols = ncols.pop() | |
111 assert ncols > 1 | |
112 | |
113 # get `x` values | |
114 data = sorted(data, key=lambda x: x[0]) | |
115 x = [row[0] for row in data] | |
116 xmin = int(x[0]) + 1 | |
117 xmax = int(x[-1]) | |
118 points = range(xmin, xmax+1) | |
119 if options.print_points: | |
120 print ('\n'.join([str(point) for point in points])) | |
121 return | |
122 | |
123 # make into x,y series | |
124 series = [[(row[0], row[col]) for row in data] | |
125 for col in range(1,ncols)] | |
126 | |
127 # interpolate | |
128 interpolated = [linear_interpolation(s, points) for s in series] | |
129 | |
130 # output interpolated data | |
131 writer = csv.writer(options.output) | |
132 for row in zip(points, *interpolated): | |
133 writer.writerow(row) | |
134 | |
135 if __name__ == '__main__': | |
136 main() |