Mercurial > hg > schema
comparison schema/sql.py @ 0:f7edadebb1de
initial commit
author | Jeff Hammel <jhammel@mozilla.com> |
---|---|
date | Fri, 17 Feb 2012 12:20:51 -0800 |
parents | |
children | 5baa23a8d32f |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:f7edadebb1de |
---|---|
1 import sqlite3 | |
2 import sys | |
3 import tempfile | |
4 | |
5 class SQL(object): | |
6 converters = {int: 'INT', | |
7 str: 'TEXT', | |
8 unicode: 'TEXT'} | |
9 | |
10 def __init__(self, database=None): | |
11 self.database = database or tempfile.mktemp() | |
12 | |
13 def __call__(self, statement, *parameters): | |
14 con = None | |
15 e = None | |
16 try: | |
17 con = sqlite3.connect(self.database) | |
18 cursor = con.cursor() | |
19 cursor.execute(statement, *parameters) | |
20 data = cursor.fetchall() | |
21 con.commit() | |
22 except sqlite3.Error, e: | |
23 print >> sys.stderr, "Error %s:" % e.args[0] | |
24 if con: | |
25 con.rollback() | |
26 raise | |
27 if con: | |
28 con.close() | |
29 if e: | |
30 raise | |
31 return data | |
32 | |
33 def update(self, table, **where): | |
34 pass | |
35 | |
36 def select(self, table, id, **where): | |
37 if id is None: | |
38 id = '*' | |
39 return self("SELECT ? FROM ? WHERE %s" % 'AND '.join(['%s=%s' % (i, repr(j)) for i, j in where.items()])) | |
40 | |
41 def tables(self): | |
42 """return the tables available in the db""" | |
43 # XXX sqlite specific | |
44 return set([i[0] for i in self("SELECT name FROM sqlite_master WHERE type='table'")]) | |
45 | |
46 def create(self, name, *values): | |
47 """ | |
48 create a new table | |
49 - name: name of the table | |
50 - values: 2-tuples of (column name, type) | |
51 """ | |
52 # sanity checks | |
53 assert not [i for i in values if len(i) != 2], "Values should be 2-tuples" | |
54 missing = set([i[1] for i in values]).difference(self.converters.keys()) | |
55 assert not missing, "Unknown types found: %s" % missing | |
56 | |
57 self("DROP TABLE IF EXISTS %s" % name) | |
58 self("CREATE TABLE %s (%s)" % (name, ', '.join(["%s %s" % (i, self.converters[j]) for i, j in values]))) | |
59 | |
60 def columns(self, table): | |
61 """ | |
62 return the column names in a table | |
63 """ | |
64 | |
65 info = self("PRAGMA table_info(%s)" % table) | |
66 return [i[1] for i in info] | |
67 | |
68 if __name__ == '__main__': | |
69 db = SQL() | |
70 | |
71 # there should be no tables | |
72 assert not db.tables() | |
73 | |
74 # add a table | |
75 db.create('foo', ('bar', int), ('blarg', str)) | |
76 db.create('fleem', ('baz', int)) | |
77 assert db.tables() == set(['foo', 'fleem']) | |
78 columns = db.columns('foo') | |
79 assert columns == ['bar', 'blarg'] |