Mercurial > hg > Lemuriformes
comparison lemuriformes/sql.py @ 15:0d1b8bb1d97b
SQL + data related functionality
author | Jeff Hammel <k0scist@gmail.com> |
---|---|
date | Sun, 10 Dec 2017 17:16:52 -0800 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
14:756dbd3e391e | 15:0d1b8bb1d97b |
---|---|
1 """ | |
2 abstract SQL functionality | |
3 """ | |
4 | |
5 | |
6 from abc import abstractmethod | |
7 | |
8 | |
9 class SQLConnection(object): | |
10 """abstract base class for SQL connection""" | |
11 | |
12 placeholder = '?' # VALUE placeholder | |
13 | |
14 @abstractmethod | |
15 def __call__(self, sql, *args, **kwargs): | |
16 """ | |
17 execute `sql` against connection cursor with values in `args`; | |
18 `kwargs` should be passed to the connection | |
19 """ | |
20 | |
21 @abstractmethod | |
22 def tables(self): | |
23 """return list of tables in the database""" | |
24 | |
25 @abstractmethod | |
26 def columns(self, table): | |
27 """return the columns in `table`""" | |
28 | |
29 @abstractmethod | |
30 def create_table(self, table, **columns): | |
31 """ | |
32 add a table to the database for the specific SQL type | |
33 """ | |
34 | |
35 @abstractmethod | |
36 def pytype2sql(self, pytype): | |
37 """return the SQL type for the python type `pytype`""" | |
38 | |
39 @abstractmethod | |
40 def sqltype2py(self, sqltype): | |
41 """return the python type for the SQL type `sqltype`""" | |
42 | |
43 def drop(self, table): | |
44 """drop `table` if exists""" | |
45 | |
46 if table in self.tables(): | |
47 return # nothing to do | |
48 | |
49 sql = "DROP TABLE {table}" | |
50 self(sql.format(table=table)) | |
51 | |
52 def placeholders(self, number): | |
53 """ | |
54 return placeholder string appropriate to INSERT SQL; | |
55 `number` should be an integer or an iterable with a `len` | |
56 """ | |
57 | |
58 try: | |
59 number = len(number) # iterable | |
60 except TypeError: | |
61 pass # assume integer | |
62 | |
63 return ','.join([self.placeholder for placeholder in range(number)]) | |
64 | |
65 def insert_sql(self, table, columns): | |
66 """return insert SQL statement""" | |
67 sql = "INSERT INTO `{table}` ({columns}) VALUES ({values})" | |
68 column_str = ', '.join(["`{}`".format(key) for key in columns]) | |
69 return sql.format(table=table, | |
70 columns=column_str, | |
71 values=self.placeholders(columns)) | |
72 | |
73 def insert_row(self, table, **row): | |
74 """insert one `row` into `table`""" | |
75 | |
76 columns = row.keys() | |
77 sql = "INSERT INTO {table} ({columns}) VALUES ({placeholders})".format(table=table, | |
78 placeholders=self.placeholders(columns), | |
79 columns=', '.join(columns)) | |
80 values = tuple([row[column] for column in columns]) | |
81 self(sql, *values) | |
82 | |
83 def count(self, table): | |
84 """count + return number of rows in `table`""" | |
85 # https://docs.microsoft.com/en-us/sql/t-sql/functions/count-transact-sql | |
86 | |
87 sql = "select count(*) from {table}".format(table=table) | |
88 data = self(sql) | |
89 assert len(data) == 1 | |
90 if len(data[0]) != 1: | |
91 raise AssertionError | |
92 return data[0][0] |