changeset 9:834b920ae345 default tip

allow output of headers in csv
author Jeff Hammel <k0scist@gmail.com>
date Sat, 01 Apr 2017 15:11:34 -0700
parents adf056d67c01
children
files sqlex/main.py sqlex/model.py
diffstat 2 files changed, 14 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/sqlex/main.py	Sat Apr 01 13:01:57 2017 -0700
+++ b/sqlex/main.py	Sat Apr 01 15:11:34 2017 -0700
@@ -42,6 +42,9 @@
                           help="list columns in `table` and exit")
         self.add_argument('-o', '--output',
                           help="output to directory (if `table` not given), or filename or stdout by default")
+        self.add_argument('--header', dest='header',
+                          action='store_true', default=False,
+                          help="export header as first row")
         self.options = None
 
     def parse_args(self, *args, **kw):
@@ -102,9 +105,9 @@
 
         if options.output:
             with open(options.output, 'w') as f:
-                db.table2csv(options.table, f)
+                db.table2csv(options.table, f, header=options.header)
         else:
-            db.table2csv(options.table, sys.stdout)
+            db.table2csv(options.table, sys.stdout, header=options.header)
             sys.stdout.flush()
     else:
         # output entire db to CSV files in directory
@@ -116,7 +119,7 @@
             # export each table
             path = os.path.join(options.output, '{}.csv'.format(table))
             with open(path, 'w') as f:
-                db.table2csv(table, f)
+                db.table2csv(table, f, header=options.header)
 
 
 if __name__ == '__main__':
--- a/sqlex/model.py	Sat Apr 01 13:01:57 2017 -0700
+++ b/sqlex/model.py	Sat Apr 01 15:11:34 2017 -0700
@@ -67,7 +67,7 @@
                             for row in data])
 
 
-    def table2csv(self, table, fp):
+    def table2csv(self, table, fp, header=False):
         """
         export `table` to `fp` file object in CSV format
         """
@@ -80,6 +80,13 @@
         sql = 'select * from {table}'.format(table=table)
         rows = self(sql)
 
+        if header:
+            # export header as first row, if specified
+            _header = self.columns(table).keys()
+            if _header:
+                _header[0] = '#{}'.format(_header[0])
+            rows.insert(0, _header)
+
         # decode unicde because the CSV module won't
         # http://stackoverflow.com/questions/22733642/how-to-write-a-unicode-csv-in-python-2-7
         rows = [[unicode(s).encode("utf-8") for s in row]