0.3.7
+- engines
+ - SA default loglevel is now "WARN". we have a few warnings
+ for things that should be available by default.
+ - cleanup of DBAPI import strategies across all engines
+ [ticket:480]
+ - refactoring of engine internals which reduces complexity,
+ number of codepaths; places more state inside of ExecutionContext
+ to allow more dialect control of cursor handling, result sets.
+ ResultProxy totally refactored and also has two versions of
+ "buffered" result sets used for different purposes.
+ - server side cursor support fully functional in postgres
+ [ticket:514].
- sql:
+ - the Unicode type is now a direct subclass of String, which now
+ contains all the "convert_unicode" logic. This helps the variety
+ of unicode situations that occur in db's such as MS-SQL to be
+ better handled and allows subclassing of the Unicode datatype.
+ [ticket:522]
- column labels are now generated in the compilation phase, which
means their lengths are dialect-dependent. So on oracle a label
that gets truncated to 30 chars will go out to 63 characters
full statement being compiled. this means the same statement
will produce the same string across application restarts and
allowing DB query plan caching to work better.
- - preliminary support for unicode table and column names added.
+ - preliminary support for unicode table names, column names and
+ SQL statements added, for databases which can support them.
- fix for fetchmany() "size" argument being positional in most
dbapis [ticket:505]
- sending None as an argument to func.<something> will produce
def create_connect_args(self):
return ([],{})
- def dbapi(self):
- return None
+ def schemagenerator(self, *args, **kwargs):
+ return ANSISchemaGenerator(self, *args, **kwargs)
- def schemagenerator(self, *args, **params):
- return ANSISchemaGenerator(*args, **params)
-
- def schemadropper(self, *args, **params):
- return ANSISchemaDropper(*args, **params)
+ def schemadropper(self, *args, **kwargs):
+ return ANSISchemaDropper(self, *args, **kwargs)
def compiler(self, statement, parameters, **kwargs):
return ANSICompiler(self, statement, parameters, **kwargs)
sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
+ # if we are insert/update. set to true when we visit an INSERT or UPDATE
+ self.isinsert = self.isupdate = False
+
# a dictionary of bind parameter keys to _BindParamClause instances.
self.binds = {}
return alterables
class ANSISchemaGenerator(ANSISchemaBase):
- def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
- super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs)
+ def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+ super(ANSISchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables and util.Set(tables) or None
- self.connection = connection
- self.preparer = self.engine.dialect.preparer()
- self.dialect = self.engine.dialect
+ self.preparer = dialect.preparer()
+ self.dialect = dialect
def get_column_specification(self, column, first_pk=False):
raise NotImplementedError()
collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
for table in collection:
table.accept_visitor(self)
- if self.supports_alter():
+ if self.dialect.supports_alter():
for alterable in self.find_alterables(collection):
self.add_foreignkey(alterable)
def _compile(self, tocompile, parameters):
"""compile the given string/parameters using this SchemaGenerator's dialect."""
- compiler = self.engine.dialect.compiler(tocompile, parameters)
+ compiler = self.dialect.compiler(tocompile, parameters)
compiler.compile()
return compiler
self.append("PRIMARY KEY ")
self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
- def supports_alter(self):
- return True
-
def visit_foreign_key_constraint(self, constraint):
- if constraint.use_alter and self.supports_alter():
+ if constraint.use_alter and self.dialect.supports_alter():
return
self.append(", \n\t ")
self.define_foreign_key(constraint)
self.execute()
class ANSISchemaDropper(ANSISchemaBase):
- def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
- super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs)
+ def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+ super(ANSISchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
- self.connection = connection
- self.preparer = self.engine.dialect.preparer()
- self.dialect = self.engine.dialect
+ self.preparer = dialect.preparer()
+ self.dialect = dialect
def visit_metadata(self, metadata):
collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))]
- if self.supports_alter():
+ if self.dialect.supports_alter():
for alterable in self.find_alterables(collection):
self.drop_foreignkey(alterable)
for table in collection:
table.accept_visitor(self)
- def supports_alter(self):
- return True
-
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
self.execute()
"""Prepare a quoted column name with table name."""
return self.format_column(column, use_table=True, name=column_name)
+
+dialect = ANSIDialect
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
-try:
+def dbapi():
import kinterbasdb
-except:
- kinterbasdb = None
-
-dbmodule = kinterbasdb
+ return kinterbasdb
_initialized_kb = False
return "NUMERIC(%(precision)s, %(length)s)" % { 'precision': self.precision,
'length' : self.length }
-
class FBInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
class FBDialect(ansisql.ANSIDialect):
- def __init__(self, module = None, **params):
- global _initialized_kb
- self.module = module or dbmodule
- self.opts = {}
-
- if not _initialized_kb:
- _initialized_kb = True
- type_conv = params.get('type_conv', 200) or 200
- if isinstance(type_conv, types.StringTypes):
- type_conv = int(type_conv)
-
- concurrency_level = params.get('concurrency_level', 1) or 1
- if isinstance(concurrency_level, types.StringTypes):
- concurrency_level = int(concurrency_level)
+ def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
+ ansisql.ANSIDialect.__init__(self, **kwargs)
- if kinterbasdb is not None:
- kinterbasdb.init(type_conv=type_conv, concurrency_level=concurrency_level)
- ansisql.ANSIDialect.__init__(self, **params)
+ self.type_conv = type_conv
+ self.concurrency_level= concurrency_level
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
del opts['port']
opts.update(url.query)
- # pop arguments that we took at the module level
- opts.pop('type_conv', None)
- opts.pop('concurrency_level', None)
- self.opts = opts
- return ([], self.opts)
+ type_conv = opts.pop('type_conv', self.type_conv)
+ concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
+ global _initialized_kb
+ if not _initialized_kb and self.dbapi is not None:
+ _initialized_kb = True
+ self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
+ return ([], opts)
- def create_execution_context(self):
- return FBExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return FBExecutionContext(self, *args, **kwargs)
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
return FBCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return FBSchemaGenerator(*args, **kwargs)
+ return FBSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return FBSchemaDropper(*args, **kwargs)
+ return FBSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return FBDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection):
+ return FBDefaultRunner(connection)
def preparer(self):
return FBIdentifierPreparer(self)
for name,value in fks.iteritems():
table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
- def last_inserted_ids(self):
- return self.context.last_inserted_ids
-
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters or [])
def do_commit(self, connection):
connection.commit(True)
- def connection(self):
- """Returns a managed DBAPI connection from this SQLEngine's connection pool."""
- c = self._pool.connect()
- c.supportsTransactions = 0
- return c
-
- def dbapi(self):
- return self.module
-
class FBCompiler(ansisql.ANSICompiler):
"""Firebird specific idiosincrasies"""
class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
- colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
class FBDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["rdb$database"], engine=self.engine).compile()
- return self.proxy(str(c), c.get_params()).fetchone()[0]
+ c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.engine)
+ return self.connection.execute_compiled(c).scalar()
def visit_sequence(self, seq):
- return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0]
+ return self.connection.execute_text("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").scalar()
RESERVED_WORDS = util.Set(
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
-
+def dbapi(module_name=None):
+ if module_name:
+ try:
+ dialect_cls = dialect_mapping[module_name]
+ return dialect_cls.import_dbapi()
+ except KeyError:
+ raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
+ else:
+ for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]:
+ try:
+ return dialect_cls.import_dbapi()
+ except ImportError, e:
+ pass
+ else:
+ raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
+
class MSNumeric(sqltypes.Numeric):
def convert_result_value(self, value, dialect):
return value
return "VARCHAR(%(length)s)" % {'length' : self.length}
class MSNVarchar(MSString):
- """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True. """
- impl = sqltypes.Unicode
-
def get_col_spec(self):
if self.length:
return "NVARCHAR(%(length)s)" % {'length' : self.length}
return "NTEXT"
class AdoMSNVarchar(MSNVarchar):
- def convert_bind_param(self, value, dialect):
- return value
-
- def convert_result_value(self, value, dialect):
- return value
-
-class MSUnicode(sqltypes.Unicode):
- """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl."""
- impl = MSNVarchar
-
-class AdoMSUnicode(MSUnicode):
- impl = AdoMSNVarchar
-
+ """overrides bindparam/result processing to not convert any unicode strings"""
def convert_bind_param(self, value, dialect):
return value
]}
class MSSQLExecutionContext(default.DefaultExecutionContext):
- def __init__(self, dialect):
+ def __init__(self, *args, **kwargs):
self.IINSERT = self.HASIDENT = False
- super(MSSQLExecutionContext, self).__init__(dialect)
+ super(MSSQLExecutionContext, self).__init__(*args, **kwargs)
def _has_implicit_sequence(self, column):
if column.primary_key and column.autoincrement:
return True
return False
- def pre_exec(self, engine, proxy, compiled, parameters, **kwargs):
+ def pre_exec(self):
"""MS-SQL has a special mode for inserting non-NULL values
into IDENTITY columns.
Activate it if the feature is turned on and needed.
"""
- if getattr(compiled, "isinsert", False):
- tbl = compiled.statement.table
+ if self.compiled.isinsert:
+ tbl = self.compiled.statement.table
if not hasattr(tbl, 'has_sequence'):
tbl.has_sequence = None
for column in tbl.c:
break
self.HASIDENT = bool(tbl.has_sequence)
- if engine.dialect.auto_identity_insert and self.HASIDENT:
- if isinstance(parameters, list):
- self.IINSERT = tbl.has_sequence.key in parameters[0]
+ if self.dialect.auto_identity_insert and self.HASIDENT:
+ if isinstance(self.compiled_parameters, list):
+ self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0]
else:
- self.IINSERT = tbl.has_sequence.key in parameters
+ self.IINSERT = tbl.has_sequence.key in self.compiled_parameters
else:
self.IINSERT = False
if self.IINSERT:
- proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
+ # TODO: quoting rules for table name here ?
+ self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name)
- super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs)
+ super(MSSQLExecutionContext, self).pre_exec()
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
+ def post_exec(self):
"""Turn off the INDENTITY_INSERT mode if it's been activated,
and fetch recently inserted IDENTIFY values (works only for
one column).
"""
- if getattr(compiled, "isinsert", False):
+ if self.compiled.isinsert:
if self.IINSERT:
- proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name)
+ # TODO: quoting rules for table name here ?
+ self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name)
self.IINSERT = False
elif self.HASIDENT:
- cursor = proxy("SELECT @@IDENTITY AS lastrowid")
- row = cursor.fetchone()
+ self.cursor.execute("SELECT @@IDENTITY AS lastrowid")
+ row = self.cursor.fetchone()
self._last_inserted_ids = [int(row[0])]
# print "LAST ROW ID", self._last_inserted_ids
self.HASIDENT = False
+ super(MSSQLExecutionContext, self).post_exec()
class MSSQLDialect(ansisql.ANSIDialect):
colspecs = {
+ sqltypes.Unicode : MSNVarchar,
sqltypes.Integer : MSInteger,
sqltypes.Smallinteger: MSSmallInteger,
sqltypes.Numeric : MSNumeric,
sqltypes.DateTime : MSDateTime,
sqltypes.Date : MSDate,
sqltypes.String : MSString,
- sqltypes.Unicode : MSUnicode,
sqltypes.Binary : MSBinary,
sqltypes.Boolean : MSBoolean,
sqltypes.TEXT : MSText,
'smallint' : MSSmallInteger,
'tinyint' : MSTinyInteger,
'varchar' : MSString,
- 'nvarchar' : MSUnicode,
+ 'nvarchar' : MSNVarchar,
'char' : MSChar,
'nchar' : MSNChar,
'text' : MSText,
'image' : MSBinary
}
- def __new__(cls, module_name=None, *args, **kwargs):
- module = kwargs.get('module', None)
+ def __new__(cls, dbapi=None, *args, **kwargs):
if cls != MSSQLDialect:
return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
- if module_name:
- dialect = dialect_mapping.get(module_name)
- if not dialect:
- raise exceptions.InvalidRequestError('Unsupported MSSQL module requested (must be adodbpi, pymssql or pyodbc): ' + module_name)
- if not hasattr(dialect, 'module'):
- raise dialect.saved_import_error
+ if dbapi:
+ dialect = dialect_mapping.get(dbapi.__name__)
return dialect(*args, **kwargs)
- elif module:
- return object.__new__(cls, *args, **kwargs)
else:
- for dialect in dialect_preference:
- if hasattr(dialect, 'module'):
- return dialect(*args, **kwargs)
- #raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
- else:
- return object.__new__(cls, *args, **kwargs)
+ return object.__new__(cls, *args, **kwargs)
- def __init__(self, module_name=None, module=None, auto_identity_insert=True, **params):
- if not hasattr(self, 'module'):
- self.module = module
+ def __init__(self, auto_identity_insert=True, **params):
super(MSSQLDialect, self).__init__(**params)
self.auto_identity_insert = auto_identity_insert
self.text_as_varchar = False
self.text_as_varchar = bool(opts.pop('text_as_varchar'))
return self.make_connect_string(opts)
- def create_execution_context(self):
- return MSSQLExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return MSSQLExecutionContext(self, *args, **kwargs)
def type_descriptor(self, typeobj):
newobj = sqltypes.adapt_type(typeobj, self.colspecs)
return MSSQLCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return MSSQLSchemaGenerator(*args, **kwargs)
+ return MSSQLSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return MSSQLSchemaDropper(*args, **kwargs)
+ return MSSQLSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return MSSQLDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection, **kwargs):
+ return MSSQLDefaultRunner(connection, **kwargs)
def preparer(self):
return MSSQLIdentifierPreparer(self)
def raw_connection(self, connection):
"""Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
try:
+ # TODO: probably want to move this to individual dialect subclasses to
+ # save on the exception throw + simplify
return connection.connection.__dict__['_pymssqlCnx__cnx']
except:
return connection.connection.adoConn
- def connection(self):
- """returns a managed DBAPI connection from this SQLEngine's connection pool."""
- c = self._pool.connect()
- c.supportsTransactions = 0
- return c
-
- def dbapi(self):
- return self.module
-
def uppercase_table(self, t):
# convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive
t.name = t.name.upper()
table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
class MSSQLDialect_pymssql(MSSQLDialect):
- try:
+ def import_dbapi(cls):
import pymssql as module
# pymmsql doesn't have a Binary method. we use string
+ # TODO: monkeypatching here is less than ideal
module.Binary = lambda st: str(st)
- except ImportError, e:
- saved_import_error = e
-
+ return module
+ import_dbapi = classmethod(import_dbapi)
+
def supports_sane_rowcount(self):
return True
def create_connect_args(self, url):
r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
if hasattr(self, 'query_timeout'):
- self.module._mssql.set_query_timeout(self.query_timeout)
+ self.dbapi._mssql.set_query_timeout(self.query_timeout)
return r
def make_connect_string(self, keys):
## r.fetch_array()
class MSSQLDialect_pyodbc(MSSQLDialect):
- try:
+
+ def import_dbapi(cls):
import pyodbc as module
- except ImportError, e:
- saved_import_error = e
-
+ return module
+ import_dbapi = classmethod(import_dbapi)
+
colspecs = MSSQLDialect.colspecs.copy()
- colspecs[sqltypes.Unicode] = AdoMSUnicode
+ colspecs[sqltypes.Unicode] = AdoMSNVarchar
ischema_names = MSSQLDialect.ischema_names.copy()
- ischema_names['nvarchar'] = AdoMSUnicode
+ ischema_names['nvarchar'] = AdoMSNVarchar
def supports_sane_rowcount(self):
return False
class MSSQLDialect_adodbapi(MSSQLDialect):
- try:
+ def import_dbapi(cls):
import adodbapi as module
- except ImportError, e:
- saved_import_error = e
+ return module
+ import_dbapi = classmethod(import_dbapi)
colspecs = MSSQLDialect.colspecs.copy()
- colspecs[sqltypes.Unicode] = AdoMSUnicode
+ colspecs[sqltypes.Unicode] = AdoMSNVarchar
ischema_names = MSSQLDialect.ischema_names.copy()
- ischema_names['nvarchar'] = AdoMSUnicode
+ ischema_names['nvarchar'] = AdoMSNVarchar
def supports_sane_rowcount(self):
return True
connectors.append("Integrated Security=SSPI")
return [[";".join (connectors)], {}]
-
dialect_mapping = {
'pymssql': MSSQLDialect_pymssql,
'pyodbc': MSSQLDialect_pyodbc,
'adodbapi': MSSQLDialect_adodbapi
}
-dialect_preference = [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]
class MSSQLCompiler(ansisql.ANSICompiler):
class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
self.execute()
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
+ # TODO: does ms-sql have standalone sequences ?
pass
class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
import sqlalchemy.exceptions as exceptions
from array import array
-try:
+def dbapi():
import MySQLdb as mysql
- import MySQLdb.constants.CLIENT as CLIENT_FLAGS
-except:
- mysql = None
- CLIENT_FLAGS = None
+ return mysql
def kw_colspec(self, spec):
if self.unsigned:
return "LONGTEXT"
class MSString(sqltypes.String):
- def __init__(self, length=None, *extra, **kwargs):
- sqltypes.String.__init__(self, length=length)
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
]}
class MySQLExecutionContext(default.DefaultExecutionContext):
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
- if getattr(compiled, "isinsert", False):
- self._last_inserted_ids = [proxy().lastrowid]
+ def post_exec(self):
+ if self.compiled.isinsert:
+ self._last_inserted_ids = [self.cursor.lastrowid]
class MySQLDialect(ansisql.ANSIDialect):
- def __init__(self, module = None, **kwargs):
- if module is None:
- self.module = mysql
- else:
- self.module = module
+ def __init__(self, **kwargs):
ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
def create_connect_args(self, url):
# TODO: what about options like "ssl", "cursorclass" and "conv" ?
client_flag = opts.get('client_flag', 0)
- if CLIENT_FLAGS is not None:
- client_flag |= CLIENT_FLAGS.FOUND_ROWS
+ if self.dbapi is not None:
+ try:
+ import MySQLdb.constants.CLIENT as CLIENT_FLAGS
+ client_flag |= CLIENT_FLAGS.FOUND_ROWS
+ except:
+ pass
opts['client_flag'] = client_flag
return [[], opts]
- def create_execution_context(self):
- return MySQLExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return MySQLExecutionContext(self, *args, **kwargs)
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
return MySQLCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return MySQLSchemaGenerator(*args, **kwargs)
+ return MySQLSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return MySQLSchemaDropper(*args, **kwargs)
+ return MySQLSchemaDropper(self, *args, **kwargs)
def preparer(self):
return MySQLIdentifierPreparer(self)
rowcount = cursor.executemany(statement, parameters)
if context is not None:
context._rowcount = rowcount
- except mysql.OperationalError, o:
+ except self.dbapi.OperationalError, o:
if o.args[0] == 2006 or o.args[0] == 2014:
cursor.invalidate()
raise o
def do_execute(self, cursor, statement, parameters, **kwargs):
try:
cursor.execute(statement, parameters)
- except mysql.OperationalError, o:
+ except self.dbapi.OperationalError, o:
if o.args[0] == 2006 or o.args[0] == 2014:
cursor.invalidate()
raise o
self._default_schema_name = text("select database()", self).scalar()
return self._default_schema_name
- def dbapi(self):
- return self.module
-
def has_table(self, connection, table_name, schema=None):
- cursor = connection.execute("show table status like '" + table_name + "'")
+ cursor = connection.execute("show table status like %s", [table_name])
+ print "CURSOR", cursor, "ROWCOUNT", cursor.rowcount, "REAL RC", cursor.cursor.rowcount
return bool( not not cursor.rowcount )
def reflecttable(self, connection, table):
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
- t = column.type.engine_impl(self.engine)
- colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
import sys, StringIO, string, re
from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
-import sqlalchemy.engine.default as default
+from sqlalchemy.engine import default, base
import sqlalchemy.types as sqltypes
-try:
+def dbapi():
import cx_Oracle
-except:
- cx_Oracle = None
+ return cx_Oracle
-ORACLE_BINARY_TYPES = [getattr(cx_Oracle, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(cx_Oracle, k)]
class OracleNumeric(sqltypes.Numeric):
def get_col_spec(self):
]}
class OracleExecutionContext(default.DefaultExecutionContext):
- def pre_exec(self, engine, proxy, compiled, parameters):
- super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters)
+ def pre_exec(self):
+ super(OracleExecutionContext, self).pre_exec()
if self.dialect.auto_setinputsizes:
- self.set_input_sizes(proxy(), parameters)
+ self.set_input_sizes()
+
+ def get_result_proxy(self):
+ if self.cursor.description is not None:
+ for column in self.cursor.description:
+ type_code = column[1]
+ if type_code in self.dialect.ORACLE_BINARY_TYPES:
+ return base.BufferedColumnResultProxy(self)
+
+ return base.ResultProxy(self)
class OracleDialect(ansisql.ANSIDialect):
- def __init__(self, use_ansi=True, auto_setinputsizes=True, module=None, threaded=True, **kwargs):
+ def __init__(self, use_ansi=True, auto_setinputsizes=True, threaded=True, **kwargs):
+ ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs)
self.use_ansi = use_ansi
self.threaded = threaded
- if module is None:
- self.module = cx_Oracle
- else:
- self.module = module
- self.supports_timestamp = hasattr(self.module, 'TIMESTAMP' )
+ self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
self.auto_setinputsizes = auto_setinputsizes
- ansisql.ANSIDialect.__init__(self, **kwargs)
-
- def dbapi(self):
- return self.module
-
+ if self.dbapi is not None:
+ self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)]
+ else:
+ self.ORACLE_BINARY_TYPES = []
+
def create_connect_args(self, url):
if url.database:
# if we have a database, then we have a remote host
port = int(port)
else:
port = 1521
- dsn = self.module.makedsn(url.host,port,url.database)
+ dsn = self.dbapi.makedsn(url.host,port,url.database)
else:
# we have a local tnsname
dsn = url.host
else:
return "rowid"
- def create_execution_context(self):
- return OracleExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return OracleExecutionContext(self, *args, **kwargs)
def compiler(self, statement, bindparams, **kwargs):
return OracleCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return OracleSchemaGenerator(*args, **kwargs)
+ return OracleSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return OracleSchemaDropper(*args, **kwargs)
+ return OracleSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return OracleDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection, **kwargs):
+ return OracleDefaultRunner(connection, **kwargs)
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()})
if context is not None:
context._rowcount = rowcount
- def create_result_proxy_args(self, connection, cursor):
- args = super(OracleDialect, self).create_result_proxy_args(connection, cursor)
- if cursor and cursor.description:
- for column in cursor.description:
- type_code = column[1]
- if type_code in ORACLE_BINARY_TYPES:
- args['should_prefetch'] = True
- break
- return args
OracleDialect.logger = logging.class_logger(OracleDialect)
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
- colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
return colspec
def visit_sequence(self, sequence):
- if not self.engine.dialect.has_sequence(self.connection, sequence.name):
+ if not self.dialect.has_sequence(self.connection, sequence.name):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class OracleSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if self.engine.dialect.has_sequence(self.connection, sequence.name):
+ if self.dialect.has_sequence(self.connection, sequence.name):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile()
- return self.proxy(str(c), c.get_params()).fetchone()[0]
+ return self.connection.execute_compiled(c).scalar()
def visit_sequence(self, seq):
- return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0]
+ return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
dialect = OracleDialect
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import datetime, sys, StringIO, string, types, re
-
-import sqlalchemy.util as util
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
-import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
+import datetime, string, types, re, random
+
+from sqlalchemy import util, sql, schema, ansisql, exceptions
+from sqlalchemy.engine import base, default
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
from sqlalchemy.databases import information_schema as ischema
-import re
try:
import mx.DateTime.DateTime as mxDateTime
except:
mxDateTime = None
-try:
- import psycopg2 as psycopg
- #import psycopg2.psycopg1 as psycopg
-except:
+def dbapi():
try:
- import psycopg
- except:
- psycopg = None
-
+ import psycopg2 as psycopg
+ except ImportError, e:
+ try:
+ import psycopg
+ except ImportError, e2:
+ raise e
+ return psycopg
+
class PGInet(sqltypes.TypeEngine):
def get_col_spec(self):
return "INET"
mx_datetime = mxDateTime(value.year, value.month, value.day,
value.hour, value.minute,
seconds)
- return psycopg.TimestampFromMx(mx_datetime)
- return psycopg.TimestampFromMx(value)
+ return dialect.dbapi.TimestampFromMx(mx_datetime)
+ return dialect.dbapi.TimestampFromMx(value)
else:
return None
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
# this one doesnt seem to work with the "emulation" mode
if value is not None:
- return psycopg.DateFromMx(value)
+ return dialect.dbapi.DateFromMx(value)
else:
return None
]}
class PGExecutionContext(default.DefaultExecutionContext):
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
- if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None:
- if not engine.dialect.use_oids:
+
+ def is_select(self):
+ return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I)
+
+ def create_cursor(self):
+ if self.dialect.server_side_cursors and self.is_select():
+ # use server-side cursors:
+ # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+ ident = "c" + hex(random.randint(0, 65535))[2:]
+ return self.connection.connection.cursor(ident)
+ else:
+ return self.connection.connection.cursor()
+
+ def get_result_proxy(self):
+ if self.dialect.server_side_cursors and self.is_select():
+ return base.BufferedRowResultProxy(self)
+ else:
+ return base.ResultProxy(self)
+
+ def post_exec(self):
+ if self.compiled.isinsert and self.last_inserted_ids is None:
+ if not self.dialect.use_oids:
pass
# will raise invalid error when they go to get them
else:
- table = compiled.statement.table
- cursor = proxy()
- if cursor.lastrowid is not None and table is not None and len(table.primary_key):
- s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid)
- c = s.compile(engine=engine)
- cursor = proxy(str(c), c.get_params())
- row = cursor.fetchone()
+ table = self.compiled.statement.table
+ if self.cursor.lastrowid is not None and table is not None and len(table.primary_key):
+ s = sql.select(table.primary_key, table.oid_column == self.cursor.lastrowid)
+ row = self.connection.execute(s).fetchone()
self._last_inserted_ids = [v for v in row]
-
+ super(PGExecutionContext, self).post_exec()
+
class PGDialect(ansisql.ANSIDialect):
- def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params):
+ def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs):
+ ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
- if module is None:
- #if psycopg is None:
- # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument")
- self.module = psycopg
+ if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'):
+ self.version = 2
else:
- self.module = module
- # figure psycopg version 1 or 2
- try:
- if self.module.__version__.startswith('2'):
- self.version = 2
- else:
- self.version = 1
- except:
self.version = 1
- ansisql.ANSIDialect.__init__(self, **params)
self.use_information_schema = use_information_schema
- # produce consistent paramstyle even if psycopg2 module not present
- if self.module is None:
- self.paramstyle = 'pyformat'
+ self.paramstyle = 'pyformat'
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
opts.update(url.query)
return ([], opts)
- def create_cursor(self, connection):
- if self.server_side_cursors:
- # use server-side cursors:
- # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
- return connection.cursor('x')
- else:
- return connection.cursor()
- def create_execution_context(self):
- return PGExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return PGExecutionContext(self, *args, **kwargs)
def max_identifier_length(self):
return 68
return PGCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return PGSchemaGenerator(*args, **kwargs)
+ return PGSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return PGSchemaDropper(*args, **kwargs)
+ return PGSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return PGDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection, **kwargs):
+ return PGDefaultRunner(connection, **kwargs)
def preparer(self):
return PGIdentifierPreparer(self)
``psycopg2`` is not nice enough to produce this correctly for
an executemany, so we do our own executemany here.
"""
-
rowcount = 0
for param in parameters:
c.execute(statement, param)
if context is not None:
context._rowcount = rowcount
- def dbapi(self):
- return self.module
-
def has_table(self, connection, table_name, schema=None):
# seems like case gets folded in pg_class...
if schema is None:
else:
colspec += " SERIAL"
else:
- colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
if column.primary_key:
# passive defaults on primary keys have to be overridden
if isinstance(column.default, schema.PassiveDefault):
- c = self.proxy("select %s" % column.default.arg)
- return c.fetchone()[0]
+ return self.connection.execute_text("select %s" % column.default.arg).scalar()
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- c = self.proxy(exc)
- return c.fetchone()[0]
- else:
- return ansisql.ANSIDefaultRunner.get_column_default(self, column)
- else:
- return ansisql.ANSIDefaultRunner.get_column_default(self, column)
+ return self.connection.execute_text(exc).scalar()
+
+ return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
def visit_sequence(self, seq):
if not seq.optional:
- c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq))
- return c.fetchone()[0]
+ return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar()
else:
return None
import sqlalchemy.types as sqltypes
import datetime,time
-pysqlite2_timesupport = False # Change this if the init.d guys ever get around to supporting time cols
-
-try:
- from pysqlite2 import dbapi2 as sqlite
-except ImportError:
+def dbapi():
try:
- from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
- except ImportError:
+ from pysqlite2 import dbapi2 as sqlite
+ except ImportError, e:
try:
- sqlite = __import__('sqlite') # skip ourselves
- except:
- sqlite = None
-
+ from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
+ except ImportError:
+ try:
+ sqlite = __import__('sqlite') # skip ourselves
+ except ImportError:
+ raise e
+ return sqlite
+
class SLNumeric(sqltypes.Numeric):
def get_col_spec(self):
if self.precision is None:
'BLOB' : SLBinary,
}
-if pysqlite2_timesupport:
- colspecs.update({sqltypes.Time : SLTime})
- pragma_names.update({'TIME' : SLTime})
-
def descriptor():
return {'name':'sqlite',
'description':'SQLite',
]}
class SQLiteExecutionContext(default.DefaultExecutionContext):
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
- if getattr(compiled, "isinsert", False):
- self._last_inserted_ids = [proxy().lastrowid]
-
+ def post_exec(self):
+ if self.compiled.isinsert:
+ self._last_inserted_ids = [self.cursor.lastrowid]
+ super(SQLiteExecutionContext, self).post_exec()
+
class SQLiteDialect(ansisql.ANSIDialect):
def __init__(self, **kwargs):
+ ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs)
def vers(num):
return tuple([int(x) for x in num.split('.')])
- self.supports_cast = (sqlite is not None and vers(sqlite.sqlite_version) >= vers("3.2.3"))
- ansisql.ANSIDialect.__init__(self, **kwargs)
+ self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
def compiler(self, statement, bindparams, **kwargs):
return SQLiteCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return SQLiteSchemaGenerator(*args, **kwargs)
+ return SQLiteSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return SQLiteSchemaDropper(*args, **kwargs)
+ return SQLiteSchemaDropper(self, *args, **kwargs)
+
+ def supports_alter(self):
+ return False
def preparer(self):
return SQLiteIdentifierPreparer(self)
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
- def create_execution_context(self):
- return SQLiteExecutionContext(self)
+ def create_execution_context(self, **kwargs):
+ return SQLiteExecutionContext(self, **kwargs)
def last_inserted_ids(self):
return self.context.last_inserted_ids
def oid_column_name(self, column):
return "oid"
- def dbapi(self):
- return sqlite
-
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {})
row = cursor.fetchone()
return ansisql.ANSICompiler.binary_operator_string(self, binary)
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
- def supports_alter(self):
- return False
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
# super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
- def supports_alter(self):
- return False
+ pass
class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def __init__(self, dialect):
raise NotImplementedError()
def type_descriptor(self, typeobj):
- """Trasform the type from generic to database-specific.
+ """Transform the type from generic to database-specific.
Provides a database-specific TypeEngine object, given the
generic object which comes from the types module. Subclasses
raise NotImplementedError()
+ def supports_alter(self):
+ """return True if the database supports ALTER TABLE."""
+ raise NotImplementedError()
+
def max_identifier_length(self):
"""Return the maximum length of identifier names.
def supports_sane_rowcount(self):
"""Indicate whether the dialect properly implements statements rowcount.
- Provided to indicate when MySQL is being used, which does not
- have standard behavior for the "rowcount" function on a statement handle.
+ This was needed for MySQL which had non-standard behavior of rowcount,
+ but this issue has since been resolved.
"""
raise NotImplementedError()
- def schemagenerator(self, engine, proxy, **params):
+ def schemagenerator(self, connection, **kwargs):
"""Return a ``schema.SchemaVisitor`` instance that can generate schemas.
+ connection
+ a Connection to use for statement execution
+
`schemagenerator()` is called via the `create()` method on Table,
Index, and others.
"""
raise NotImplementedError()
- def schemadropper(self, engine, proxy, **params):
+ def schemadropper(self, connection, **kwargs):
"""Return a ``schema.SchemaVisitor`` instance that can drop schemas.
+ connection
+ a Connection to use for statement execution
+
`schemadropper()` is called via the `drop()` method on Table,
Index, and others.
"""
raise NotImplementedError()
- def defaultrunner(self, engine, proxy, **params):
- """Return a ``schema.SchemaVisitor`` instance that can execute defaults."""
+ def defaultrunner(self, connection, **kwargs):
+ """Return a ``schema.SchemaVisitor`` instance that can execute defaults.
+
+ connection
+ a Connection to use for statement execution
+
+ """
raise NotImplementedError()
ansisql.ANSICompiler, and will produce a string representation
of the given ClauseElement and `parameters` dictionary.
- `compiler()` is called within the context of the compile() method.
"""
raise NotImplementedError()
raise NotImplementedError()
- def dbapi(self):
- """Establish a connection to the database.
-
- Subclasses override this method to provide the DBAPI module
- used to establish connections.
- """
-
- raise NotImplementedError()
-
def get_default_schema_name(self, connection):
"""Return the currently selected schema given a connection"""
raise NotImplementedError()
- def execution_context(self):
+ def create_execution_context(self, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
"""Return a new ExecutionContext object."""
-
raise NotImplementedError()
def do_begin(self, connection):
raise NotImplementedError()
- def create_cursor(self, connection):
- """Return a new cursor generated from the given connection."""
-
- raise NotImplementedError()
-
- def create_result_proxy_args(self, connection, cursor):
- """Return a dictionary of arguments that should be passed to ResultProxy()."""
-
- raise NotImplementedError()
def compile(self, clauseelement, parameters=None):
"""Compile the given ClauseElement using this Dialect.
class ExecutionContext(object):
"""A messenger object for a Dialect that corresponds to a single execution.
+ ExecutionContext should have these datamembers:
+
+ connection
+ Connection object which initiated the call to the
+ dialect to create this ExecutionContext.
+
+ dialect
+ dialect which created this ExecutionContext.
+
+ cursor
+ DBAPI cursor procured from the connection
+
+ compiled
+ if passed to constructor, sql.Compiled object being executed
+
+ compiled_parameters
+ if passed to constructor, sql.ClauseParameters object
+
+ statement
+ string version of the statement to be executed. Is either
+ passed to the constructor, or must be created from the
+ sql.Compiled object by the time pre_exec() has completed.
+
+ parameters
+ "raw" parameters suitable for direct execution by the
+ dialect. Either passed to the constructor, or must be
+ created from the sql.ClauseParameters object by the time
+ pre_exec() has completed.
+
+
The Dialect should provide an ExecutionContext via the
create_execution_context() method. The `pre_exec` and `post_exec`
- methods will be called for compiled statements, afterwhich it is
- expected that the various methods `last_inserted_ids`,
- `last_inserted_params`, etc. will contain appropriate values, if
- applicable.
+ methods will be called for compiled statements.
+
"""
- def pre_exec(self, engine, proxy, compiled, parameters):
- """Called before an execution of a compiled statement.
+ def create_cursor(self):
+ """Return a new cursor generated this ExecutionContext's connection."""
- `proxy` is a callable that takes a string statement and a bind
- parameter list/dictionary.
+ raise NotImplementedError()
+
+ def pre_exec(self):
+ """Called before an execution of a compiled statement.
+
+ If compiled and compiled_parameters were passed to this
+ ExecutionContext, the `statement` and `parameters` datamembers
+ must be initialized after this statement is complete.
"""
raise NotImplementedError()
- def post_exec(self, engine, proxy, compiled, parameters):
+ def post_exec(self):
"""Called after the execution of a compiled statement.
-
- `proxy` is a callable that takes a string statement and a bind
- parameter list/dictionary.
+
+ If compiled was passed to this ExecutionContext,
+ the `last_insert_ids`, `last_inserted_params`, etc.
+ datamembers should be available after this method
+ completes.
"""
raise NotImplementedError()
-
- def get_rowcount(self, cursor):
- """Return the count of rows updated/deleted for an UPDATE/DELETE statement."""
-
+
+ def get_result_proxy(self):
+ """return a ResultProxy corresponding to this ExecutionContext."""
raise NotImplementedError()
-
- def supports_sane_rowcount(self):
- """Indicate if the "rowcount" DBAPI cursor function works properly.
-
- Currently, MySQLDB does not properly implement this function.
- """
+
+ def get_rowcount(self):
+ """Return the count of rows updated/deleted for an UPDATE/DELETE statement."""
raise NotImplementedError()
This does not apply to straight textual clauses; only to
``sql.Insert`` objects compiled against a ``schema.Table`` object,
- which are executed via `statement.execute()`. The order of
+ which are executed via `execute()`. The order of
items in the list is the same as that of the Table's
'primary_key' attribute.
raise NotImplementedError()
-class Connectable(object):
+class Connectable(sql.Executor):
"""Interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
def contextual_connect(self):
raise NotImplementedError()
engine = property(_not_impl, doc="The Engine which this Connectable is associated with.")
+ dialect = property(_not_impl, doc="Dialect which this Connectable is associated with.")
class Connection(Connectable):
"""Represent a single DBAPI connection returned from the underlying connection pool.
except AttributeError:
raise exceptions.InvalidRequestError("This Connection is closed")
- engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated (read only)")
+ engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.")
+ dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.")
connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.")
"""When no Transaction is present, this is called after executions to provide "autocommit" behavior."""
# TODO: have the dialect determine if autocommit can be set on the connection directly without this
# extra step
- if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()):
+ if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip(), re.I):
self._commit_impl()
def _autorollback(self):
def scalar(self, object, *multiparams, **params):
return self.execute(object, *multiparams, **params).scalar()
+ def compiler(self, statement, parameters, **kwargs):
+ return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs)
+
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
if c in Connection.executors:
raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
def execute_default(self, default, **kwargs):
- return default.accept_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs))
+ return default.accept_visitor(self.__engine.dialect.defaultrunner(self))
def execute_text(self, statement, *multiparams, **params):
if len(multiparams) == 0:
parameters = multiparams[0]
else:
parameters = list(multiparams)
- cursor = self._execute_raw(statement, parameters)
- rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
- return ResultProxy(self.__engine, self, cursor, **rpargs)
+ context = self._create_execution_context(statement=statement, parameters=parameters)
+ self._execute_raw(context)
+ return context.get_result_proxy()
def _params_to_listofdicts(self, *multiparams, **params):
if len(multiparams) == 0:
param = multiparams[0]
else:
param = params
- return self.execute_compiled(elem.compile(engine=self.__engine, parameters=param), *multiparams, **params)
+ return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params)
def execute_compiled(self, compiled, *multiparams, **params):
"""Execute a sql.Compiled object."""
if not compiled.can_execute:
raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
- cursor = self.__engine.dialect.create_cursor(self.connection)
parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)]
if len(parameters) == 1:
parameters = parameters[0]
- def proxy(statement=None, parameters=None):
- if statement is None:
- return cursor
-
- parameters = self.__engine.dialect.convert_compiled_params(parameters)
- self._execute_raw(statement, parameters, cursor=cursor, context=context)
- return cursor
- context = self.__engine.dialect.create_execution_context()
- context.pre_exec(self.__engine, proxy, compiled, parameters)
- proxy(unicode(compiled), parameters)
- context.post_exec(self.__engine, proxy, compiled, parameters)
- rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
- return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs)
+ context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters)
+ context.pre_exec()
+ self._execute_raw(context)
+ context.post_exec()
+ return context.get_result_proxy()
+
+ def _create_execution_context(self, **kwargs):
+ return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
+
+ def _execute_raw(self, context):
+ self.__engine.logger.info(context.statement)
+ self.__engine.logger.info(repr(context.parameters))
+ if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], dict)):
+ self._executemany(context)
+ else:
+ self._execute(context)
+ self._autocommit(context.statement)
+
+ def _execute(self, context):
+ if context.parameters is None:
+ if context.dialect.positional:
+ context.parameters = ()
+ else:
+ context.parameters = {}
+ try:
+ context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context)
+ except Exception, e:
+ self._autorollback()
+ #self._rollback_impl()
+ if self.__close_with_result:
+ self.close()
+ raise exceptions.SQLError(context.statement, context.parameters, e)
+
+ def _executemany(self, context):
+ try:
+ context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
+ except Exception, e:
+ self._autorollback()
+ #self._rollback_impl()
+ if self.__close_with_result:
+ self.close()
+ raise exceptions.SQLError(context.statement, context.parameters, e)
# poor man's multimethod/generic function thingy
executors = {
}
def create(self, entity, **kwargs):
- """Create a table or index given an appropriate schema object."""
+ """Create a Table or Index given an appropriate Schema object."""
return self.__engine.create(entity, connection=self, **kwargs)
def drop(self, entity, **kwargs):
- """Drop a table or index given an appropriate schema object."""
+ """Drop a Table or Index given an appropriate Schema object."""
return self.__engine.drop(entity, connection=self, **kwargs)
def reflecttable(self, table, **kwargs):
- """Reflect the columns in the given table from the database."""
+ """Reflect the columns in the given string table name from the database."""
return self.__engine.reflecttable(table, connection=self, **kwargs)
def run_callable(self, callable_):
return callable_(self)
- def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs):
- if cursor is None:
- cursor = self.__engine.dialect.create_cursor(self.connection)
- if not self.__engine.dialect.supports_unicode_statements():
- # encode to ascii, with full error handling
- statement = statement.encode('ascii')
- self.__engine.logger.info(statement)
- self.__engine.logger.info(repr(parameters))
- if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
- self._executemany(cursor, statement, parameters, context=context)
- else:
- self._execute(cursor, statement, parameters, context=context)
- self._autocommit(statement)
- return cursor
-
- def _execute(self, c, statement, parameters, context=None):
- if parameters is None:
- if self.__engine.dialect.positional:
- parameters = ()
- else:
- parameters = {}
- try:
- self.__engine.dialect.do_execute(c, statement, parameters, context=context)
- except Exception, e:
- self._autorollback()
- #self._rollback_impl()
- if self.__close_with_result:
- self.close()
- raise exceptions.SQLError(statement, parameters, e)
-
- def _executemany(self, c, statement, parameters, context=None):
- try:
- self.__engine.dialect.do_executemany(c, statement, parameters, context=context)
- except Exception, e:
- self._autorollback()
- #self._rollback_impl()
- if self.__close_with_result:
- self.close()
- raise exceptions.SQLError(statement, parameters, e)
-
- def proxy(self, statement=None, parameters=None):
- """Execute the given statement string and parameter object.
-
- The parameter object is expected to be the result of a call to
- ``compiled.get_params()``. This callable is a generic version
- of a connection/cursor-specific callable that is produced
- within the execute_compiled method, and is used for objects
- that require this style of proxy when outside of an
- execute_compiled method, primarily the DefaultRunner.
- """
- parameters = self.__engine.dialect.convert_compiled_params(parameters)
- return self._execute_raw(statement, parameters)
-
class Transaction(object):
"""Represent a Transaction in progress.
self.__connection._commit_impl()
self.__is_active = False
-class Engine(sql.Executor, Connectable):
+class Engine(Connectable):
"""
Connects a ConnectionProvider, a Dialect and a CompilerFactory together to
provide a default implementation of SchemaEngine.
def __init__(self, connection_provider, dialect, echo=None):
self.connection_provider = connection_provider
- self.dialect=dialect
+ self._dialect=dialect
self.echo = echo
self.logger = logging.instance_logger(self)
name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'])
engine = property(lambda s:s)
+ dialect = property(lambda s:s._dialect)
echo = logging.echo_property()
def dispose(self):
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
if connection is None:
- conn = self.contextual_connect()
+ conn = self.contextual_connect(close_with_result=False)
else:
conn = connection
try:
- element.accept_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs))
+ element.accept_visitor(visitorcallable(conn, **kwargs))
finally:
if connection is None:
conn.close()
def convert_result_value(self, arg, engine):
raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
- def __new__(cls, *args, **kwargs):
- if cls is ResultProxy and kwargs.has_key('should_prefetch') and kwargs['should_prefetch']:
- return PrefetchingResultProxy(*args, **kwargs)
- else:
- return object.__new__(cls, *args, **kwargs)
-
- def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None, column_labels=None, should_prefetch=None):
+ def __init__(self, context):
"""ResultProxy objects are constructed via the execute() method on SQLEngine."""
-
- self.connection = connection
- self.dialect = engine.dialect
- self.cursor = cursor
- self.engine = engine
+ self.context = context
self.closed = False
- self.column_labels = column_labels
- if executioncontext is not None:
- self.__executioncontext = executioncontext
- self.rowcount = executioncontext.get_rowcount(cursor)
- else:
- self.rowcount = cursor.rowcount
- self.__key_cache = {}
- self.__echo = engine.echo == 'debug'
- metadata = cursor.description
- self.props = {}
- self.keys = []
- i = 0
+ self.cursor = context.cursor
+ self.__echo = logging.is_debug_enabled(context.engine.logger)
+ self._init_metadata()
+ dialect = property(lambda s:s.context.dialect)
+ rowcount = property(lambda s:s.context.get_rowcount())
+ connection = property(lambda s:s.context.connection)
+
+ def _init_metadata(self):
+ if hasattr(self, '_ResultProxy__props'):
+ return
+ self.__key_cache = {}
+ self.__props = {}
+ self.__keys = []
+ metadata = self.cursor.description
if metadata is not None:
- for item in metadata:
+ for i, item in enumerate(metadata):
# sqlite possibly prepending table name to colnames so strip
- colname = item[0].split('.')[-1].lower()
- if typemap is not None:
- rec = (typemap.get(colname, types.NULLTYPE), i)
+ colname = item[0].split('.')[-1]
+ if self.context.typemap is not None:
+ rec = (self.context.typemap.get(colname.lower(), types.NULLTYPE), i)
else:
rec = (types.NULLTYPE, i)
if rec[0] is None:
raise DBAPIError("None for metadata " + colname)
- if self.props.setdefault(colname, rec) is not rec:
- self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0)
- self.keys.append(colname)
- self.props[i] = rec
- i+=1
-
- def _executioncontext(self):
- try:
- return self.__executioncontext
- except AttributeError:
- raise exceptions.InvalidRequestError("This ResultProxy does not have an execution context with which to complete this operation. Execution contexts are not generated for literal SQL execution.")
- executioncontext = property(_executioncontext)
+ if self.__props.setdefault(colname.lower(), rec) is not rec:
+ self.__props[colname.lower()] = (ResultProxy.AmbiguousColumn(colname), 0)
+ self.__keys.append(colname)
+ self.__props[i] = rec
def close(self):
"""Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution.
This method is also called automatically when all result rows
are exhausted.
"""
-
if not self.closed:
self.closed = True
self.cursor.close()
if self.connection.should_close_with_result and self.dialect.supports_autoclose_results:
self.connection.close()
-
+
def _convert_key(self, key):
"""Convert and cache a key.
metadata; then cache it locally for quick re-access.
"""
- try:
+ if key in self.__key_cache:
return self.__key_cache[key]
- except KeyError:
- if isinstance(key, int) and key in self.props:
- rec = self.props[key]
- elif isinstance(key, basestring) and key.lower() in self.props:
- rec = self.props[key.lower()]
+ else:
+ if isinstance(key, int) and key in self.__props:
+ rec = self.__props[key]
+ elif isinstance(key, basestring) and key.lower() in self.__props:
+ rec = self.__props[key.lower()]
elif isinstance(key, sql.ColumnElement):
- label = self.column_labels.get(key._label, key.name).lower()
- if label in self.props:
- rec = self.props[label]
+ label = self.context.column_labels.get(key._label, key.name).lower()
+ if label in self.__props:
+ rec = self.__props[label]
if not "rec" in locals():
raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (repr(key)))
self.__key_cache[key] = rec
return rec
-
-
+
+ keys = property(lambda s:s.__keys)
+
def _has_key(self, row, key):
try:
self._convert_key(key)
except KeyError:
return False
- def _get_col(self, row, key):
- rec = self._convert_key(key)
- return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect)
-
def __iter__(self):
while True:
row = self.fetchone()
See ExecutionContext for details.
"""
- return self.executioncontext.last_inserted_ids()
+ return self.context.last_inserted_ids()
def last_updated_params(self):
"""Return ``last_updated_params()`` from the underlying ExecutionContext.
See ExecutionContext for details.
"""
- return self.executioncontext.last_updated_params()
+ return self.context.last_updated_params()
def last_inserted_params(self):
"""Return ``last_inserted_params()`` from the underlying ExecutionContext.
See ExecutionContext for details.
"""
- return self.executioncontext.last_inserted_params()
+ return self.context.last_inserted_params()
def lastrow_has_defaults(self):
"""Return ``lastrow_has_defaults()`` from the underlying ExecutionContext.
See ExecutionContext for details.
"""
- return self.executioncontext.lastrow_has_defaults()
+ return self.context.lastrow_has_defaults()
def supports_sane_rowcount(self):
"""Return ``supports_sane_rowcount()`` from the underlying ExecutionContext.
See ExecutionContext for details.
"""
- return self.executioncontext.supports_sane_rowcount()
+ return self.context.supports_sane_rowcount()
+ def _get_col(self, row, key):
+ rec = self._convert_key(key)
+ return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect)
+
+ def _fetchone_impl(self):
+ return self.cursor.fetchone()
+ def _fetchmany_impl(self, size=None):
+ return self.cursor.fetchmany(size)
+ def _fetchall_impl(self):
+ return self.cursor.fetchall()
+
+ def _process_row(self, row):
+ return RowProxy(self, row)
+
def fetchall(self):
"""Fetch all rows, just like DBAPI ``cursor.fetchall()``."""
- l = []
- for row in self.cursor.fetchall():
- l.append(RowProxy(self, row))
+ l = [self._process_row(row) for row in self._fetchall_impl()]
self.close()
return l
def fetchmany(self, size=None):
"""Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``."""
- if size is None:
- rows = self.cursor.fetchmany()
- else:
- rows = self.cursor.fetchmany(size)
- l = []
- for row in rows:
- l.append(RowProxy(self, row))
+ l = [self._process_row(row) for row in self._fetchmany_impl(size)]
if len(l) == 0:
self.close()
return l
def fetchone(self):
"""Fetch one row, just like DBAPI ``cursor.fetchone()``."""
-
- row = self.cursor.fetchone()
+ row = self._fetchone_impl()
if row is not None:
- return RowProxy(self, row)
+ return self._process_row(row)
else:
self.close()
return None
def scalar(self):
"""Fetch the first column of the first row, and close the result set."""
-
- row = self.cursor.fetchone()
+ row = self._fetchone_impl()
try:
if row is not None:
- return RowProxy(self, row)[0]
+ return self._process_row(row)[0]
else:
return None
finally:
self.close()
-class PrefetchingResultProxy(ResultProxy):
+class BufferedRowResultProxy(ResultProxy):
+ def _init_metadata(self):
+ self.__buffer_rows()
+ super(BufferedRowResultProxy, self)._init_metadata()
+
+ # this is a "growth chart" for the buffering of rows.
+ # each successive __buffer_rows call will use the next
+ # value in the list for the buffer size until the max
+ # is reached
+ size_growth = {
+ 1 : 5,
+ 5 : 10,
+ 10 : 20,
+ 20 : 50,
+ 50 : 100
+ }
+
+ def __buffer_rows(self):
+ size = getattr(self, '_bufsize', 1)
+ self.__rowbuffer = self.cursor.fetchmany(size)
+ #self.context.engine.logger.debug("Buffered %d rows" % size)
+ self._bufsize = self.size_growth.get(size, size)
+
+ def _fetchone_impl(self):
+ if self.closed:
+ return None
+ if len(self.__rowbuffer) == 0:
+ self.__buffer_rows()
+ if len(self.__rowbuffer) == 0:
+ return None
+ return self.__rowbuffer.pop(0)
+
+ def _fetchmany_impl(self, size=None):
+ result = []
+ for x in range(0, size):
+ row = self._fetchone_impl()
+ if row is None:
+ break
+ result.append(row)
+ return result
+
+ def _fetchall_impl(self):
+ return self.__rowbuffer + list(self.cursor.fetchall())
+
+class BufferedColumnResultProxy(ResultProxy):
"""ResultProxy that loads all columns into memory each time fetchone() is
called. If fetchmany() or fetchall() are called, the full grid of results
is fetched.
"""
-
def _get_col(self, row, key):
rec = self._convert_key(key)
return row[rec[1]]
+
+ def _process_row(self, row):
+ sup = super(BufferedColumnResultProxy, self)
+ row = [sup._get_col(row, i) for i in xrange(len(row))]
+ return RowProxy(self, row)
def fetchall(self):
l = []
while True:
row = self.fetchone()
- if row is not None:
- l.append(row)
- else:
+ if row is None:
break
+ l.append(row)
return l
def fetchmany(self, size=None):
l = []
for i in xrange(size):
row = self.fetchone()
- if row is not None:
- l.append(row)
- else:
+ if row is None:
break
+ l.append(row)
return l
- def fetchone(self):
- sup = super(PrefetchingResultProxy, self)
- row = self.cursor.fetchone()
- if row is not None:
- row = [sup._get_col(row, i) for i in xrange(len(row))]
- return RowProxy(self, row)
- else:
- self.close()
- return None
-
class RowProxy(object):
- """Proxie a single cursor row for a parent ResultProxy.
+ """Proxy a single cursor row for a parent ResultProxy.
Mostly follows "ordered dictionary" behavior, mapping result
values to the string-based column name, the integer position of
self.__parent = parent
self.__row = row
if self.__parent._ResultProxy__echo:
- self.__parent.engine.logger.debug("Row " + repr(row))
+ self.__parent.context.engine.logger.debug("Row " + repr(row))
def close(self):
"""Close the parent ResultProxy."""
class SchemaIterator(schema.SchemaVisitor):
"""A visitor that can gather text into a buffer and execute the contents of the buffer."""
- def __init__(self, engine, proxy, **params):
+ def __init__(self, connection):
"""Construct a new SchemaIterator.
-
- engine
- the Engine used by this SchemaIterator
-
- proxy
- a callable which takes a statement and bind parameters and
- executes it, returning the cursor (the actual DBAPI cursor).
- The callable should use the same cursor repeatedly.
"""
-
- self.proxy = proxy
- self.engine = engine
+ self.connection = connection
self.buffer = StringIO.StringIO()
def append(self, s):
"""Execute the contents of the SchemaIterator's buffer."""
try:
- return self.proxy(self.buffer.getvalue(), None)
+ return self.connection.execute(self.buffer.getvalue())
finally:
self.buffer.truncate(0)
DefaultRunner to allow database-specific behavior.
"""
- def __init__(self, engine, proxy):
- self.proxy = proxy
- self.engine = engine
-
+ def __init__(self, connection):
+ self.connection = connection
+ self.dialect = connection.dialect
+
def get_column_default(self, column):
if column.default is not None:
return column.default.accept_visitor(self)
return None
def exec_default_sql(self, default):
- c = sql.select([default.arg], engine=self.engine).compile()
- return self.proxy(str(c), c.get_params()).fetchone()[0]
+ c = sql.select([default.arg]).compile(engine=self.connection)
+ return self.connection.execute_compiled(c).scalar()
def visit_column_onupdate(self, onupdate):
if isinstance(onupdate.arg, sql.ClauseElement):
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
- def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', **kwargs):
+ def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
self.convert_unicode = convert_unicode
self.supports_autoclose_results = True
self.encoding = encoding
self.positional = False
self._ischema = None
- self._figure_paramstyle(default=default_paramstyle)
+ self.dbapi = dbapi
+ self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
- def create_execution_context(self):
- return DefaultExecutionContext(self)
+ def create_execution_context(self, **kwargs):
+ return DefaultExecutionContext(self, **kwargs)
def type_descriptor(self, typeobj):
"""Provide a database-specific ``TypeEngine`` object, given
# TODO: probably raise this and fill out
# db modules better
return 30
+
+ def supports_alter(self):
+ return True
def oid_column_name(self, column):
return None
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
- def defaultrunner(self, engine, proxy):
- return base.DefaultRunner(engine, proxy)
-
- def create_cursor(self, connection):
- return connection.cursor()
-
- def create_result_proxy_args(self, connection, cursor):
- return dict(should_prefetch=False)
+ def defaultrunner(self, connection):
+ return base.DefaultRunner(connection)
def _set_paramstyle(self, style):
self._paramstyle = style
return parameters
def _figure_paramstyle(self, paramstyle=None, default='named'):
- db = self.dbapi()
if paramstyle is not None:
self._paramstyle = paramstyle
- elif db is not None:
- self._paramstyle = db.paramstyle
+ elif self.dbapi is not None:
+ self._paramstyle = self.dbapi.paramstyle
else:
self._paramstyle = default
raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
def _get_ischema(self):
- # We use a property for ischema so that the accessor
- # creation only happens as needed, since otherwise we
- # have a circularity problem with the generic
- # ansisql.engine()
if self._ischema is None:
import sqlalchemy.databases.information_schema as ischema
self._ischema = ischema.ISchema(self)
ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
class DefaultExecutionContext(base.ExecutionContext):
- def __init__(self, dialect):
+ def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
self.dialect = dialect
+ self.connection = connection
+ self.compiled = compiled
+ self.compiled_parameters = compiled_parameters
+
+ if compiled is not None:
+ self.typemap = compiled.typemap
+ self.column_labels = compiled.column_labels
+ self.statement = unicode(compiled)
+ else:
+ self.typemap = self.column_labels = None
+ self.parameters = parameters
+ self.statement = statement
- def pre_exec(self, engine, proxy, compiled, parameters):
- self._process_defaults(engine, proxy, compiled, parameters)
+ if not dialect.supports_unicode_statements():
+ self.statement = self.statement.encode('ascii')
+
+ self.cursor = self.create_cursor()
+
+ engine = property(lambda s:s.connection.engine)
+
+ def is_select(self):
+ return re.match(r'SELECT', self.statement.lstrip(), re.I)
+
+ def create_cursor(self):
+ return self.connection.connection.cursor()
+
+ def pre_exec(self):
+ self._process_defaults()
+ self.parameters = self.dialect.convert_compiled_params(self.compiled_parameters)
- def post_exec(self, engine, proxy, compiled, parameters):
+ def post_exec(self):
pass
- def get_rowcount(self, cursor):
+ def get_result_proxy(self):
+ return base.ResultProxy(self)
+
+ def get_rowcount(self):
if hasattr(self, '_rowcount'):
return self._rowcount
else:
- return cursor.rowcount
+ return self.cursor.rowcount
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount()
def lastrow_has_defaults(self):
return self._lastrow_has_defaults
- def set_input_sizes(self, cursor, parameters):
+ def set_input_sizes(self):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DBAPI types
from the bind parameter's ``TypeEngine`` objects.
"""
- if isinstance(parameters, list):
- plist = parameters
+ if isinstance(self.compiled_parameters, list):
+ plist = self.compiled_parameters
else:
- plist = [parameters]
+ plist = [self.compiled_parameters]
if self.dialect.positional:
inputsizes = []
for params in plist[0:1]:
for key in params.positional:
typeengine = params.binds[key].type
- dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module)
+ dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes.append(dbtype)
- cursor.setinputsizes(*inputsizes)
+ self.cursor.setinputsizes(*inputsizes)
else:
inputsizes = {}
for params in plist[0:1]:
for key in params.keys():
typeengine = params.binds[key].type
- dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module)
+ dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes[key] = dbtype
- cursor.setinputsizes(**inputsizes)
+ self.cursor.setinputsizes(**inputsizes)
- def _process_defaults(self, engine, proxy, compiled, parameters):
+ def _process_defaults(self):
"""``INSERT`` and ``UPDATE`` statements, when compiled, may
have additional columns added to their ``VALUES`` and ``SET``
lists corresponding to column defaults/onupdates that are
present on the ``Table`` object (i.e. ``ColumnDefault``,
``Sequence``, ``PassiveDefault``). This method pre-execs
those ``DefaultGenerator`` objects that require pre-execution
- and sets their values within the parameter list, and flags the
- thread-local state about ``PassiveDefault`` objects that may
+ and sets their values within the parameter list, and flags this
+ ExecutionContext about ``PassiveDefault`` objects that may
require post-fetching the row after it is inserted/updated.
This method relies upon logic within the ``ANSISQLCompiler``
statement.
"""
- if compiled is None: return
-
- if getattr(compiled, "isinsert", False):
- if isinstance(parameters, list):
- plist = parameters
+ if self.compiled.isinsert:
+ if isinstance(self.compiled_parameters, list):
+ plist = self.compiled_parameters
else:
- plist = [parameters]
- drunner = self.dialect.defaultrunner(engine, proxy)
+ plist = [self.compiled_parameters]
+ drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
self._lastrow_has_defaults = False
for param in plist:
last_inserted_ids = []
need_lastrowid=False
# check the "default" status of each column in the table
- for c in compiled.statement.table.c:
+ for c in self.compiled.statement.table.c:
# check if it will be populated by a SQL clause - we'll need that
# after execution.
- if c in compiled.inline_params:
+ if c in self.compiled.inline_params:
self._lastrow_has_defaults = True
if c.primary_key:
need_lastrowid = True
# check if its not present at all. see if theres a default
# and fire it off, and add to bind parameters. if
# its a pk, add the value to our last_inserted_ids list,
- # or, if its a SQL-side default, dont do any of that, but we'll need
+ # or, if its a SQL-side default, let it fire off on the DB side, but we'll need
# the SQL-generated value after execution.
elif not c.key in param or param.get_original(c.key) is None:
if isinstance(c.default, schema.PassiveDefault):
else:
self._last_inserted_ids = last_inserted_ids
self._last_inserted_params = param
- elif getattr(compiled, 'isupdate', False):
- if isinstance(parameters, list):
- plist = parameters
+ elif self.compiled.isupdate:
+ if isinstance(self.compiled_parameters, list):
+ plist = self.compiled_parameters
else:
- plist = [parameters]
- drunner = self.dialect.defaultrunner(engine, proxy)
+ plist = [self.compiled_parameters]
+ drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
self._lastrow_has_defaults = False
for param in plist:
# check the "onupdate" status of each column in the table
- for c in compiled.statement.table.c:
+ for c in self.compiled.statement.table.c:
# it will be populated by a SQL clause - we'll need that
# after execution.
- if c in compiled.inline_params:
+ if c in self.compiled.inline_params:
pass
# its not in the bind parameters, and theres an "onupdate" defined for the column;
# execute it and add to bind params
if k in kwargs:
dialect_args[k] = kwargs.pop(k)
+ dbapi = kwargs.pop('module', None)
+ if dbapi is None:
+ dbapi_args = {}
+ for k in util.get_func_kwargs(module.dbapi):
+ if k in kwargs:
+ dbapi_args[k] = kwargs.pop(k)
+ dbapi = module.dbapi(**dbapi_args)
+
+ dialect_args['dbapi'] = dbapi
+
# create dialect
dialect = module.dialect(**dialect_args)
# look for existing pool or create
pool = kwargs.pop('pool', None)
if pool is None:
- dbapi = kwargs.pop('module', dialect.dbapi())
- if dbapi is None:
- raise exceptions.InvalidRequestError("Can't get DBAPI module for dialect '%s'" % dialect)
-
def connect():
try:
return dbapi.connect(*cargs, **cparams)
poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool))
pool_args = {}
+
# consume pool arguments from kwargs, translating a few of the arguments
for k in util.get_cls_kwargs(poolclass):
tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k)
return threadlocal.TLEngine
ThreadLocalEngineStrategy()
+
+
+class MockEngineStrategy(EngineStrategy):
+ """Produces a single Connection object which dispatches statement executions
+ to a passed-in function"""
+ def __init__(self):
+ EngineStrategy.__init__(self, 'mock')
+
+ def create(self, name_or_url, executor, **kwargs):
+ # create url.URL object
+ u = url.make_url(name_or_url)
+
+ # get module from sqlalchemy.databases
+ module = u.get_module()
+
+ dialect_args = {}
+ # consume dialect arguments from kwargs
+ for k in util.get_cls_kwargs(module.dialect):
+ if k in kwargs:
+ dialect_args[k] = kwargs.pop(k)
+
+ # create dialect
+ dialect = module.dialect(**dialect_args)
+
+ return MockEngineStrategy.MockConnection(dialect, executor)
+
+ class MockConnection(base.Connectable):
+ def __init__(self, dialect, execute):
+ self._dialect = dialect
+ self.execute = execute
+
+ engine = property(lambda s: s)
+ dialect = property(lambda s:s._dialect)
+
+ def contextual_connect(self):
+ return self
+
+ def create(self, entity, **kwargs):
+ kwargs['checkfirst'] = False
+ entity.accept_visitor(self.dialect.schemagenerator(self, **kwargs))
+
+ def drop(self, entity, **kwargs):
+ kwargs['checkfirst'] = False
+ entity.accept_visitor(self.dialect.schemadropper(self, **kwargs))
+
+ def execute(self, object, *multiparams, **params):
+ raise NotImplementedError()
+
+MockEngineStrategy()
\ No newline at end of file
def get_module(self):
"""Return the SQLAlchemy database module corresponding to this URL's driver name."""
+ if self.drivername == 'ansi':
+ import sqlalchemy.ansisql
+ return sqlalchemy.ansisql
+
try:
return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
except ImportError:
# py2.5 absolute imports will fix....
logging = __import__('logging')
-# turn off logging at the root sqlalchemy level
-logging.getLogger('sqlalchemy').setLevel(logging.ERROR)
+
+logging.getLogger('sqlalchemy').setLevel(logging.WARN)
default_enabled = False
def default_logging(name):
raise
if self.__pool.echo:
self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
-
+
+ _logger = property(lambda self: self.__pool.logger)
+
def invalidate(self):
if self.connection is None:
raise exceptions.InvalidRequestError("This connection is closed")
def cursor(self, *args, **kwargs):
try:
- return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+ c = self.connection.cursor(*args, **kwargs)
+ return _CursorFairy(self, c)
except Exception, e:
self.invalidate()
raise
def invalidate(self):
self.__parent.invalidate()
-
+
def close(self):
if self in self.__parent._cursors:
del self.__parent._cursors[self]
- self.cursor.close()
+ try:
+ self.cursor.close()
+ except Exception, e:
+ self.__parent._logger.warn("Error closing cursor: " + str(e))
def __getattr__(self, key):
return getattr(self.cursor, key)
return d
def __repr__(self):
- return repr(self.get_original_dict())
+ return self.__class__.__name__ + ":" + repr(self.get_original_dict())
class ClauseVisitor(object):
"""A class that knows how to traverse and visit
def __init__(self, *args, **params):
pass
- def engine_impl(self, engine):
- """Deprecated; call dialect_impl with a dialect directly."""
-
- return self.dialect_impl(engine.dialect)
-
def dialect_impl(self, dialect):
try:
return self.impl_dict[dialect]
except KeyError:
return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self))
- def _get_impl(self):
- if hasattr(self, '_impl'):
- return self._impl
- else:
- return NULLTYPE
-
- def _set_impl(self, impl):
- self._impl = impl
-
- impl = property(_get_impl, _set_impl)
-
def get_col_spec(self):
raise NotImplementedError()
def adapt(self, cls):
return cls()
-
+
+ def get_search_list(self):
+ """return a list of classes to test for a match
+ when adapting this type to a dialect-specific type.
+
+ """
+
+ return self.__class__.__mro__[0:-1]
+
class TypeDecorator(AbstractType):
def __init__(self, *args, **kwargs):
if not hasattr(self.__class__, 'impl'):
raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
self.impl = self.__class__.impl(*args, **kwargs)
- def engine_impl(self, engine):
- return self.dialect_impl(engine.dialect)
-
def dialect_impl(self, dialect):
try:
return self.impl_dict[dialect]
except:
- # see if the dialect has an adaptation of the TypeDecorator itself
- adapted_decorator = dialect.type_descriptor(self)
- if adapted_decorator is not self:
- result = adapted_decorator.dialect_impl(dialect)
- self.impl_dict[dialect] = result
- return result
typedesc = dialect.type_descriptor(self.impl)
tt = self.copy()
if not isinstance(tt, self.__class__):
def adapt_type(typeobj, colspecs):
if isinstance(typeobj, type):
typeobj = typeobj()
-
- for t in typeobj.__class__.__mro__[0:-1]:
+ for t in typeobj.get_search_list():
try:
impltype = colspecs[t]
break
return value
class String(TypeEngine):
- def __new__(cls, *args, **kwargs):
- if cls is not String or len(args) > 0 or kwargs.has_key('length'):
- return super(String, cls).__new__(cls, *args, **kwargs)
- else:
- return super(String, TEXT).__new__(TEXT, *args, **kwargs)
-
- def __init__(self, length = None):
+ def __init__(self, length=None, convert_unicode=False):
self.length = length
+ self.convert_unicode = convert_unicode
def adapt(self, impltype):
- return impltype(length=self.length)
+ return impltype(length=self.length, convert_unicode=self.convert_unicode)
def convert_bind_param(self, value, dialect):
- if not dialect.convert_unicode or value is None or not isinstance(value, unicode):
+ if not (self.convert_unicode or dialect.convert_unicode) or value is None or not isinstance(value, unicode):
return value
else:
return value.encode(dialect.encoding)
+ def get_search_list(self):
+ l = super(String, self).get_search_list()
+ if self.length is None:
+ return (TEXT,) + l
+ else:
+ return l
+
def convert_result_value(self, value, dialect):
- if not dialect.convert_unicode or value is None or isinstance(value, unicode):
+ if not (self.convert_unicode or dialect.convert_unicode) or value is None or isinstance(value, unicode):
return value
else:
return value.decode(dialect.encoding)
def compare_values(self, x, y):
return x == y
-class Unicode(TypeDecorator):
- impl = String
-
- def convert_bind_param(self, value, dialect):
- if value is not None and isinstance(value, unicode):
- return value.encode(dialect.encoding)
- else:
- return value
-
- def convert_result_value(self, value, dialect):
- if value is not None and not isinstance(value, unicode):
- return value.decode(dialect.encoding)
- else:
- return value
-
+class Unicode(String):
+ def __init__(self, length=None, **kwargs):
+ kwargs['convert_unicode'] = True
+ super(Unicode, self).__init__(length=length, **kwargs)
+
class Integer(TypeEngine):
"""Integer datatype."""
def convert_bind_param(self, value, dialect):
if value is not None:
- return dialect.dbapi().Binary(value)
+ return dialect.dbapi.Binary(value)
else:
return None
kw.append(vn)
return kw
+def get_func_kwargs(func):
+ """Return the full set of legal kwargs for the given `func`."""
+ return [vn for vn in func.func_code.co_varnames]
+
class SimpleProperty(object):
"""A *default* property accessor."""
# insure this doesnt crash
print [t for t in metadata.table_iterator()]
buf = StringIO.StringIO()
- def foo(s, p):
+ def foo(s, p=None):
buf.write(s)
- gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo, None)
+ gen = create_engine(testbase.db.name + "://", strategy="mock", executor=foo)
+ gen = gen.dialect.schemagenerator(gen)
gen.traverse(table1)
gen.traverse(table2)
buf = buf.getvalue()
try:
compile_mappers()
except exceptions.ArgumentError, ar:
- assert str(ar) == "Cant determine relation direction for relationship 'Person.manager (Manager)' - foreign key columns are present in both the parent and the child's mapped tables. Specify 'foreign_keys' argument."
+ assert str(ar) == "Can't determine relation direction for relationship 'Person.manager (Manager)' - foreign key columns are present in both the parent and the child's mapped tables. Specify 'foreign_keys' argument.", str(ar)
clear_mappers()
'addresses':relation(Address, lazy=True)
})
mapper(Address, addresses)
- query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True)
+ query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.user_id', addresses.c.address_id])
q = create_session().query(User)
def go():
})
mapper(Address, addresses)
- selectquery = users.outerjoin(addresses).select(use_labels=True)
+ selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.user_id, addresses.c.address_id])
q = create_session().query(User)
def go():
mapper(Address, addresses)
adalias = addresses.alias('adalias')
- selectquery = users.outerjoin(adalias).select(use_labels=True)
+ selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
q = create_session().query(User)
def go():
mapper(Address, addresses)
adalias = addresses.alias('adalias')
- selectquery = users.outerjoin(adalias).select(use_labels=True)
+ selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
q = create_session().query(User)
def go():
mapper(Address, addresses)
adalias = addresses.alias('adalias')
- selectquery = users.outerjoin(adalias).select(use_labels=True)
+ selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
def decorate(row):
d = {}
for c in addresses.columns:
(user7, user8, user9) = sess.query(User).select()
(address1, address2, address3, address4) = sess.query(Address).select()
- selectquery = users.outerjoin(addresses).select(use_labels=True)
+ selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.user_id, addresses.c.address_id])
q = sess.query(User)
l = q.instances(selectquery.execute(), Address)
# note the result is a cartesian product
capt = []
connection = testbase.db.connect()
- def proxy(statement, parameters):
- capt.append(statement)
- capt.append(repr(parameters))
- connection.proxy(statement, parameters)
- schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy, connection)
+ ex = connection._execute
+ def proxy(context):
+ capt.append(context.statement)
+ capt.append(repr(context.parameters))
+ ex(context)
+ connection._execute = proxy
+ schemagen = testbase.db.dialect.schemagenerator(connection)
schemagen.traverse(events)
assert capt[0].strip().startswith('CREATE TABLE events')
Column('__parent', VARCHAR(20)),
Column('__row', VARCHAR(20)),
)
- shadowed.create()
+ shadowed.create(checkfirst=True)
try:
shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row')
r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone()
pass # expected
r.close()
finally:
- shadowed.drop()
+ shadowed.drop(checkfirst=True)
class CompoundTest(PersistTest):
"""test compound statements like UNION, INTERSECT, particularly their ability to nest on
import sqlalchemy.engine.url as url
import sqlalchemy.types
-
+from sqlalchemy.databases import mssql, oracle
db = testbase.db
class MyDecoratedType(types.TypeDecorator):
impl = String
- def convert_bind_param(self, value, engine):
- return "BIND_IN"+ value
- def convert_result_value(self, value, engine):
- return value + "BIND_OUT"
+ def convert_bind_param(self, value, dialect):
+ return "BIND_IN"+ super(MyDecoratedType, self).convert_bind_param(value, dialect)
+ def convert_result_value(self, value, dialect):
+ return super(MyDecoratedType, self).convert_result_value(value, dialect) + "BIND_OUT"
def copy(self):
return MyDecoratedType()
-class MyUnicodeType(types.Unicode):
- def convert_bind_param(self, value, engine):
- return "UNI_BIND_IN"+ value
- def convert_result_value(self, value, engine):
- return value + "UNI_BIND_OUT"
+class MyUnicodeType(types.TypeDecorator):
+ impl = Unicode
+ def convert_bind_param(self, value, dialect):
+ return "UNI_BIND_IN"+ super(MyUnicodeType, self).convert_bind_param(value, dialect)
+ def convert_result_value(self, value, dialect):
+ return super(MyUnicodeType, self).convert_result_value(value, dialect) + "UNI_BIND_OUT"
def copy(self):
return MyUnicodeType(self.impl.length)
assert t2 != t3
assert t3 != t1
- def testdecorator(self):
- t1 = Unicode(20)
- t2 = Unicode()
- assert isinstance(t1.impl, String)
- assert not isinstance(t1.impl, TEXT)
- assert (t1.impl.length == 20)
- assert isinstance(t2.impl, TEXT)
- assert t2.impl.length is None
-
-
- def testdialecttypedecorators(self):
- """test that a a Dialect can provide a dialect-specific subclass of a TypeDecorator subclass."""
- import sqlalchemy.databases.mssql as mssql
+ def testmsnvarchar(self):
dialect = mssql.MSSQLDialect()
# run the test twice to insure the caching step works too
for x in range(0, 1):
col = Column('', Unicode(length=10))
dialect_type = col.type.dialect_impl(dialect)
- assert isinstance(dialect_type, mssql.MSUnicode)
+ assert isinstance(dialect_type, mssql.MSNVarchar)
assert dialect_type.get_col_spec() == 'NVARCHAR(10)'
- assert isinstance(dialect_type.impl, mssql.MSString)
-
+
+ def testoracletext(self):
+ dialect = oracle.OracleDialect()
+ col = Column('', MyDecoratedType)
+ dialect_type = col.type.dialect_impl(dialect)
+ assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl)
+
class OverrideTest(PersistTest):
"""tests user-defined types, including a full type as well as a TypeDecorator"""
+ def testbasic(self):
+ print users.c.goofy4.type
+ print users.c.goofy4.type.dialect_impl(testbase.db.dialect)
+ print users.c.goofy4.type.dialect_impl(testbase.db.dialect).get_col_spec()
+
def testprocessing(self):
global users
import sys
sys.path.insert(0, './lib/')
-import os
-import unittest
-import StringIO
-import sqlalchemy.ext.proxy as proxy
-import re
+import os, unittest, StringIO, re
import sqlalchemy
from sqlalchemy import sql, engine, pool
+import sqlalchemy.engine.base as base
import optparse
from sqlalchemy.schema import BoundMetaData
from sqlalchemy.orm import clear_mappers
parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)")
parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running")
parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)")
+ parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG")
(options, args) = parser.parse_args()
sys.argv[1:] = args
db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
elif DBTYPE == 'oracle8':
db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
- opts = {'use_ansi':False}
+ opts['use_ansi'] = False
elif DBTYPE == 'mssql':
db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test'
elif DBTYPE == 'firebird':
global with_coverage
with_coverage = options.coverage
+
+ if options.serverside:
+ opts['server_side_cursors'] = True
if options.enginestrategy is not None:
opts['strategy'] = options.enginestrategy
db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts)
else:
db = engine.create_engine(db_uri, **opts)
- db = EngineAssert(db)
+
+ # decorate the dialect's create_execution_context() method
+ # to produce a wrapper
+ create_context = db.dialect.create_execution_context
+ def create_exec_context(*args, **kwargs):
+ return ExecutionContextWrapper(create_context(*args, **kwargs))
+ db.dialect.create_execution_context = create_exec_context
+
+ global testdata
+ testdata = TestData(db)
if options.topological:
from sqlalchemy.orm import unitofwork
"""overridden to not return docstrings"""
return None
-
-
class AssertMixin(PersistTest):
"""given a list-based structure of keys/properties which represent information within an object structure, and
a list of actual objects, asserts that the list of objects corresponds to the structure."""
else:
self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
def assert_sql(self, db, callable_, list, with_sequences=None):
+ global testdata
+ testdata = TestData(db)
if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'):
- db.set_assert_list(self, with_sequences)
+ testdata.set_assert_list(self, with_sequences)
else:
- db.set_assert_list(self, list)
+ testdata.set_assert_list(self, list)
try:
callable_()
finally:
- db.set_assert_list(None, None)
+ testdata.set_assert_list(None, None)
+
def assert_sql_count(self, db, callable_, count):
- db.sql_count = 0
+ global testdata
+ testdata = TestData(db)
try:
callable_()
finally:
- self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count))
+ self.assert_(testdata.sql_count == count, "desired statement count %d does not match %d" % (count, testdata.sql_count))
class ORMTest(AssertMixin):
keep_mappers = False
for t in metadata.table_iterator(reverse=True):
t.delete().execute().close()
-class EngineAssert(proxy.BaseProxyEngine):
- """decorates a SQLEngine object to match the incoming queries against a set of assertions."""
+class TestData(object):
def __init__(self, engine):
self._engine = engine
-
- self.real_execution_context = engine.dialect.create_execution_context
- engine.dialect.create_execution_context = self.execution_context
-
self.logger = engine.logger
self.set_assert_list(None, None)
self.sql_count = 0
- def get_engine(self):
- return self._engine
- def set_engine(self, e):
- self._engine = e
+
def set_assert_list(self, unittest, list):
self.unittest = unittest
self.assert_list = list
if list is not None:
self.assert_list.reverse()
- def _set_echo(self, echo):
- self.engine.echo = echo
- echo = property(lambda s: s.engine.echo, _set_echo)
- def execution_context(self):
- def post_exec(engine, proxy, compiled, parameters, **kwargs):
- ctx = e
- self.engine.logger = self.logger
- statement = unicode(compiled)
- statement = re.sub(r'\n', '', statement)
-
- if self.assert_list is not None:
- item = self.assert_list[-1]
- if not isinstance(item, dict):
- item = self.assert_list.pop()
- else:
- # asserting a dictionary of statements->parameters
- # this is to specify query assertions where the queries can be in
- # multiple orderings
- if not item.has_key('_converted'):
- for key in item.keys():
- ckey = self.convert_statement(key)
- item[ckey] = item[key]
- if ckey != key:
- del item[key]
- item['_converted'] = True
- try:
- entry = item.pop(statement)
- if len(item) == 1:
- self.assert_list.pop()
- item = (statement, entry)
- except KeyError:
- self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
-
- (query, params) = item
- if callable(params):
- params = params(ctx)
- if params is not None and isinstance(params, list) and len(params) == 1:
- params = params[0]
-
- if isinstance(parameters, sql.ClauseParameters):
- parameters = parameters.get_original_dict()
- elif isinstance(parameters, list):
- parameters = [p.get_original_dict() for p in parameters]
-
- query = self.convert_statement(query)
- self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- self.sql_count += 1
- return realexec(ctx, proxy, compiled, parameters, **kwargs)
-
- e = self.real_execution_context()
- realexec = e.post_exec
- realexec.im_self.post_exec = post_exec
- return e
+class ExecutionContextWrapper(object):
+ def __init__(self, ctx):
+ self.__dict__['ctx'] = ctx
+ def __getattr__(self, key):
+ return getattr(self.ctx, key)
+ def __setattr__(self, key, value):
+ setattr(self.ctx, key, value)
+
+ def post_exec(self):
+ ctx = self.ctx
+ statement = unicode(ctx.compiled)
+ statement = re.sub(r'\n', '', ctx.statement)
+
+ if testdata.assert_list is not None:
+ item = testdata.assert_list[-1]
+ if not isinstance(item, dict):
+ item = testdata.assert_list.pop()
+ else:
+ # asserting a dictionary of statements->parameters
+ # this is to specify query assertions where the queries can be in
+ # multiple orderings
+ if not item.has_key('_converted'):
+ for key in item.keys():
+ ckey = self.convert_statement(key)
+ item[ckey] = item[key]
+ if ckey != key:
+ del item[key]
+ item['_converted'] = True
+ try:
+ entry = item.pop(statement)
+ if len(item) == 1:
+ testdata.assert_list.pop()
+ item = (statement, entry)
+ except KeyError:
+ self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
+
+ (query, params) = item
+ if callable(params):
+ params = params(ctx)
+ if params is not None and isinstance(params, list) and len(params) == 1:
+ params = params[0]
+
+ if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
+ parameters = ctx.compiled_parameters.get_original_dict()
+ elif isinstance(ctx.compiled_parameters, list):
+ parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
+
+ query = self.convert_statement(query)
+ testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
+ testdata.sql_count += 1
+ self.ctx.post_exec()
def convert_statement(self, query):
- paramstyle = self.engine.dialect.paramstyle
+ paramstyle = self.ctx.dialect.paramstyle
if paramstyle == 'named':
pass
elif paramstyle =='pyformat':