diff lemuriformes/sql.py @ 15:0d1b8bb1d97b

SQL + data related functionality
author Jeff Hammel <k0scist@gmail.com>
date Sun, 10 Dec 2017 17:16:52 -0800
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/lemuriformes/sql.py	Sun Dec 10 17:16:52 2017 -0800
@@ -0,0 +1,92 @@
+"""
+abstract SQL functionality
+"""
+
+
+from abc import abstractmethod
+
+
+class SQLConnection(object):
+    """abstract base class for SQL connection"""
+
+    placeholder = '?'  # VALUE placeholder
+
+    @abstractmethod
+    def __call__(self, sql, *args, **kwargs):
+        """
+        execute `sql` against connection cursor with values in `args`;
+        `kwargs` should be passed to the connection
+        """
+
+    @abstractmethod
+    def tables(self):
+        """return list of tables in the database"""
+
+    @abstractmethod
+    def columns(self, table):
+        """return the columns in `table`"""
+
+    @abstractmethod
+    def create_table(self, table, **columns):
+        """
+        add a table to the database for the specific SQL type
+        """
+
+    @abstractmethod
+    def pytype2sql(self, pytype):
+        """return the SQL type for the python type `pytype`"""
+
+    @abstractmethod
+    def sqltype2py(self, sqltype):
+        """return the python type for the SQL type `sqltype`"""
+
+    def drop(self, table):
+        """drop `table` if exists"""
+
+        if table in self.tables():
+            return  # nothing to do
+
+        sql = "DROP TABLE {table}"
+        self(sql.format(table=table))
+
+    def placeholders(self, number):
+        """
+        return placeholder string appropriate to INSERT SQL;
+        `number` should be an integer or an iterable with a `len`
+        """
+
+        try:
+            number = len(number)  # iterable
+        except TypeError:
+            pass  # assume integer
+
+        return ','.join([self.placeholder for placeholder in range(number)])
+
+    def insert_sql(self, table, columns):
+        """return insert SQL statement"""
+        sql = "INSERT INTO `{table}` ({columns}) VALUES ({values})"
+        column_str = ', '.join(["`{}`".format(key) for key in columns])
+        return sql.format(table=table,
+                          columns=column_str,
+                          values=self.placeholders(columns))
+
+    def insert_row(self, table, **row):
+        """insert one `row` into `table`"""
+
+        columns = row.keys()
+        sql = "INSERT INTO {table} ({columns}) VALUES ({placeholders})".format(table=table,
+                                                                               placeholders=self.placeholders(columns),
+                                                                               columns=', '.join(columns))
+        values = tuple([row[column] for column in columns])
+        self(sql, *values)
+
+    def count(self, table):
+        """count + return number of rows in `table`"""
+        # https://docs.microsoft.com/en-us/sql/t-sql/functions/count-transact-sql
+
+        sql = "select count(*) from {table}".format(table=table)
+        data = self(sql)
+        assert len(data) == 1
+        if len(data[0]) != 1:
+            raise AssertionError
+        return data[0][0]