Mercurial > hg > Lemuriformes
view lemuriformes/sql.py @ 18:56596902e9ae default tip
add some setup + tests
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 10 Dec 2017 17:57:03 -0800 |
parents | 0d1b8bb1d97b |
children |
line wrap: on
line source
""" abstract SQL functionality """ from abc import abstractmethod class SQLConnection(object): """abstract base class for SQL connection""" placeholder = '?' # VALUE placeholder @abstractmethod def __call__(self, sql, *args, **kwargs): """ execute `sql` against connection cursor with values in `args`; `kwargs` should be passed to the connection """ @abstractmethod def tables(self): """return list of tables in the database""" @abstractmethod def columns(self, table): """return the columns in `table`""" @abstractmethod def create_table(self, table, **columns): """ add a table to the database for the specific SQL type """ @abstractmethod def pytype2sql(self, pytype): """return the SQL type for the python type `pytype`""" @abstractmethod def sqltype2py(self, sqltype): """return the python type for the SQL type `sqltype`""" def drop(self, table): """drop `table` if exists""" if table in self.tables(): return # nothing to do sql = "DROP TABLE {table}" self(sql.format(table=table)) def placeholders(self, number): """ return placeholder string appropriate to INSERT SQL; `number` should be an integer or an iterable with a `len` """ try: number = len(number) # iterable except TypeError: pass # assume integer return ','.join([self.placeholder for placeholder in range(number)]) def insert_sql(self, table, columns): """return insert SQL statement""" sql = "INSERT INTO `{table}` ({columns}) VALUES ({values})" column_str = ', '.join(["`{}`".format(key) for key in columns]) return sql.format(table=table, columns=column_str, values=self.placeholders(columns)) def insert_row(self, table, **row): """insert one `row` into `table`""" columns = row.keys() sql = "INSERT INTO {table} ({columns}) VALUES ({placeholders})".format(table=table, placeholders=self.placeholders(columns), columns=', '.join(columns)) values = tuple([row[column] for column in columns]) self(sql, *values) def count(self, table): """count + return number of rows in `table`""" # https://docs.microsoft.com/en-us/sql/t-sql/functions/count-transact-sql sql = "select count(*) from {table}".format(table=table) data = self(sql) assert len(data) == 1 if len(data[0]) != 1: raise AssertionError return data[0][0]