view tests/test_grid.py @ 25:991bce6b6881 default tip

[knn] placeholder for planning session
author Jeff Hammel <k0scist@gmail.com>
date Sun, 17 Sep 2017 14:35:50 -0700
parents d1b99c695511
children
line wrap: on
line source

#!/usr/bin/env python

"""
test that we can grid a solution
"""

import os
import unittest
from common import datafile
from globalneighbors.grid import LatLonGrid
from globalneighbors.locations import locations
from globalneighbors.read import read_cities
from globalneighbors.read import read_city_list


class TestGrid(unittest.TestCase):
    """test gridding functionality"""

    ### test functions

    def test_dimensions(self):

        # make a 2 degree grid
        grid = LatLonGrid(90, 180)
        assert grid.n == (90, 180)
        assert grid.d == (2., 2.)
        assert len(grid.grid) == 90
        for row in grid.grid:
            assert len(row) == 180

    def test_insertion(self):

        coord = (-23., 122.)
        grid = LatLonGrid(3, 4)
        grid.add(1234, *coord)
        i, j = grid.index(*coord)
        assert i == 1
        assert j == 3
        assert grid[(i,j)] == set([1234])

    def test_sample(self):

        samplefile = datafile('sample.tsv')
        assert os.path.exists(samplefile)
        city_locations = locations(read_city_list(samplefile))
        self.grid_locations(city_locations)

    def test_10000(self):
        """test 10000 cities"""

        filename = datafile('10000cities.tsv')
        assert os.path.exists(filename)
        with open(filename) as f:
            city_locations = locations(read_cities(f))
        grid = self.grid_locations(city_locations)

    def test_neighbors(self):
        """test grid neighbor indexing"""

        grid = LatLonGrid(9, 9)

        neighbors = grid.neighbors(5,5)
        expected = [(4,4), (4,5), (4,6),
                    (5,4), (5,6),
                    (6,4), (6,5), (6,6)]
        assert sorted(neighbors) == sorted(expected)


    ### generic (utility) functions

    def grid_locations(self, locations):
        """grid locations + test created grid"""

        # create a grid
        grid = LatLonGrid(8, 8)

        # add the items to it
        for geoid, (lat, lon) in locations.items():
            grid.add(geoid, lat, lon)

        # iterate over the grid
        n_locations = 0
        for i in range(grid.n[0]):
            for j in range(grid.n[1]):
                n_locations += len(grid[(i,j)])
        assert n_locations == len(locations)

        return grid

if __name__ == '__main__':
    unittest.main()