From: Mike Bayer Date: Fri, 1 Jul 2005 02:43:15 +0000 (+0000) Subject: Initial revision X-Git-Tag: rel_0_1_0~927 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b2f0d64fa8c06b5662ce6831dc3fe1588397c76b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Initial revision --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py new file mode 100644 index 0000000000..bd63aa9d7e --- /dev/null +++ b/lib/sqlalchemy/ansisql.py @@ -0,0 +1,244 @@ +"""defines ANSI SQL operations.""" + +import sqlalchemy.schema as schema + +from sqlalchemy.schema import * +import sqlalchemy.sql as sql +import sqlalchemy.engine +from sqlalchemy.sql import * +from sqlalchemy.util import * +import string + +def engine(**params): + return ANSISQLEngine(**params) + +class ANSISQLEngine(sqlalchemy.engine.SQLEngine): + + def tableimpl(self, table): + return ANSISQLTableImpl(table) + + def schemagenerator(self, proxy, **params): + return ANSISchemaGenerator(proxy, **params) + + def schemadropper(self, proxy, **params): + return ANSISchemaDropper(proxy, **params) + + def connect_args(self): + return ([],{}) + + def dbapi(self): + return object() + + def compile(self, statement, bindparams): + compiler = ANSICompiler(statement, bindparams) + + statement.accept_visitor(compiler) + return compiler + +class ANSICompiler(sql.Compiled): + def __init__(self, parent, bindparams): + self.binds = {} + self.bindparams = bindparams + self.parent = parent + self.froms = {} + self.wheres = {} + self.strings = {} + + def get_from_text(self, obj): + return self.froms[obj] + + def get_str(self, obj): + return self.strings[obj] + + def get_whereclause(self, obj): + return self.wheres.get(obj, None) + + def get_params(self, **params): + d = {} + for key, value in params.iteritems(): + try: + b = self.binds[key] + except KeyError: + raise "No such bind param in statement '%s': %s" % (str(self), key) + d[b.key] = value + + for b in self.binds.values(): + if not d.has_key(b.key): + d[b.key] = b.value + + return d + + def visit_column(self, column): + if column.table.name is None: + self.strings[column] = column.name + else: + self.strings[column] = "%s.%s" % (column.table.name, column.name) + + def visit_fromclause(self, fromclause): + self.froms[fromclause] = fromclause.from_name + + def visit_textclause(self, textclause): + if textclause.parens and len(textclause.text): + self.strings[textclause] = "(" + textclause.text + ")" + else: + self.strings[textclause] = textclause.text + + def visit_compound(self, compound): + if compound.operator is None: + sep = " " + else: + sep = " " + compound.operator + " " + + if compound.parens: + self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")" + else: + self.strings[compound] = string.join([self.get_str(c) for c in compound.clauses], sep) + + def visit_clauselist(self, list): + self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') + + def visit_binary(self, binary): + + if binary.parens: + self.strings[binary] = "(" + self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + ")" + else: + self.strings[binary] = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + + def visit_bindparam(self, bindparam): + self.binds[bindparam.shortname] = bindparam + + count = 1 + key = bindparam.key + + while self.binds.setdefault(key, bindparam) is not bindparam: + key = "%s_%d" % (bindparam.key, count) + count += 1 + + self.strings[bindparam] = ":" + key + + def visit_alias(self, alias): + self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name + + def visit_select(self, select): + inner_columns = [] + + for c in select._raw_columns: + for co in c.columns: + inner_columns.append(co) + + if select.use_labels: + collist = string.join(["%s AS %s" % (c.fullname, c.label) for c in inner_columns], ', ') + else: + collist = string.join([c.fullname for c in inner_columns], ', ') + + text = "SELECT " + collist + " FROM " + + whereclause = select.whereclause + + froms = [] + for f in select.froms.values(): + + # special thingy used by oracle to redefine a join + w = self.get_whereclause(f) + if w is not None: + # TODO: move this more into the oracle module + whereclause = sql.and_(w, whereclause) + self.visit_compound(whereclause) + + t = self.get_from_text(f) + if t is not None: + froms.append(t) + + text += string.join(froms, ', ') + + if whereclause is not None: + t = self.get_str(whereclause) + if t: + text += " WHERE " + t + + for tup in select._clauses: + text += " " + tup[0] + " " + self.get_str(tup[1]) + + self.strings[select] = text + self.froms[select] = "(" + text + ")" + + + def visit_table(self, table): + self.froms[table] = table.name + + def visit_join(self, join): + if join.isouter: + self.froms[join] = ("(" + self.get_from_text(join.left) + " LEFT OUTER JOIN " + self.get_from_text(join.right) + + " ON " + self.get_str(join.onclause) + ")") + else: + self.froms[join] = ("(" + self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) + + " ON " + self.get_str(join.onclause) + ")") + + def visit_insert(self, insert_stmt): + colparams = insert_stmt.get_colparams(self.bindparams) + + for c in colparams: + b = c[1] + self.binds[b.key] = b + self.binds[b.shortname] = b + + text = ("INSERT INTO " + insert_stmt.table.name + " (" + string.join([c[0].name for c in colparams], ', ') + ")" + + " VALUES (" + string.join([":" + c[1].key for c in colparams], ', ') + ")") + + self.strings[insert_stmt] = text + + def visit_update(self, update_stmt): + colparams = update_stmt.get_colparams(self.bindparams) + + for c in colparams: + b = c[1] + self.binds[b.key] = b + self.binds[b.shortname] = b + + text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=:%s" % (c[0].name, c[1].key) for c in colparams], ', ') + + if update_stmt.whereclause: + text += " WHERE " + self.get_str(update_stmt.whereclause) + + self.strings[update_stmt] = text + + def __str__(self): + return self.get_str(self.parent) + + + +class ANSISQLTableImpl(sql.TableImpl): + """Selectable implementation that gets attached to a schema.Table object.""" + + def __init__(self, table): + sql.TableImpl.__init__(self) + self.table = table + self.id = self.table.name + + def get_from_text(self): + return self.table.name + +class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): + + def visit_table(self, table): + self.append("\nCREATE TABLE " + table.name + "(") + + separator = "\n" + + for column in table.columns: + self.append(separator) + separator = ", \n" + self.append("\t" + column._get_specification()) + + self.append("\n)\n\n") + self.execute() + + def visit_column(self, column): + pass + +class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator): + def visit_table(self, table): + self.append("\nDROP TABLE " + table.name) + self.execute() + + diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py new file mode 100644 index 0000000000..2208923407 --- /dev/null +++ b/lib/sqlalchemy/databases/sqlite.py @@ -0,0 +1,53 @@ +import sys, StringIO, string, types + +import sqlalchemy.sql as sql +import sqlalchemy.engine as engine +import sqlalchemy.schema as schema +import sqlalchemy.ansisql as ansisql +from sqlalchemy.ansisql import * + +from pysqlite2 import dbapi2 as sqlite + +colspecs = { + schema.INT : "INTEGER", + schema.CHAR : "CHAR(%(length)s)", + schema.VARCHAR : "VARCHAR(%(length)s)", + schema.TEXT : "TEXT", + schema.FLOAT : "NUMERIC(%(precision)s, %(length)s)", + schema.DECIMAL : "NUMERIC(%(precision)s, %(length)s)", + schema.TIMESTAMP : "TIMESTAMP", + schema.DATETIME : "TIMESTAMP", + schema.CLOB : "TEXT", + schema.BLOB : "BLOB", + schema.BOOLEAN : "BOOLEAN", +} + +def engine(filename, **params): + return SQLiteSQLEngine(filename, **params) + +class SQLiteSQLEngine(ansisql.ANSISQLEngine): + def __init__(self, filename, **params): + self.filename = filename + ansisql.ANSISQLEngine.__init__(self, **params) + + def connect_args(self): + return ([self.filename], {}) + + def dbapi(self): + return sqlite + + def columnimpl(self, column): + return SQLiteColumnImpl(column) + +class SQLiteColumnImpl(sql.ColumnSelectable): + def _get_specification(self): + coltype = self.column.type + if type(coltype) == types.ClassType: + key = coltype + else: + key = coltype.__class__ + + return self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)} + + + diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py new file mode 100644 index 0000000000..811fa94331 --- /dev/null +++ b/lib/sqlalchemy/engine.py @@ -0,0 +1,117 @@ +"""builds upon the schema and sql packages to provide a central object for tying schema objects +and sql constructs to database-specific query compilation and execution""" + +import sqlalchemy.schema as schema +import sqlalchemy.pool +import sqlalchemy.util as util +import sqlalchemy.sql as sql +import StringIO + +class SchemaIterator(schema.SchemaVisitor): + """a visitor that can gather text into a buffer and execute the contents of the buffer.""" + + def __init__(self, sqlproxy, **params): + self.sqlproxy = sqlproxy + self.buffer = StringIO.StringIO() + + def run(self): + raise NotImplementedError() + + def append(self, s): + self.buffer.write(s) + + def execute(self): + try: + return self.sqlproxy(self.buffer.getvalue()) + finally: + self.buffer.truncate(0) + +class SQLEngine(schema.SchemaEngine): + """base class for a series of database-specific engines. serves as an abstract factory for + implementation objects as well as database connections, transactions, SQL generators, etc.""" + + def __init__(self, pool = None, echo = False, **params): + # get a handle on the connection pool via the connect arguments + # this insures the SQLEngine instance integrates with the pool referenced + # by direct usage of pool.manager().connect(*args, **params) + (cargs, cparams) = self.connect_args() + self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams) + self._echo = echo + self.context = util.ThreadLocal() + + def schemagenerator(self, proxy, **params): + raise NotImplementedError() + + def schemadropper(self, proxy, **params): + raise NotImplementedError() + + def columnimpl(self, column): + return sql.ColumnSelectable(column) + + def connect_args(self): + raise NotImplementedError() + + def dbapi(self): + raise NotImplementedError() + + def compile(self, statement): + raise NotImplementedError() + + def proxy(self): + return lambda s, p = None: self.execute(s, p) + + def connection(self): + return self._pool.connect() + + def transaction(self, func): + self.begin() + try: + func() + except: + self.rollback() + raise + self.commit() + + def begin(self): + if getattr(self.context, 'transaction', None) is None: + conn = self.connection() + self.context.transaction = conn + self.context.tcount = 1 + else: + self.context.tcount += 1 + + def rollback(self): + if self.context.transaction is not None: + self.context.transaction.rollback() + self.context.transaction = None + self.context.tcount = None + + def commit(self): + if self.context.transaction is not None: + count = self.context.tcount - 1 + self.context.tcount = count + if count == 0: + self.context.transaction.commit() + self.context.transaction = None + self.context.tcount = None + + def execute(self, statement, parameters, connection = None, **params): + if parameters is None: + parameters = {} + + if self._echo: + self.log(statement) + self.log(repr(parameters)) + + if connection is None: + poolconn = self.connection() + c = poolconn.cursor() + c.execute(statement, parameters) + return c + else: + c = connection.cursor() + c.execute(statement, parameters) + return c + + def log(self, msg): + print msg