0
|
1 import sqlite3
|
|
2 import sys
|
|
3 import tempfile
|
|
4
|
|
5 class SQL(object):
|
|
6 converters = {int: 'INT',
|
|
7 str: 'TEXT',
|
|
8 unicode: 'TEXT'}
|
|
9
|
|
10 def __init__(self, database=None):
|
|
11 self.database = database or tempfile.mktemp()
|
|
12
|
|
13 def __call__(self, statement, *parameters):
|
|
14 con = None
|
|
15 e = None
|
|
16 try:
|
|
17 con = sqlite3.connect(self.database)
|
|
18 cursor = con.cursor()
|
|
19 cursor.execute(statement, *parameters)
|
|
20 data = cursor.fetchall()
|
|
21 con.commit()
|
|
22 except sqlite3.Error, e:
|
|
23 print >> sys.stderr, "Error %s:" % e.args[0]
|
|
24 if con:
|
|
25 con.rollback()
|
|
26 raise
|
|
27 if con:
|
|
28 con.close()
|
|
29 if e:
|
|
30 raise
|
|
31 return data
|
|
32
|
|
33 def update(self, table, **where):
|
|
34 pass
|
|
35
|
|
36 def select(self, table, id, **where):
|
|
37 if id is None:
|
|
38 id = '*'
|
|
39 return self("SELECT ? FROM ? WHERE %s" % 'AND '.join(['%s=%s' % (i, repr(j)) for i, j in where.items()]))
|
|
40
|
|
41 def tables(self):
|
|
42 """return the tables available in the db"""
|
|
43 # XXX sqlite specific
|
|
44 return set([i[0] for i in self("SELECT name FROM sqlite_master WHERE type='table'")])
|
|
45
|
|
46 def create(self, name, *values):
|
|
47 """
|
|
48 create a new table
|
|
49 - name: name of the table
|
|
50 - values: 2-tuples of (column name, type)
|
|
51 """
|
|
52 # sanity checks
|
|
53 assert not [i for i in values if len(i) != 2], "Values should be 2-tuples"
|
|
54 missing = set([i[1] for i in values]).difference(self.converters.keys())
|
|
55 assert not missing, "Unknown types found: %s" % missing
|
|
56
|
|
57 self("DROP TABLE IF EXISTS %s" % name)
|
|
58 self("CREATE TABLE %s (%s)" % (name, ', '.join(["%s %s" % (i, self.converters[j]) for i, j in values])))
|
|
59
|
|
60 def columns(self, table):
|
|
61 """
|
|
62 return the column names in a table
|
|
63 """
|
|
64
|
|
65 info = self("PRAGMA table_info(%s)" % table)
|
|
66 return [i[1] for i in info]
|
|
67
|
|
68 if __name__ == '__main__':
|
|
69 db = SQL()
|
|
70
|
|
71 # there should be no tables
|
|
72 assert not db.tables()
|
|
73
|
|
74 # add a table
|
|
75 db.create('foo', ('bar', int), ('blarg', str))
|
|
76 db.create('fleem', ('baz', int))
|
|
77 assert db.tables() == set(['foo', 'fleem'])
|
|
78 columns = db.columns('foo')
|
|
79 assert columns == ['bar', 'blarg']
|