view lemuriformes/sql.py @ 18:56596902e9ae default tip

add some setup + tests
author Jeff Hammel <k0scist@gmail.com>
date Sun, 10 Dec 2017 17:57:03 -0800
parents 0d1b8bb1d97b
children
line wrap: on
line source

"""
abstract SQL functionality
"""


from abc import abstractmethod


class SQLConnection(object):
    """abstract base class for SQL connection"""

    placeholder = '?'  # VALUE placeholder

    @abstractmethod
    def __call__(self, sql, *args, **kwargs):
        """
        execute `sql` against connection cursor with values in `args`;
        `kwargs` should be passed to the connection
        """

    @abstractmethod
    def tables(self):
        """return list of tables in the database"""

    @abstractmethod
    def columns(self, table):
        """return the columns in `table`"""

    @abstractmethod
    def create_table(self, table, **columns):
        """
        add a table to the database for the specific SQL type
        """

    @abstractmethod
    def pytype2sql(self, pytype):
        """return the SQL type for the python type `pytype`"""

    @abstractmethod
    def sqltype2py(self, sqltype):
        """return the python type for the SQL type `sqltype`"""

    def drop(self, table):
        """drop `table` if exists"""

        if table in self.tables():
            return  # nothing to do

        sql = "DROP TABLE {table}"
        self(sql.format(table=table))

    def placeholders(self, number):
        """
        return placeholder string appropriate to INSERT SQL;
        `number` should be an integer or an iterable with a `len`
        """

        try:
            number = len(number)  # iterable
        except TypeError:
            pass  # assume integer

        return ','.join([self.placeholder for placeholder in range(number)])

    def insert_sql(self, table, columns):
        """return insert SQL statement"""
        sql = "INSERT INTO `{table}` ({columns}) VALUES ({values})"
        column_str = ', '.join(["`{}`".format(key) for key in columns])
        return sql.format(table=table,
                          columns=column_str,
                          values=self.placeholders(columns))

    def insert_row(self, table, **row):
        """insert one `row` into `table`"""

        columns = row.keys()
        sql = "INSERT INTO {table} ({columns}) VALUES ({placeholders})".format(table=table,
                                                                               placeholders=self.placeholders(columns),
                                                                               columns=', '.join(columns))
        values = tuple([row[column] for column in columns])
        self(sql, *values)

    def count(self, table):
        """count + return number of rows in `table`"""
        # https://docs.microsoft.com/en-us/sql/t-sql/functions/count-transact-sql

        sql = "select count(*) from {table}".format(table=table)
        data = self(sql)
        assert len(data) == 1
        if len(data[0]) != 1:
            raise AssertionError
        return data[0][0]