Mercurial > hg > Lemuriformes
diff lemuriformes/sql.py @ 15:0d1b8bb1d97b
SQL + data related functionality
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 10 Dec 2017 17:16:52 -0800 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/lemuriformes/sql.py Sun Dec 10 17:16:52 2017 -0800 @@ -0,0 +1,92 @@ +""" +abstract SQL functionality +""" + + +from abc import abstractmethod + + +class SQLConnection(object): + """abstract base class for SQL connection""" + + placeholder = '?' # VALUE placeholder + + @abstractmethod + def __call__(self, sql, *args, **kwargs): + """ + execute `sql` against connection cursor with values in `args`; + `kwargs` should be passed to the connection + """ + + @abstractmethod + def tables(self): + """return list of tables in the database""" + + @abstractmethod + def columns(self, table): + """return the columns in `table`""" + + @abstractmethod + def create_table(self, table, **columns): + """ + add a table to the database for the specific SQL type + """ + + @abstractmethod + def pytype2sql(self, pytype): + """return the SQL type for the python type `pytype`""" + + @abstractmethod + def sqltype2py(self, sqltype): + """return the python type for the SQL type `sqltype`""" + + def drop(self, table): + """drop `table` if exists""" + + if table in self.tables(): + return # nothing to do + + sql = "DROP TABLE {table}" + self(sql.format(table=table)) + + def placeholders(self, number): + """ + return placeholder string appropriate to INSERT SQL; + `number` should be an integer or an iterable with a `len` + """ + + try: + number = len(number) # iterable + except TypeError: + pass # assume integer + + return ','.join([self.placeholder for placeholder in range(number)]) + + def insert_sql(self, table, columns): + """return insert SQL statement""" + sql = "INSERT INTO `{table}` ({columns}) VALUES ({values})" + column_str = ', '.join(["`{}`".format(key) for key in columns]) + return sql.format(table=table, + columns=column_str, + values=self.placeholders(columns)) + + def insert_row(self, table, **row): + """insert one `row` into `table`""" + + columns = row.keys() + sql = "INSERT INTO {table} ({columns}) VALUES ({placeholders})".format(table=table, + placeholders=self.placeholders(columns), + columns=', '.join(columns)) + values = tuple([row[column] for column in columns]) + self(sql, *values) + + def count(self, table): + """count + return number of rows in `table`""" + # https://docs.microsoft.com/en-us/sql/t-sql/functions/count-transact-sql + + sql = "select count(*) from {table}".format(table=table) + data = self(sql) + assert len(data) == 1 + if len(data[0]) != 1: + raise AssertionError + return data[0][0]