view globalneighbors/distance.py @ 3:49aae0c0293b

improved test coverage
author Jeff Hammel <k0scist@gmail.com>
date Sat, 24 Jun 2017 14:48:31 -0700
parents 1b94f3bf97e5
children 7e27e874655b
line wrap: on
line source

"""
distance functionality
"""

import argparse
import json
import sys
import time
from math import asin, sin, cos, sqrt, pi, fabs
from .cli import CitiesParser
from .constants import Rearth
from .locations import locations
from .read import read_cities
from .schema import fields

DEGREESTORADIANS = pi/180.


def haversine(lat1, lon1, lat2, lon2, r=1):
    """
    see
    https://en.wikipedia.org/wiki/Haversine_formula
    Coordinates in radians
    """
    return 2*r*asin(
        sqrt(sin(0.5*(lat2-lat1))**2
             +cos(lat1)*cos(lat2)*(sin(0.5*(lon2-lon1))**2)
        ))


def deg_to_rad(degrees):
    return degrees*DEGREESTORADIANS


def calculate_distances(locations, r=Rearth):
    """
    WARNING! This is an N-squared approach
    """

    # convert to rad
    rad_locations = [(location, tuple([deg_to_rad(i)
                                       for i in latlon]))
                      for location, latlon
                      in locations.items()]

    # use haversince function on N-body problem
    for index, loc1 in enumerate(rad_locations):
        id1, (lat1, lon1) = loc1
        for loc2 in rad_locations[index+1:]:
            id2, (lat2, lon2) = loc2
            key = (id1, id2) if id2 > id1 else (id2, id1)
            yield (key,
                   haversine(lat1, lon1, lat2, lon2, r=r))


def calculate_neighbors(locations,
                        k=10,
                        lat_tol=1.,
                        lon_tol=1.,
                        output=None,
                        neighbors=None):
    """
    calculate `k` nearest neighbors for each location

    locations -- dict of `geoid: (lat, lon)`
    """
    neighbors = neighbors or {}
    items = locations.items()  # copy
    index = 0
    n_items = len(items)
    start = int(time.time())
    while items:
        index += 1
        if output and not index % output:
            # output status counter
            now = int(time.time())
            duration =  now - start
            start = now
            print ('{},{},{},{}'.format(index,
                                        len(items),
                                        n_items,
                                        duration))
        id1, (lat1, lon1) = items.pop()
        for loc2 in items:
            id2, (lat2, lon2) = loc2

            # filter out locations based on latlon boxing
            if fabs(lat2 - lat1) > lat_tol:
                 continue
            if fabs(lon2 - lon1) > lon_tol:
                 continue

            # determine distance
            args = [deg_to_rad(i) for i in
                    (lat1, lon1, lat2, lon2)]
            new_distance = haversine(*args, r=Rearth)

            # insert in order
            for i in (id1, id2):
                distances = neighbors.setdefault(i, [])
                if len(distances) == k and new_distance >= distances[-1][-1]:
                    continue

                # TODO: Binary Search Tree
                for _index, (geoid, old_distance) in enumerate(distances):
                    if new_distance < old_distance:
                        distances.insert(_index, (i, new_distance))
                        if len(distances) == k+1:
                            distances.pop()
                        break
                else:
                    distances.append((i, new_distance))

    return neighbors


def main(args=sys.argv[1:]):
    """CLI"""

    # parse command line arguments
    description = """write nearest neighborfiles"""
    parser = CitiesParser(description=description)
    parser.add_argument('output', type=argparse.FileType('w'),
                        help="output file to dump JSON to")
    parser.add_argument('--counter', '--output-counter',
                        dest='output_counter',
                        type=int, default=100,
                        help="how often to output progress updates [DEFAULT: %(default)s]")
    parser.add_argument('-k', dest='k',
                        type=int, default=50,
                        help="number of neighbors to determine [DEFAULT: %(default)s]")
    options = parser.parse_args(args)

    # parse cities
    cities = list(read_cities(options.cities, fields=fields))

    # get locations
    city_locations = locations(cities)

    # calculate neighbors
    neighbors = calculate_neighbors(city_locations,
                                    k=options.k,
                                    output=options.output_counter)

    # output
    options.output.write(json.dumps(neighbors, indent=2))

if __name__ == '__main__':
    main()