diff schema/sql.py @ 0:f7edadebb1de

initial commit
author Jeff Hammel <jhammel@mozilla.com>
date Fri, 17 Feb 2012 12:20:51 -0800
parents
children 5baa23a8d32f
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/schema/sql.py	Fri Feb 17 12:20:51 2012 -0800
@@ -0,0 +1,79 @@
+import sqlite3
+import sys
+import tempfile
+
+class SQL(object):
+    converters = {int: 'INT',
+                  str: 'TEXT',
+                  unicode: 'TEXT'}
+
+    def __init__(self, database=None):
+        self.database = database or tempfile.mktemp()
+
+    def __call__(self, statement, *parameters):
+        con = None
+        e = None
+        try:
+            con = sqlite3.connect(self.database)
+            cursor = con.cursor()
+            cursor.execute(statement, *parameters)
+            data = cursor.fetchall()
+            con.commit()
+        except sqlite3.Error, e:
+            print >> sys.stderr, "Error %s:" % e.args[0]
+            if con:
+                con.rollback()
+            raise
+        if con:
+            con.close()
+        if e:
+            raise
+        return data
+
+    def update(self, table, **where):
+        pass
+
+    def select(self, table, id, **where):
+        if id is None:
+            id = '*'
+        return self("SELECT ? FROM ? WHERE %s" % 'AND '.join(['%s=%s' % (i, repr(j)) for i, j in where.items()]))
+
+    def tables(self):
+        """return the tables available in the db"""
+        # XXX sqlite specific
+        return set([i[0] for i in self("SELECT name FROM sqlite_master WHERE type='table'")])
+
+    def create(self, name, *values):
+        """
+        create a new table
+        - name: name of the table
+        - values: 2-tuples of (column name, type)
+        """
+        # sanity checks
+        assert not [i for i in values if len(i) != 2], "Values should be 2-tuples"
+        missing = set([i[1] for i in values]).difference(self.converters.keys())
+        assert not missing, "Unknown types found: %s" % missing
+
+        self("DROP TABLE IF EXISTS %s" % name)
+        self("CREATE TABLE %s (%s)" % (name, ', '.join(["%s %s" % (i, self.converters[j]) for i, j in values])))
+
+    def columns(self, table):
+        """
+        return the column names in a table
+        """
+
+        info = self("PRAGMA table_info(%s)" % table)
+        return [i[1] for i in info]
+
+if __name__ == '__main__':
+    db = SQL()
+
+    # there should be no tables
+    assert not db.tables()
+
+    # add a table
+    db.create('foo', ('bar', int), ('blarg', str))
+    db.create('fleem', ('baz', int))
+    assert db.tables() == set(['foo', 'fleem'])
+    columns = db.columns('foo')
+    assert columns == ['bar', 'blarg']