view schema/sql.py @ 0:f7edadebb1de

initial commit
author Jeff Hammel <jhammel@mozilla.com>
date Fri, 17 Feb 2012 12:20:51 -0800
parents
children 5baa23a8d32f
line wrap: on
line source

import sqlite3
import sys
import tempfile

class SQL(object):
    converters = {int: 'INT',
                  str: 'TEXT',
                  unicode: 'TEXT'}

    def __init__(self, database=None):
        self.database = database or tempfile.mktemp()

    def __call__(self, statement, *parameters):
        con = None
        e = None
        try:
            con = sqlite3.connect(self.database)
            cursor = con.cursor()
            cursor.execute(statement, *parameters)
            data = cursor.fetchall()
            con.commit()
        except sqlite3.Error, e:
            print >> sys.stderr, "Error %s:" % e.args[0]
            if con:
                con.rollback()
            raise
        if con:
            con.close()
        if e:
            raise
        return data

    def update(self, table, **where):
        pass

    def select(self, table, id, **where):
        if id is None:
            id = '*'
        return self("SELECT ? FROM ? WHERE %s" % 'AND '.join(['%s=%s' % (i, repr(j)) for i, j in where.items()]))

    def tables(self):
        """return the tables available in the db"""
        # XXX sqlite specific
        return set([i[0] for i in self("SELECT name FROM sqlite_master WHERE type='table'")])

    def create(self, name, *values):
        """
        create a new table
        - name: name of the table
        - values: 2-tuples of (column name, type)
        """
        # sanity checks
        assert not [i for i in values if len(i) != 2], "Values should be 2-tuples"
        missing = set([i[1] for i in values]).difference(self.converters.keys())
        assert not missing, "Unknown types found: %s" % missing

        self("DROP TABLE IF EXISTS %s" % name)
        self("CREATE TABLE %s (%s)" % (name, ', '.join(["%s %s" % (i, self.converters[j]) for i, j in values])))

    def columns(self, table):
        """
        return the column names in a table
        """

        info = self("PRAGMA table_info(%s)" % table)
        return [i[1] for i in info]

if __name__ == '__main__':
    db = SQL()

    # there should be no tables
    assert not db.tables()

    # add a table
    db.create('foo', ('bar', int), ('blarg', str))
    db.create('fleem', ('baz', int))
    assert db.tables() == set(['foo', 'fleem'])
    columns = db.columns('foo')
    assert columns == ['bar', 'blarg']