--- /dev/null
+"""defines ANSI SQL operations."""
+
+import sqlalchemy.schema as schema
+
+from sqlalchemy.schema import *
+import sqlalchemy.sql as sql
+import sqlalchemy.engine
+from sqlalchemy.sql import *
+from sqlalchemy.util import *
+import string
+
+def engine(**params):
+ return ANSISQLEngine(**params)
+
+class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
+
+ def tableimpl(self, table):
+ return ANSISQLTableImpl(table)
+
+ def schemagenerator(self, proxy, **params):
+ return ANSISchemaGenerator(proxy, **params)
+
+ def schemadropper(self, proxy, **params):
+ return ANSISchemaDropper(proxy, **params)
+
+ def connect_args(self):
+ return ([],{})
+
+ def dbapi(self):
+ return object()
+
+ def compile(self, statement, bindparams):
+ compiler = ANSICompiler(statement, bindparams)
+
+ statement.accept_visitor(compiler)
+ return compiler
+
+class ANSICompiler(sql.Compiled):
+ def __init__(self, parent, bindparams):
+ self.binds = {}
+ self.bindparams = bindparams
+ self.parent = parent
+ self.froms = {}
+ self.wheres = {}
+ self.strings = {}
+
+ def get_from_text(self, obj):
+ return self.froms[obj]
+
+ def get_str(self, obj):
+ return self.strings[obj]
+
+ def get_whereclause(self, obj):
+ return self.wheres.get(obj, None)
+
+ def get_params(self, **params):
+ d = {}
+ for key, value in params.iteritems():
+ try:
+ b = self.binds[key]
+ except KeyError:
+ raise "No such bind param in statement '%s': %s" % (str(self), key)
+ d[b.key] = value
+
+ for b in self.binds.values():
+ if not d.has_key(b.key):
+ d[b.key] = b.value
+
+ return d
+
+ def visit_column(self, column):
+ if column.table.name is None:
+ self.strings[column] = column.name
+ else:
+ self.strings[column] = "%s.%s" % (column.table.name, column.name)
+
+ def visit_fromclause(self, fromclause):
+ self.froms[fromclause] = fromclause.from_name
+
+ def visit_textclause(self, textclause):
+ if textclause.parens and len(textclause.text):
+ self.strings[textclause] = "(" + textclause.text + ")"
+ else:
+ self.strings[textclause] = textclause.text
+
+ def visit_compound(self, compound):
+ if compound.operator is None:
+ sep = " "
+ else:
+ sep = " " + compound.operator + " "
+
+ if compound.parens:
+ self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")"
+ else:
+ self.strings[compound] = string.join([self.get_str(c) for c in compound.clauses], sep)
+
+ def visit_clauselist(self, list):
+ self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
+
+ def visit_binary(self, binary):
+
+ if binary.parens:
+ self.strings[binary] = "(" + self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + ")"
+ else:
+ self.strings[binary] = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right)
+
+ def visit_bindparam(self, bindparam):
+ self.binds[bindparam.shortname] = bindparam
+
+ count = 1
+ key = bindparam.key
+
+ while self.binds.setdefault(key, bindparam) is not bindparam:
+ key = "%s_%d" % (bindparam.key, count)
+ count += 1
+
+ self.strings[bindparam] = ":" + key
+
+ def visit_alias(self, alias):
+ self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name
+
+ def visit_select(self, select):
+ inner_columns = []
+
+ for c in select._raw_columns:
+ for co in c.columns:
+ inner_columns.append(co)
+
+ if select.use_labels:
+ collist = string.join(["%s AS %s" % (c.fullname, c.label) for c in inner_columns], ', ')
+ else:
+ collist = string.join([c.fullname for c in inner_columns], ', ')
+
+ text = "SELECT " + collist + " FROM "
+
+ whereclause = select.whereclause
+
+ froms = []
+ for f in select.froms.values():
+
+ # special thingy used by oracle to redefine a join
+ w = self.get_whereclause(f)
+ if w is not None:
+ # TODO: move this more into the oracle module
+ whereclause = sql.and_(w, whereclause)
+ self.visit_compound(whereclause)
+
+ t = self.get_from_text(f)
+ if t is not None:
+ froms.append(t)
+
+ text += string.join(froms, ', ')
+
+ if whereclause is not None:
+ t = self.get_str(whereclause)
+ if t:
+ text += " WHERE " + t
+
+ for tup in select._clauses:
+ text += " " + tup[0] + " " + self.get_str(tup[1])
+
+ self.strings[select] = text
+ self.froms[select] = "(" + text + ")"
+
+
+ def visit_table(self, table):
+ self.froms[table] = table.name
+
+ def visit_join(self, join):
+ if join.isouter:
+ self.froms[join] = ("(" + self.get_from_text(join.left) + " LEFT OUTER JOIN " + self.get_from_text(join.right) +
+ " ON " + self.get_str(join.onclause) + ")")
+ else:
+ self.froms[join] = ("(" + self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) +
+ " ON " + self.get_str(join.onclause) + ")")
+
+ def visit_insert(self, insert_stmt):
+ colparams = insert_stmt.get_colparams(self.bindparams)
+
+ for c in colparams:
+ b = c[1]
+ self.binds[b.key] = b
+ self.binds[b.shortname] = b
+
+ text = ("INSERT INTO " + insert_stmt.table.name + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
+ " VALUES (" + string.join([":" + c[1].key for c in colparams], ', ') + ")")
+
+ self.strings[insert_stmt] = text
+
+ def visit_update(self, update_stmt):
+ colparams = update_stmt.get_colparams(self.bindparams)
+
+ for c in colparams:
+ b = c[1]
+ self.binds[b.key] = b
+ self.binds[b.shortname] = b
+
+ text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=:%s" % (c[0].name, c[1].key) for c in colparams], ', ')
+
+ if update_stmt.whereclause:
+ text += " WHERE " + self.get_str(update_stmt.whereclause)
+
+ self.strings[update_stmt] = text
+
+ def __str__(self):
+ return self.get_str(self.parent)
+
+
+
+class ANSISQLTableImpl(sql.TableImpl):
+ """Selectable implementation that gets attached to a schema.Table object."""
+
+ def __init__(self, table):
+ sql.TableImpl.__init__(self)
+ self.table = table
+ self.id = self.table.name
+
+ def get_from_text(self):
+ return self.table.name
+
+class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
+
+ def visit_table(self, table):
+ self.append("\nCREATE TABLE " + table.name + "(")
+
+ separator = "\n"
+
+ for column in table.columns:
+ self.append(separator)
+ separator = ", \n"
+ self.append("\t" + column._get_specification())
+
+ self.append("\n)\n\n")
+ self.execute()
+
+ def visit_column(self, column):
+ pass
+
+class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator):
+ def visit_table(self, table):
+ self.append("\nDROP TABLE " + table.name)
+ self.execute()
+
+
--- /dev/null
+import sys, StringIO, string, types
+
+import sqlalchemy.sql as sql
+import sqlalchemy.engine as engine
+import sqlalchemy.schema as schema
+import sqlalchemy.ansisql as ansisql
+from sqlalchemy.ansisql import *
+
+from pysqlite2 import dbapi2 as sqlite
+
+colspecs = {
+ schema.INT : "INTEGER",
+ schema.CHAR : "CHAR(%(length)s)",
+ schema.VARCHAR : "VARCHAR(%(length)s)",
+ schema.TEXT : "TEXT",
+ schema.FLOAT : "NUMERIC(%(precision)s, %(length)s)",
+ schema.DECIMAL : "NUMERIC(%(precision)s, %(length)s)",
+ schema.TIMESTAMP : "TIMESTAMP",
+ schema.DATETIME : "TIMESTAMP",
+ schema.CLOB : "TEXT",
+ schema.BLOB : "BLOB",
+ schema.BOOLEAN : "BOOLEAN",
+}
+
+def engine(filename, **params):
+ return SQLiteSQLEngine(filename, **params)
+
+class SQLiteSQLEngine(ansisql.ANSISQLEngine):
+ def __init__(self, filename, **params):
+ self.filename = filename
+ ansisql.ANSISQLEngine.__init__(self, **params)
+
+ def connect_args(self):
+ return ([self.filename], {})
+
+ def dbapi(self):
+ return sqlite
+
+ def columnimpl(self, column):
+ return SQLiteColumnImpl(column)
+
+class SQLiteColumnImpl(sql.ColumnSelectable):
+ def _get_specification(self):
+ coltype = self.column.type
+ if type(coltype) == types.ClassType:
+ key = coltype
+ else:
+ key = coltype.__class__
+
+ return self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)}
+
+
+
--- /dev/null
+"""builds upon the schema and sql packages to provide a central object for tying schema objects
+and sql constructs to database-specific query compilation and execution"""
+
+import sqlalchemy.schema as schema
+import sqlalchemy.pool
+import sqlalchemy.util as util
+import sqlalchemy.sql as sql
+import StringIO
+
+class SchemaIterator(schema.SchemaVisitor):
+ """a visitor that can gather text into a buffer and execute the contents of the buffer."""
+
+ def __init__(self, sqlproxy, **params):
+ self.sqlproxy = sqlproxy
+ self.buffer = StringIO.StringIO()
+
+ def run(self):
+ raise NotImplementedError()
+
+ def append(self, s):
+ self.buffer.write(s)
+
+ def execute(self):
+ try:
+ return self.sqlproxy(self.buffer.getvalue())
+ finally:
+ self.buffer.truncate(0)
+
+class SQLEngine(schema.SchemaEngine):
+ """base class for a series of database-specific engines. serves as an abstract factory for
+ implementation objects as well as database connections, transactions, SQL generators, etc."""
+
+ def __init__(self, pool = None, echo = False, **params):
+ # get a handle on the connection pool via the connect arguments
+ # this insures the SQLEngine instance integrates with the pool referenced
+ # by direct usage of pool.manager(<module>).connect(*args, **params)
+ (cargs, cparams) = self.connect_args()
+ self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams)
+ self._echo = echo
+ self.context = util.ThreadLocal()
+
+ def schemagenerator(self, proxy, **params):
+ raise NotImplementedError()
+
+ def schemadropper(self, proxy, **params):
+ raise NotImplementedError()
+
+ def columnimpl(self, column):
+ return sql.ColumnSelectable(column)
+
+ def connect_args(self):
+ raise NotImplementedError()
+
+ def dbapi(self):
+ raise NotImplementedError()
+
+ def compile(self, statement):
+ raise NotImplementedError()
+
+ def proxy(self):
+ return lambda s, p = None: self.execute(s, p)
+
+ def connection(self):
+ return self._pool.connect()
+
+ def transaction(self, func):
+ self.begin()
+ try:
+ func()
+ except:
+ self.rollback()
+ raise
+ self.commit()
+
+ def begin(self):
+ if getattr(self.context, 'transaction', None) is None:
+ conn = self.connection()
+ self.context.transaction = conn
+ self.context.tcount = 1
+ else:
+ self.context.tcount += 1
+
+ def rollback(self):
+ if self.context.transaction is not None:
+ self.context.transaction.rollback()
+ self.context.transaction = None
+ self.context.tcount = None
+
+ def commit(self):
+ if self.context.transaction is not None:
+ count = self.context.tcount - 1
+ self.context.tcount = count
+ if count == 0:
+ self.context.transaction.commit()
+ self.context.transaction = None
+ self.context.tcount = None
+
+ def execute(self, statement, parameters, connection = None, **params):
+ if parameters is None:
+ parameters = {}
+
+ if self._echo:
+ self.log(statement)
+ self.log(repr(parameters))
+
+ if connection is None:
+ poolconn = self.connection()
+ c = poolconn.cursor()
+ c.execute(statement, parameters)
+ return c
+ else:
+ c = connection.cursor()
+ c.execute(statement, parameters)
+ return c
+
+ def log(self, msg):
+ print msg