changeset 8:e3d6919130ca

insert via BST
author Jeff Hammel <k0scist@gmail.com>
date Sun, 25 Jun 2017 11:21:28 -0700
parents 254195d0bac2
children 638fad06e556
files globalneighbors/distance.py tests/test_bisect.py
diffstat 2 files changed, 49 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/globalneighbors/distance.py	Sun Jun 25 09:13:48 2017 -0700
+++ b/globalneighbors/distance.py	Sun Jun 25 11:21:28 2017 -0700
@@ -3,6 +3,7 @@
 """
 
 import argparse
+import bisect
 import json
 import sys
 import time
@@ -51,8 +52,19 @@
     else:
         distances.append((i, new_distance))
 
+class KeyWrapper:
+    def __init__(self, iterable, key):
+        self.it = iterable
+        self.key = key
+
+    def __getitem__(self, i):
+        return self.key(self.it[i])
+
+    def __len__(self):
+        return len(self.it)
 
 def insert_distance_bisect(distances, i, new_distance, k):
+
     if not distances:
         distances.append((i, new_distance))
         return
@@ -60,10 +72,13 @@
         if len(distances) < k:
             distances.append((i, new_distance))
         return
-    indices = [0, len(distances)]
-    while True:
-        midpoint = int((indices[-1] - indices[0])/2)
 
+    point = bisect.bisect_left(KeyWrapper(distances,
+                                          key=lambda x: x[-1]),
+                               new_distance)
+    distances.insert(point, (i, new_distance))
+    if len(distances) == k+1:
+        distances.pop()
 
 def calculate_distances(locations, r=Rearth):
     """
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/test_bisect.py	Sun Jun 25 11:21:28 2017 -0700
@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+"""
+test bisection insert
+"""
+
+import random
+import unittest
+from globalneighbors import distance
+
+class TestBisectInsert(unittest.TestCase):
+
+    def test_bisect_insert(self):
+        """ensure our inserted points are in order"""
+
+        values = [(random.random(), i)
+                  for i in range(15000)]
+        for k in (10, 100, 1000, 10000):
+            _distances = []
+            for value, i in values:
+                distance.insert_distance_bisect(_distances,
+                                                i,
+                                                value,
+                                                k)
+            # since k is < 15000
+            assert len(_distances) == k
+            ordered = [value[-1] for value in _distances]
+            assert sorted(ordered) == ordered
+
+
+if __name__ == '__main__':
+    unittest.main()