view lemuriformes/mysql.py @ 18:56596902e9ae default tip

add some setup + tests
author Jeff Hammel <k0scist@gmail.com>
date Sun, 10 Dec 2017 17:57:03 -0800
parents 0d1b8bb1d97b
children
line wrap: on
line source

"""
MySQL database connection class + helpers
"""

import pymysql  # https://github.com/PyMySQL/PyMySQL
import pymysql.cursors
import sys
from .cli import ConfigurationParser
from .sql import SQLConnection


class MySQLConnection(SQLConnection):
    """connection to a MySQL database"""

    placeholder = '%s'  # VALUE placeholder
    connect_data_keys = ['host', 'user', 'password', 'db', 'port', 'charset']

    def __init__(self, host, user, password, db, port=3306, charset='utf8mb4'):
        self.connect_data = {}
        for key in self.connect_data_keys:
            self.connect_data[key] = locals()[key]


    def connect(self):
        return pymysql.connect(**self.connect_data)

    def __call__(self, sql, *args):

        with self.connect() as cursor:
            try:
                cursor.execute(sql, args)
            except TypeError:
                print ((sql, args))
                raise
            try:
                result = cursor.fetchall()
            except:
                result = None
        self.connect().commit()
        return result

    def tables(self):
        """return tables"""

        data = self("show tables")
        return [item[0] for item in data]

    def drop(self, table):

        if table not in self.tables():
            return
        self("drop table if exists {table}".format(table=table))

    def create(self, table, *columns):
        """
        columns -- each column should be a 2-tuple
        """

        sql = "CREATE TABLE {table} ({columns}) DEFAULT CHARSET=utf8mb4"

        # format columns
        _columns = ', '.join(["{0} {1}".format(column, _type)
                             for column, _type in columns])

        # execute query
        self(sql.format(table=table, columns=_columns))

    def insert(self, table, **row):
        """insert a `row` into `table`"""

        assert row
        keys = row.keys()
        values = [row[key] for key in keys]
        self(self.insert_sql(table=table,
                             columns=keys),
             *values)

    def insert_many(self, table, columns, values):
        """
        insert many rows into `table`
        columns -- list of columns to insert
        """

        # https://stackoverflow.com/questions/13020908/sql-multiple-inserts-with-python
        # It may be more efficient to flatten the string
        # instead of using `.executemany`; see
        # https://stackoverflow.com/questions/14011160/how-to-use-python-mysqldb-to-insert-many-rows-at-once

        with self.connect() as connection:
            sql = self.insert_sql(table=table, columns=columns)
            connection.executemany(sql, values)
        self.connect().commit()


class MySQLParser(ConfigurationParser):
    """command line parser for MySQL"""
    # TODO: obsolete!

    def add_arguments(self):
        self.add_argument('host', help="SQL host")
        self.add_argument('db', help="database to use")
        self.add_argument('-u', '--user', dest='user', default='root',
                          help="MySQL user [DEFAULT: %(default)s]")
        self.add_argument('-p', '--password', dest='password',
                          help="MySQL password [DEFAULT: %(default)s]")

    def connection(self):
        if self.options is None:
            raise Exception("parse_args not called successfully!")

        return MySQLConnection(host=self.options.host,
                               user=self.options.user,
                               password=self.options.password,
                               db=self.options.db)