]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged the "execcontext" branch, refactors engine/dialect codepaths
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Apr 2007 21:36:11 +0000 (21:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Apr 2007 21:36:11 +0000 (21:36 +0000)
- much more functionality moved into ExecutionContext, which impacted
the API used by dialects to some degree
- ResultProxy and subclasses now designed sanely
- merged patch for #522, Unicode subclasses String directly,
MSNVarchar implements for MS-SQL, removed MSUnicode.
- String moves its "VARCHAR"/"TEXT" switchy thing into
"get_search_list()" function, which VARCHAR and CHAR can override
to not return TEXT in any case (didnt do the latter yet)
- implements server side cursors for postgres, unit tests, #514
- includes overhaul of dbapi import strategy #480, all dbapi
importing happens in dialect method "dbapi()", is only called
inside of create_engine() for default and threadlocal strategies.
Dialect subclasses have a datamember "dbapi" referencing the loaded
module which may be None.
- added "mock" engine strategy, doesnt require DBAPI module and
gives you a "Connecition" which just sends all executes to a callable.
can be used to create string output of create_all()/drop_all().

24 files changed:
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/url.py
lib/sqlalchemy/logging.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
test/engine/reflection.py
test/orm/inheritance5.py
test/orm/mapper.py
test/sql/constraints.py
test/sql/query.py
test/sql/testtypes.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index fc8077167f0a137ed7870ba0d56c59a7f8f4c377..41a2ac3837f27a872e1917a83f773f3fc9525fc5 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,5 +1,22 @@
 0.3.7
+- engines
+    - SA default loglevel is now "WARN".  we have a few warnings
+      for things that should be available by default.
+    - cleanup of DBAPI import strategies across all engines 
+      [ticket:480]
+    - refactoring of engine internals which reduces complexity,
+      number of codepaths; places more state inside of ExecutionContext
+      to allow more dialect control of cursor handling, result sets.
+      ResultProxy totally refactored and also has two versions of
+      "buffered" result sets used for different purposes.
+    - server side cursor support fully functional in postgres
+      [ticket:514].
 - sql:
+    - the Unicode type is now a direct subclass of String, which now
+      contains all the "convert_unicode" logic.  This helps the variety
+      of unicode situations that occur in db's such as MS-SQL to be
+      better handled and allows subclassing of the Unicode datatype.
+      [ticket:522]
     - column labels are now generated in the compilation phase, which
       means their lengths are dialect-dependent.  So on oracle a label
       that gets truncated to 30 chars will go out to 63 characters
@@ -11,7 +28,8 @@
       full statement being compiled.  this means the same statement
       will produce the same string across application restarts and
       allowing DB query plan caching to work better.
-    - preliminary support for unicode table and column names added.
+    - preliminary support for unicode table names, column names and 
+      SQL statements added, for databases which can support them.
     - fix for fetchmany() "size" argument being positional in most
       dbapis [ticket:505]
     - sending None as an argument to func.<something> will produce
index a75263d915b44d60c3bd61ab1275907819b30839..03053b998c18b1f90c854e9360c751fcc79947f6 100644 (file)
@@ -49,14 +49,11 @@ class ANSIDialect(default.DefaultDialect):
     def create_connect_args(self):
         return ([],{})
 
-    def dbapi(self):
-        return None
+    def schemagenerator(self, *args, **kwargs):
+        return ANSISchemaGenerator(self, *args, **kwargs)
 
-    def schemagenerator(self, *args, **params):
-        return ANSISchemaGenerator(*args, **params)
-
-    def schemadropper(self, *args, **params):
-        return ANSISchemaDropper(*args, **params)
+    def schemadropper(self, *args, **kwargs):
+        return ANSISchemaDropper(self, *args, **kwargs)
 
     def compiler(self, statement, parameters, **kwargs):
         return ANSICompiler(self, statement, parameters, **kwargs)
@@ -97,6 +94,9 @@ class ANSICompiler(sql.Compiled):
         
         sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
 
+        # if we are insert/update.  set to true when we visit an INSERT or UPDATE
+        self.isinsert = self.isupdate = False
+        
         # a dictionary of bind parameter keys to _BindParamClause instances.
         self.binds = {}
         
@@ -789,13 +789,12 @@ class ANSISchemaBase(engine.SchemaIterator):
         return alterables
 
 class ANSISchemaGenerator(ANSISchemaBase):
-    def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
-        super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs)
+    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+        super(ANSISchemaGenerator, self).__init__(connection, **kwargs)
         self.checkfirst = checkfirst
         self.tables = tables and util.Set(tables) or None
-        self.connection = connection
-        self.preparer = self.engine.dialect.preparer()
-        self.dialect = self.engine.dialect
+        self.preparer = dialect.preparer()
+        self.dialect = dialect
 
     def get_column_specification(self, column, first_pk=False):
         raise NotImplementedError()
@@ -804,7 +803,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
         collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
         for table in collection:
             table.accept_visitor(self)
-        if self.supports_alter():
+        if self.dialect.supports_alter():
             for alterable in self.find_alterables(collection):
                 self.add_foreignkey(alterable)
 
@@ -857,7 +856,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
 
     def _compile(self, tocompile, parameters):
         """compile the given string/parameters using this SchemaGenerator's dialect."""
-        compiler = self.engine.dialect.compiler(tocompile, parameters)
+        compiler = self.dialect.compiler(tocompile, parameters)
         compiler.compile()
         return compiler
 
@@ -880,11 +879,8 @@ class ANSISchemaGenerator(ANSISchemaBase):
         self.append("PRIMARY KEY ")
         self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
 
-    def supports_alter(self):
-        return True
-
     def visit_foreign_key_constraint(self, constraint):
-        if constraint.use_alter and self.supports_alter():
+        if constraint.use_alter and self.dialect.supports_alter():
             return
         self.append(", \n\t ")
         self.define_foreign_key(constraint)
@@ -927,25 +923,21 @@ class ANSISchemaGenerator(ANSISchemaBase):
         self.execute()
 
 class ANSISchemaDropper(ANSISchemaBase):
-    def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
-        super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs)
+    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+        super(ANSISchemaDropper, self).__init__(connection, **kwargs)
         self.checkfirst = checkfirst
         self.tables = tables
-        self.connection = connection
-        self.preparer = self.engine.dialect.preparer()
-        self.dialect = self.engine.dialect
+        self.preparer = dialect.preparer()
+        self.dialect = dialect
 
     def visit_metadata(self, metadata):
         collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or  self.dialect.has_table(self.connection, t.name, schema=t.schema))]
-        if self.supports_alter():
+        if self.dialect.supports_alter():
             for alterable in self.find_alterables(collection):
                 self.drop_foreignkey(alterable)
         for table in collection:
             table.accept_visitor(self)
 
-    def supports_alter(self):
-        return True
-
     def visit_index(self, index):
         self.append("\nDROP INDEX " + index.name)
         self.execute()
@@ -1099,3 +1091,5 @@ class ANSIIdentifierPreparer(object):
         """Prepare a quoted column name with table name."""
         
         return self.format_column(column, use_table=True, name=column_name)
+
+dialect = ANSIDialect
index 91a0869c610a4a5079fc081816d44cd9d40b3ff2..2ab88101a99240cfa591c501f8059409e12d6ec9 100644 (file)
@@ -15,12 +15,9 @@ import sqlalchemy.ansisql as ansisql
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 
-try:
+def dbapi():
     import kinterbasdb
-except:
-    kinterbasdb = None
-
-dbmodule = kinterbasdb
+    return kinterbasdb
 
 _initialized_kb = False
 
@@ -33,7 +30,6 @@ class FBNumeric(sqltypes.Numeric):
             return "NUMERIC(%(precision)s, %(length)s)" % { 'precision': self.precision,
                                                             'length' : self.length }
 
-
 class FBInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
@@ -111,24 +107,11 @@ class FBExecutionContext(default.DefaultExecutionContext):
 
 
 class FBDialect(ansisql.ANSIDialect):
-    def __init__(self, module = None, **params):
-        global _initialized_kb
-        self.module = module or dbmodule
-        self.opts = {}
-
-        if not _initialized_kb:
-            _initialized_kb = True
-            type_conv = params.get('type_conv', 200) or 200
-            if isinstance(type_conv, types.StringTypes):
-                type_conv = int(type_conv)
-
-            concurrency_level = params.get('concurrency_level', 1) or 1
-            if isinstance(concurrency_level, types.StringTypes):
-                concurrency_level = int(concurrency_level)
+    def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
+        ansisql.ANSIDialect.__init__(self, **kwargs)
 
-            if kinterbasdb is not None:
-                kinterbasdb.init(type_conv=type_conv, concurrency_level=concurrency_level)
-        ansisql.ANSIDialect.__init__(self, **params)
+        self.type_conv = type_conv
+        self.concurrency_level= concurrency_level
 
     def create_connect_args(self, url):
         opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
@@ -136,15 +119,17 @@ class FBDialect(ansisql.ANSIDialect):
             opts['host'] = "%s/%s" % (opts['host'], opts['port'])
             del opts['port']
         opts.update(url.query)
-        # pop arguments that we took at the module level
-        opts.pop('type_conv', None)
-        opts.pop('concurrency_level', None)
-        self.opts = opts
 
-        return ([], self.opts)
+        type_conv = opts.pop('type_conv', self.type_conv)
+        concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
+        global _initialized_kb
+        if not _initialized_kb and self.dbapi is not None:
+            _initialized_kb = True
+            self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
+        return ([], opts)
 
-    def create_execution_context(self):
-        return FBExecutionContext(self)
+    def create_execution_context(self, *args, **kwargs):
+        return FBExecutionContext(self, *args, **kwargs)
 
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
@@ -156,13 +141,13 @@ class FBDialect(ansisql.ANSIDialect):
         return FBCompiler(self, statement, bindparams, **kwargs)
 
     def schemagenerator(self, *args, **kwargs):
-        return FBSchemaGenerator(*args, **kwargs)
+        return FBSchemaGenerator(self, *args, **kwargs)
 
     def schemadropper(self, *args, **kwargs):
-        return FBSchemaDropper(*args, **kwargs)
+        return FBSchemaDropper(self, *args, **kwargs)
 
-    def defaultrunner(self, engine, proxy):
-        return FBDefaultRunner(engine, proxy)
+    def defaultrunner(self, connection):
+        return FBDefaultRunner(connection)
 
     def preparer(self):
         return FBIdentifierPreparer(self)
@@ -292,9 +277,6 @@ class FBDialect(ansisql.ANSIDialect):
         for name,value in fks.iteritems():
             table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
 
-    def last_inserted_ids(self):
-        return self.context.last_inserted_ids
-
     def do_execute(self, cursor, statement, parameters, **kwargs):
         cursor.execute(statement, parameters or [])
 
@@ -304,15 +286,6 @@ class FBDialect(ansisql.ANSIDialect):
     def do_commit(self, connection):
         connection.commit(True)
 
-    def connection(self):
-        """Returns a managed DBAPI connection from this SQLEngine's connection pool."""
-        c = self._pool.connect()
-        c.supportsTransactions = 0
-        return c
-
-    def dbapi(self):
-        return self.module
-
 
 class FBCompiler(ansisql.ANSICompiler):
     """Firebird specific idiosincrasies"""
@@ -364,7 +337,7 @@ class FBCompiler(ansisql.ANSICompiler):
 class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+        colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
 
         default = self.get_column_default_string(column)
         if default is not None:
@@ -388,11 +361,11 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper):
 
 class FBDefaultRunner(ansisql.ANSIDefaultRunner):
     def exec_default_sql(self, default):
-        c = sql.select([default.arg], from_obj=["rdb$database"], engine=self.engine).compile()
-        return self.proxy(str(c), c.get_params()).fetchone()[0]
+        c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.engine)
+        return self.connection.execute_compiled(c).scalar()
 
     def visit_sequence(self, seq):
-        return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0]
+        return self.connection.execute_text("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").scalar()
 
 
 RESERVED_WORDS = util.Set(
index 1852edefb8a734b87999c990625a499ee09c8aad..6d2ff66cd594475eda59d7b73d1538bb4ac4d704 100644 (file)
@@ -52,7 +52,22 @@ import sqlalchemy.ansisql as ansisql
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 
-
+def dbapi(module_name=None):
+    if module_name:
+        try:
+            dialect_cls = dialect_mapping[module_name]
+            return dialect_cls.import_dbapi()
+        except KeyError:
+            raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
+    else:
+        for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]:
+            try:
+                return dialect_cls.import_dbapi()
+            except ImportError, e:
+                pass
+        else:
+            raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
+    
 class MSNumeric(sqltypes.Numeric):
     def convert_result_value(self, value, dialect):
         return value
@@ -142,9 +157,6 @@ class MSString(sqltypes.String):
         return "VARCHAR(%(length)s)" % {'length' : self.length}
 
 class MSNVarchar(MSString):
-    """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True.  """
-    impl = sqltypes.Unicode
-
     def get_col_spec(self):
         if self.length:
             return "NVARCHAR(%(length)s)" % {'length' : self.length}
@@ -154,19 +166,7 @@ class MSNVarchar(MSString):
             return "NTEXT"
 
 class AdoMSNVarchar(MSNVarchar):
-    def convert_bind_param(self, value, dialect):
-        return value
-
-    def convert_result_value(self, value, dialect):
-        return value        
-
-class MSUnicode(sqltypes.Unicode):
-    """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl."""
-    impl = MSNVarchar
-
-class AdoMSUnicode(MSUnicode):
-    impl = AdoMSNVarchar
-
+    """overrides bindparam/result processing to not convert any unicode strings"""
     def convert_bind_param(self, value, dialect):
         return value
 
@@ -215,9 +215,9 @@ def descriptor():
     ]}
 
 class MSSQLExecutionContext(default.DefaultExecutionContext):
-    def __init__(self, dialect):
+    def __init__(self, *args, **kwargs):
         self.IINSERT = self.HASIDENT = False
-        super(MSSQLExecutionContext, self).__init__(dialect)
+        super(MSSQLExecutionContext, self).__init__(*args, **kwargs)
 
     def _has_implicit_sequence(self, column):
         if column.primary_key and column.autoincrement:
@@ -227,14 +227,14 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
                     return True
         return False
 
-    def pre_exec(self, engine, proxy, compiled, parameters, **kwargs):
+    def pre_exec(self):
         """MS-SQL has a special mode for inserting non-NULL values
         into IDENTITY columns.
         
         Activate it if the feature is turned on and needed.
         """
-        if getattr(compiled, "isinsert", False):
-            tbl = compiled.statement.table
+        if self.compiled.isinsert:
+            tbl = self.compiled.statement.table
             if not hasattr(tbl, 'has_sequence'):
                 tbl.has_sequence = None
                 for column in tbl.c:
@@ -243,39 +243,43 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
                         break
 
             self.HASIDENT = bool(tbl.has_sequence)
-            if engine.dialect.auto_identity_insert and self.HASIDENT:
-                if isinstance(parameters, list):
-                    self.IINSERT = tbl.has_sequence.key in parameters[0]
+            if self.dialect.auto_identity_insert and self.HASIDENT:
+                if isinstance(self.compiled_parameters, list):
+                    self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0]
                 else:
-                    self.IINSERT = tbl.has_sequence.key in parameters
+                    self.IINSERT = tbl.has_sequence.key in self.compiled_parameters
             else:
                 self.IINSERT = False
 
             if self.IINSERT:
-                proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
+                # TODO: quoting rules for table name here ?
+                self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name)
 
-        super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs)
+        super(MSSQLExecutionContext, self).pre_exec()
 
-    def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
+    def post_exec(self):
         """Turn off the INDENTITY_INSERT mode if it's been activated,
         and fetch recently inserted IDENTIFY values (works only for
         one column).
         """
         
-        if getattr(compiled, "isinsert", False):
+        if self.compiled.isinsert:
             if self.IINSERT:
-                proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name)
+                # TODO: quoting rules for table name here ?
+                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name)
                 self.IINSERT = False
             elif self.HASIDENT:
-                cursor = proxy("SELECT @@IDENTITY AS lastrowid")
-                row = cursor.fetchone()
+                self.cursor.execute("SELECT @@IDENTITY AS lastrowid")
+                row = self.cursor.fetchone()
                 self._last_inserted_ids = [int(row[0])]
                 # print "LAST ROW ID", self._last_inserted_ids
             self.HASIDENT = False
+        super(MSSQLExecutionContext, self).post_exec()
 
 
 class MSSQLDialect(ansisql.ANSIDialect):
     colspecs = {
+        sqltypes.Unicode : MSNVarchar,
         sqltypes.Integer : MSInteger,
         sqltypes.Smallinteger: MSSmallInteger,
         sqltypes.Numeric : MSNumeric,
@@ -283,7 +287,6 @@ class MSSQLDialect(ansisql.ANSIDialect):
         sqltypes.DateTime : MSDateTime,
         sqltypes.Date : MSDate,
         sqltypes.String : MSString,
-        sqltypes.Unicode : MSUnicode,
         sqltypes.Binary : MSBinary,
         sqltypes.Boolean : MSBoolean,
         sqltypes.TEXT : MSText,
@@ -296,7 +299,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
         'smallint' : MSSmallInteger,
         'tinyint' : MSTinyInteger,
         'varchar' : MSString,
-        'nvarchar' : MSUnicode,
+        'nvarchar' : MSNVarchar,
         'char' : MSChar,
         'nchar' : MSNChar,
         'text' : MSText,
@@ -312,30 +315,16 @@ class MSSQLDialect(ansisql.ANSIDialect):
         'image' : MSBinary
     }
 
-    def __new__(cls, module_name=None, *args, **kwargs):
-        module = kwargs.get('module', None)
+    def __new__(cls, dbapi=None, *args, **kwargs):
         if cls != MSSQLDialect:
             return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
-        if module_name:
-            dialect = dialect_mapping.get(module_name)
-            if not dialect:
-                raise exceptions.InvalidRequestError('Unsupported MSSQL module requested (must be adodbpi, pymssql or pyodbc): ' + module_name)
-            if not hasattr(dialect, 'module'):
-                raise dialect.saved_import_error
+        if dbapi:
+            dialect = dialect_mapping.get(dbapi.__name__)
             return dialect(*args, **kwargs)
-        elif module:
-            return object.__new__(cls, *args, **kwargs)
         else:
-            for dialect in dialect_preference:
-                if hasattr(dialect, 'module'):
-                    return dialect(*args, **kwargs)
-            #raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
-            else:
-                return object.__new__(cls, *args, **kwargs)
+            return object.__new__(cls, *args, **kwargs)
                 
-    def __init__(self, module_name=None, module=None, auto_identity_insert=True, **params):
-        if not hasattr(self, 'module'):
-            self.module = module
+    def __init__(self, auto_identity_insert=True, **params):
         super(MSSQLDialect, self).__init__(**params)
         self.auto_identity_insert = auto_identity_insert
         self.text_as_varchar = False
@@ -352,8 +341,8 @@ class MSSQLDialect(ansisql.ANSIDialect):
             self.text_as_varchar = bool(opts.pop('text_as_varchar'))
         return self.make_connect_string(opts)
 
-    def create_execution_context(self):
-        return MSSQLExecutionContext(self)
+    def create_execution_context(self, *args, **kwargs):
+        return MSSQLExecutionContext(self, *args, **kwargs)
 
     def type_descriptor(self, typeobj):
         newobj = sqltypes.adapt_type(typeobj, self.colspecs)
@@ -373,13 +362,13 @@ class MSSQLDialect(ansisql.ANSIDialect):
         return MSSQLCompiler(self, statement, bindparams, **kwargs)
 
     def schemagenerator(self, *args, **kwargs):
-        return MSSQLSchemaGenerator(*args, **kwargs)
+        return MSSQLSchemaGenerator(self, *args, **kwargs)
 
     def schemadropper(self, *args, **kwargs):
-        return MSSQLSchemaDropper(*args, **kwargs)
+        return MSSQLSchemaDropper(self, *args, **kwargs)
 
-    def defaultrunner(self, engine, proxy):
-        return MSSQLDefaultRunner(engine, proxy)
+    def defaultrunner(self, connection, **kwargs):
+        return MSSQLDefaultRunner(connection, **kwargs)
 
     def preparer(self):
         return MSSQLIdentifierPreparer(self)
@@ -411,19 +400,12 @@ class MSSQLDialect(ansisql.ANSIDialect):
     def raw_connection(self, connection):
         """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
         try:
+            # TODO: probably want to move this to individual dialect subclasses to 
+            # save on the exception throw + simplify
             return connection.connection.__dict__['_pymssqlCnx__cnx']
         except:
             return connection.connection.adoConn
 
-    def connection(self):
-        """returns a managed DBAPI connection from this SQLEngine's connection pool."""
-        c = self._pool.connect()
-        c.supportsTransactions = 0
-        return c
-  
-    def dbapi(self):
-        return self.module
-
     def uppercase_table(self, t):
         # convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive
         t.name = t.name.upper()
@@ -558,13 +540,14 @@ class MSSQLDialect(ansisql.ANSIDialect):
             table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
 
 class MSSQLDialect_pymssql(MSSQLDialect):
-    try:
+    def import_dbapi(cls):
         import pymssql as module
         # pymmsql doesn't have a Binary method.  we use string
+        # TODO: monkeypatching here is less than ideal
         module.Binary = lambda st: str(st)
-    except ImportError, e:
-        saved_import_error = e
-
+        return module
+    import_dbapi = classmethod(import_dbapi)
+    
     def supports_sane_rowcount(self):
         return True
 
@@ -578,7 +561,7 @@ class MSSQLDialect_pymssql(MSSQLDialect):
     def create_connect_args(self, url):
         r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
         if hasattr(self, 'query_timeout'):
-            self.module._mssql.set_query_timeout(self.query_timeout)
+            self.dbapi._mssql.set_query_timeout(self.query_timeout)
         return r
 
     def make_connect_string(self, keys):
@@ -621,15 +604,16 @@ class MSSQLDialect_pymssql(MSSQLDialect):
 ##        r.fetch_array()
 
 class MSSQLDialect_pyodbc(MSSQLDialect):
-    try:
+    
+    def import_dbapi(cls):
         import pyodbc as module
-    except ImportError, e:
-        saved_import_error = e
-
+        return module
+    import_dbapi = classmethod(import_dbapi)
+    
     colspecs = MSSQLDialect.colspecs.copy()
-    colspecs[sqltypes.Unicode] = AdoMSUnicode
+    colspecs[sqltypes.Unicode] = AdoMSNVarchar
     ischema_names = MSSQLDialect.ischema_names.copy()
-    ischema_names['nvarchar'] = AdoMSUnicode
+    ischema_names['nvarchar'] = AdoMSNVarchar
 
     def supports_sane_rowcount(self):
         return False
@@ -648,15 +632,15 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
 
 
 class MSSQLDialect_adodbapi(MSSQLDialect):
-    try:
+    def import_dbapi(cls):
         import adodbapi as module
-    except ImportError, e:
-        saved_import_error = e
+        return module
+    import_dbapi = classmethod(import_dbapi)
 
     colspecs = MSSQLDialect.colspecs.copy()
-    colspecs[sqltypes.Unicode] = AdoMSUnicode
+    colspecs[sqltypes.Unicode] = AdoMSNVarchar
     ischema_names = MSSQLDialect.ischema_names.copy()
-    ischema_names['nvarchar'] = AdoMSUnicode
+    ischema_names['nvarchar'] = AdoMSNVarchar
 
     def supports_sane_rowcount(self):
         return True
@@ -676,13 +660,11 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
             connectors.append("Integrated Security=SSPI")
         return [[";".join (connectors)], {}]
 
-
 dialect_mapping = {
     'pymssql':  MSSQLDialect_pymssql,
     'pyodbc':   MSSQLDialect_pyodbc,
     'adodbapi': MSSQLDialect_adodbapi
     }
-dialect_preference = [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]
 
 
 class MSSQLCompiler(ansisql.ANSICompiler):
@@ -770,7 +752,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
 
 class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
         
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
         if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
@@ -797,6 +779,7 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
         self.execute()
 
 class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
+    # TODO: does ms-sql have standalone sequences ?
     pass
 
 class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
index 5fc63234a0e94bdca6170a5faabe285e798188d1..65ccb6af19f86d818c030be8053c65560f8701c2 100644 (file)
@@ -12,12 +12,9 @@ import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 from array import array
 
-try:
+def dbapi():
     import MySQLdb as mysql
-    import MySQLdb.constants.CLIENT as CLIENT_FLAGS
-except:
-    mysql = None
-    CLIENT_FLAGS = None
+    return mysql
 
 def kw_colspec(self, spec):
     if self.unsigned:
@@ -158,8 +155,6 @@ class MSLongText(MSText):
             return "LONGTEXT"
 
 class MSString(sqltypes.String):
-    def __init__(self, length=None, *extra, **kwargs):
-        sqltypes.String.__init__(self, length=length)
     def get_col_spec(self):
         return "VARCHAR(%(length)s)" % {'length' : self.length}
 
@@ -277,16 +272,12 @@ def descriptor():
     ]}
 
 class MySQLExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
-        if getattr(compiled, "isinsert", False):
-            self._last_inserted_ids = [proxy().lastrowid]
+    def post_exec(self):
+        if self.compiled.isinsert:
+            self._last_inserted_ids = [self.cursor.lastrowid]
 
 class MySQLDialect(ansisql.ANSIDialect):
-    def __init__(self, module = None, **kwargs):
-        if module is None:
-            self.module = mysql
-        else:
-            self.module = module
+    def __init__(self, **kwargs):
         ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
 
     def create_connect_args(self, url):
@@ -305,14 +296,18 @@ class MySQLDialect(ansisql.ANSIDialect):
         # TODO: what about options like "ssl", "cursorclass" and "conv" ?
 
         client_flag = opts.get('client_flag', 0)
-        if CLIENT_FLAGS is not None:
-            client_flag |= CLIENT_FLAGS.FOUND_ROWS
+        if self.dbapi is not None:
+            try:
+                import MySQLdb.constants.CLIENT as CLIENT_FLAGS
+                client_flag |= CLIENT_FLAGS.FOUND_ROWS
+            except:
+                pass
         opts['client_flag'] = client_flag
 
         return [[], opts]
 
-    def create_execution_context(self):
-        return MySQLExecutionContext(self)
+    def create_execution_context(self, *args, **kwargs):
+        return MySQLExecutionContext(self, *args, **kwargs)
 
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
@@ -324,10 +319,10 @@ class MySQLDialect(ansisql.ANSIDialect):
         return MySQLCompiler(self, statement, bindparams, **kwargs)
 
     def schemagenerator(self, *args, **kwargs):
-        return MySQLSchemaGenerator(*args, **kwargs)
+        return MySQLSchemaGenerator(self, *args, **kwargs)
 
     def schemadropper(self, *args, **kwargs):
-        return MySQLSchemaDropper(*args, **kwargs)
+        return MySQLSchemaDropper(self, *args, **kwargs)
 
     def preparer(self):
         return MySQLIdentifierPreparer(self)
@@ -337,14 +332,14 @@ class MySQLDialect(ansisql.ANSIDialect):
             rowcount = cursor.executemany(statement, parameters)
             if context is not None:
                 context._rowcount = rowcount
-        except mysql.OperationalError, o:
+        except self.dbapi.OperationalError, o:
             if o.args[0] == 2006 or o.args[0] == 2014:
                 cursor.invalidate()
             raise o
     def do_execute(self, cursor, statement, parameters, **kwargs):
         try:
             cursor.execute(statement, parameters)
-        except mysql.OperationalError, o:
+        except self.dbapi.OperationalError, o:
             if o.args[0] == 2006 or o.args[0] == 2014:
                 cursor.invalidate()
             raise o
@@ -361,11 +356,9 @@ class MySQLDialect(ansisql.ANSIDialect):
             self._default_schema_name = text("select database()", self).scalar()
         return self._default_schema_name
 
-    def dbapi(self):
-        return self.module
-
     def has_table(self, connection, table_name, schema=None):
-        cursor = connection.execute("show table status like '" + table_name + "'")
+        cursor = connection.execute("show table status like %s", [table_name])
+        print "CURSOR", cursor, "ROWCOUNT", cursor.rowcount, "REAL RC", cursor.cursor.rowcount
         return bool( not not cursor.rowcount )
 
     def reflecttable(self, connection, table):
@@ -492,8 +485,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
 
 class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, first_pk=False):
-        t = column.type.engine_impl(self.engine)
-        colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
index adea127bfedab6130f8ef44150d9c9758f45f354..5377759a2aa61aa019825281e032b1cd586b07b6 100644 (file)
@@ -8,15 +8,13 @@
 import sys, StringIO, string, re
 
 from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
-import sqlalchemy.engine.default as default
+from sqlalchemy.engine import default, base
 import sqlalchemy.types as sqltypes
 
-try:
+def dbapi():
     import cx_Oracle
-except:
-    cx_Oracle = None
+    return cx_Oracle
 
-ORACLE_BINARY_TYPES = [getattr(cx_Oracle, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(cx_Oracle, k)]
 
 class OracleNumeric(sqltypes.Numeric):
     def get_col_spec(self):
@@ -149,26 +147,32 @@ def descriptor():
     ]}
 
 class OracleExecutionContext(default.DefaultExecutionContext):
-    def pre_exec(self, engine, proxy, compiled, parameters):
-        super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters)
+    def pre_exec(self):
+        super(OracleExecutionContext, self).pre_exec()
         if self.dialect.auto_setinputsizes:
-                self.set_input_sizes(proxy(), parameters)
+            self.set_input_sizes()
+
+    def get_result_proxy(self):
+        if self.cursor.description is not None:
+            for column in self.cursor.description:
+                type_code = column[1]
+                if type_code in self.dialect.ORACLE_BINARY_TYPES:
+                    return base.BufferedColumnResultProxy(self)
+        
+        return base.ResultProxy(self)
 
 class OracleDialect(ansisql.ANSIDialect):
-    def __init__(self, use_ansi=True, auto_setinputsizes=True, module=None, threaded=True, **kwargs):
+    def __init__(self, use_ansi=True, auto_setinputsizes=True, threaded=True, **kwargs):
+        ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs)
         self.use_ansi = use_ansi
         self.threaded = threaded
-        if module is None:
-            self.module = cx_Oracle
-        else:
-            self.module = module
-        self.supports_timestamp = hasattr(self.module, 'TIMESTAMP' )
+        self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
         self.auto_setinputsizes = auto_setinputsizes
-        ansisql.ANSIDialect.__init__(self, **kwargs)
-
-    def dbapi(self):
-        return self.module
-
+        if self.dbapi is not None:
+            self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)]
+        else:
+            self.ORACLE_BINARY_TYPES = []
+            
     def create_connect_args(self, url):
         if url.database:
             # if we have a database, then we have a remote host
@@ -177,7 +181,7 @@ class OracleDialect(ansisql.ANSIDialect):
                 port = int(port)
             else:
                 port = 1521
-            dsn = self.module.makedsn(url.host,port,url.database)
+            dsn = self.dbapi.makedsn(url.host,port,url.database)
         else:
             # we have a local tnsname
             dsn = url.host
@@ -206,20 +210,20 @@ class OracleDialect(ansisql.ANSIDialect):
         else:
             return "rowid"
 
-    def create_execution_context(self):
-        return OracleExecutionContext(self)
+    def create_execution_context(self, *args, **kwargs):
+        return OracleExecutionContext(self, *args, **kwargs)
 
     def compiler(self, statement, bindparams, **kwargs):
         return OracleCompiler(self, statement, bindparams, **kwargs)
 
     def schemagenerator(self, *args, **kwargs):
-        return OracleSchemaGenerator(*args, **kwargs)
+        return OracleSchemaGenerator(self, *args, **kwargs)
 
     def schemadropper(self, *args, **kwargs):
-        return OracleSchemaDropper(*args, **kwargs)
+        return OracleSchemaDropper(self, *args, **kwargs)
 
-    def defaultrunner(self, engine, proxy):
-        return OracleDefaultRunner(engine, proxy)
+    def defaultrunner(self, connection, **kwargs):
+        return OracleDefaultRunner(connection, **kwargs)
 
     def has_table(self, connection, table_name, schema=None):
         cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()})
@@ -405,15 +409,6 @@ class OracleDialect(ansisql.ANSIDialect):
         if context is not None:
             context._rowcount = rowcount
 
-    def create_result_proxy_args(self, connection, cursor):
-        args = super(OracleDialect, self).create_result_proxy_args(connection, cursor)
-        if cursor and cursor.description:
-            for column in cursor.description:
-                type_code = column[1]
-                if type_code in ORACLE_BINARY_TYPES:
-                    args['should_prefetch'] = True
-                    break
-        return args
 
 OracleDialect.logger = logging.class_logger(OracleDialect)
 
@@ -569,7 +564,7 @@ class OracleCompiler(ansisql.ANSICompiler):
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+        colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
@@ -579,22 +574,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
         return colspec
 
     def visit_sequence(self, sequence):
-        if not self.engine.dialect.has_sequence(self.connection, sequence.name):
+        if not self.dialect.has_sequence(self.connection, sequence.name):
             self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
             self.execute()
 
 class OracleSchemaDropper(ansisql.ANSISchemaDropper):
     def visit_sequence(self, sequence):
-        if self.engine.dialect.has_sequence(self.connection, sequence.name):
+        if self.dialect.has_sequence(self.connection, sequence.name):
             self.append("DROP SEQUENCE %s" % sequence.name)
             self.execute()
 
 class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
     def exec_default_sql(self, default):
         c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile()
-        return self.proxy(str(c), c.get_params()).fetchone()[0]
+        return self.connection.execute_compiled(c).scalar()
 
     def visit_sequence(self, seq):
-        return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0]
+        return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
 
 dialect = OracleDialect
index d83607793ef0fad7239e7fd4436863a0a7953b6a..2943d163e5960d39941ab7c02c41f620242d386a 100644 (file)
@@ -4,33 +4,28 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import datetime, sys, StringIO, string, types, re
-
-import sqlalchemy.util as util
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
-import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
+import datetime, string, types, re, random
+
+from sqlalchemy import util, sql, schema, ansisql, exceptions
+from sqlalchemy.engine import base, default
 import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
 from sqlalchemy.databases import information_schema as ischema
-import re
 
 try:
     import mx.DateTime.DateTime as mxDateTime
 except:
     mxDateTime = None
 
-try:
-    import psycopg2 as psycopg
-    #import psycopg2.psycopg1 as psycopg
-except:
+def dbapi():
     try:
-        import psycopg
-    except:
-        psycopg = None
-
+        import psycopg2 as psycopg
+    except ImportError, e:
+        try:
+            import psycopg
+        except ImportError, e2:
+            raise e
+    return psycopg
+    
 class PGInet(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "INET"
@@ -74,8 +69,8 @@ class PG1DateTime(sqltypes.DateTime):
                 mx_datetime = mxDateTime(value.year, value.month, value.day,
                                          value.hour, value.minute,
                                          seconds)
-                return psycopg.TimestampFromMx(mx_datetime)
-            return psycopg.TimestampFromMx(value)
+                return dialect.dbapi.TimestampFromMx(mx_datetime)
+            return dialect.dbapi.TimestampFromMx(value)
         else:
             return None
 
@@ -101,7 +96,7 @@ class PG1Date(sqltypes.Date):
         # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
         # this one doesnt seem to work with the "emulation" mode
         if value is not None:
-            return psycopg.DateFromMx(value)
+            return dialect.dbapi.DateFromMx(value)
         else:
             return None
 
@@ -219,44 +214,49 @@ def descriptor():
     ]}
 
 class PGExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
-        if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None:
-            if not engine.dialect.use_oids:
+
+    def is_select(self):
+        return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I)
+    
+    def create_cursor(self):
+        if self.dialect.server_side_cursors and self.is_select():
+            # use server-side cursors:
+            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+            ident = "c" + hex(random.randint(0, 65535))[2:]
+            return self.connection.connection.cursor(ident)
+        else:
+            return self.connection.connection.cursor()
+
+    def get_result_proxy(self):
+        if self.dialect.server_side_cursors and self.is_select():
+            return base.BufferedRowResultProxy(self)
+        else:
+            return base.ResultProxy(self)
+    
+    def post_exec(self):
+        if self.compiled.isinsert and self.last_inserted_ids is None:
+            if not self.dialect.use_oids:
                 pass
                 # will raise invalid error when they go to get them
             else:
-                table = compiled.statement.table
-                cursor = proxy()
-                if cursor.lastrowid is not None and table is not None and len(table.primary_key):
-                    s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid)
-                    c = s.compile(engine=engine)
-                    cursor = proxy(str(c), c.get_params())
-                    row = cursor.fetchone()
+                table = self.compiled.statement.table
+                if self.cursor.lastrowid is not None and table is not None and len(table.primary_key):
+                    s = sql.select(table.primary_key, table.oid_column == self.cursor.lastrowid)
+                    row = self.connection.execute(s).fetchone()
                 self._last_inserted_ids = [v for v in row]
-
+        super(PGExecutionContext, self).post_exec()
+        
 class PGDialect(ansisql.ANSIDialect):
-    def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params):
+    def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs):
+        ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
         self.use_oids = use_oids
         self.server_side_cursors = server_side_cursors
-        if module is None:
-            #if psycopg is None:
-            #    raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument")
-            self.module = psycopg
+        if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'):
+            self.version = 2
         else:
-            self.module = module
-        # figure psycopg version 1 or 2
-        try:
-            if self.module.__version__.startswith('2'):
-                self.version = 2
-            else:
-                self.version = 1
-        except:
             self.version = 1
-        ansisql.ANSIDialect.__init__(self, **params)
         self.use_information_schema = use_information_schema
-        # produce consistent paramstyle even if psycopg2 module not present
-        if self.module is None:
-            self.paramstyle = 'pyformat'
+        self.paramstyle = 'pyformat'
 
     def create_connect_args(self, url):
         opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
@@ -268,16 +268,9 @@ class PGDialect(ansisql.ANSIDialect):
         opts.update(url.query)
         return ([], opts)
 
-    def create_cursor(self, connection):
-        if self.server_side_cursors:
-            # use server-side cursors:
-            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
-            return connection.cursor('x')
-        else:
-            return connection.cursor()
 
-    def create_execution_context(self):
-        return PGExecutionContext(self)
+    def create_execution_context(self, *args, **kwargs):
+        return PGExecutionContext(self, *args, **kwargs)
 
     def max_identifier_length(self):
         return 68
@@ -292,13 +285,13 @@ class PGDialect(ansisql.ANSIDialect):
         return PGCompiler(self, statement, bindparams, **kwargs)
 
     def schemagenerator(self, *args, **kwargs):
-        return PGSchemaGenerator(*args, **kwargs)
+        return PGSchemaGenerator(self, *args, **kwargs)
 
     def schemadropper(self, *args, **kwargs):
-        return PGSchemaDropper(*args, **kwargs)
+        return PGSchemaDropper(self, *args, **kwargs)
 
-    def defaultrunner(self, engine, proxy):
-        return PGDefaultRunner(engine, proxy)
+    def defaultrunner(self, connection, **kwargs):
+        return PGDefaultRunner(connection, **kwargs)
 
     def preparer(self):
         return PGIdentifierPreparer(self)
@@ -326,7 +319,6 @@ class PGDialect(ansisql.ANSIDialect):
         ``psycopg2`` is not nice enough to produce this correctly for
         an executemany, so we do our own executemany here.
         """
-
         rowcount = 0
         for param in parameters:
             c.execute(statement, param)
@@ -334,9 +326,6 @@ class PGDialect(ansisql.ANSIDialect):
         if context is not None:
             context._rowcount = rowcount
 
-    def dbapi(self):
-        return self.module
-
     def has_table(self, connection, table_name, schema=None):
         # seems like case gets folded in pg_class...
         if schema is None:
@@ -542,7 +531,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
             else:
                 colspec += " SERIAL"
         else:
-            colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
@@ -567,8 +556,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
         if column.primary_key:
             # passive defaults on primary keys have to be overridden
             if isinstance(column.default, schema.PassiveDefault):
-                c = self.proxy("select %s" % column.default.arg)
-                return c.fetchone()[0]
+                return self.connection.execute_text("select %s" % column.default.arg).scalar()
             elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
                 sch = column.table.schema
                 # TODO: this has to build into the Sequence object so we can get the quoting
@@ -577,17 +565,13 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
                     exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
                 else:
                     exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
-                c = self.proxy(exc)
-                return c.fetchone()[0]
-            else:
-                return ansisql.ANSIDefaultRunner.get_column_default(self, column)
-        else:
-            return ansisql.ANSIDefaultRunner.get_column_default(self, column)
+                return self.connection.execute_text(exc).scalar()
+
+        return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
 
     def visit_sequence(self, seq):
         if not seq.optional:
-            c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq))
-            return c.fetchone()[0]
+            return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar()
         else:
             return None
 
index b29be9eeddf1af98c0284530dd4b48424a2b5745..9270f2a5ffa2136b404e8cb64cd5c1dd992eaba0 100644 (file)
@@ -12,19 +12,19 @@ import sqlalchemy.engine.default as default
 import sqlalchemy.types as sqltypes
 import datetime,time
 
-pysqlite2_timesupport = False   # Change this if the init.d guys ever get around to supporting time cols
-
-try:
-    from pysqlite2 import dbapi2 as sqlite
-except ImportError:
+def dbapi():
     try:
-        from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
-    except ImportError:
+        from pysqlite2 import dbapi2 as sqlite
+    except ImportError, e:
         try:
-            sqlite = __import__('sqlite') # skip ourselves
-        except:
-            sqlite = None
-
+            from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
+        except ImportError:
+            try:
+                sqlite = __import__('sqlite') # skip ourselves
+            except ImportError:
+                raise e
+    return sqlite
+    
 class SLNumeric(sqltypes.Numeric):
     def get_col_spec(self):
         if self.precision is None:
@@ -140,10 +140,6 @@ pragma_names = {
     'BLOB' : SLBinary,
 }
 
-if pysqlite2_timesupport:
-    colspecs.update({sqltypes.Time : SLTime})
-    pragma_names.update({'TIME' : SLTime})
-
 def descriptor():
     return {'name':'sqlite',
     'description':'SQLite',
@@ -152,25 +148,29 @@ def descriptor():
     ]}
 
 class SQLiteExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
-        if getattr(compiled, "isinsert", False):
-            self._last_inserted_ids = [proxy().lastrowid]
-
+    def post_exec(self):
+        if self.compiled.isinsert:
+            self._last_inserted_ids = [self.cursor.lastrowid]
+        super(SQLiteExecutionContext, self).post_exec()
+        
 class SQLiteDialect(ansisql.ANSIDialect):
     def __init__(self, **kwargs):
+        ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs)
         def vers(num):
             return tuple([int(x) for x in num.split('.')])
-        self.supports_cast = (sqlite is not None and vers(sqlite.sqlite_version) >= vers("3.2.3"))
-        ansisql.ANSIDialect.__init__(self, **kwargs)
+        self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
 
     def compiler(self, statement, bindparams, **kwargs):
         return SQLiteCompiler(self, statement, bindparams, **kwargs)
 
     def schemagenerator(self, *args, **kwargs):
-        return SQLiteSchemaGenerator(*args, **kwargs)
+        return SQLiteSchemaGenerator(self, *args, **kwargs)
 
     def schemadropper(self, *args, **kwargs):
-        return SQLiteSchemaDropper(*args, **kwargs)
+        return SQLiteSchemaDropper(self, *args, **kwargs)
+
+    def supports_alter(self):
+        return False
 
     def preparer(self):
         return SQLiteIdentifierPreparer(self)
@@ -182,8 +182,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
-    def create_execution_context(self):
-        return SQLiteExecutionContext(self)
+    def create_execution_context(self, **kwargs):
+        return SQLiteExecutionContext(self, **kwargs)
 
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
@@ -191,9 +191,6 @@ class SQLiteDialect(ansisql.ANSIDialect):
     def oid_column_name(self, column):
         return "oid"
 
-    def dbapi(self):
-        return sqlite
-
     def has_table(self, connection, table_name, schema=None):
         cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {})
         row = cursor.fetchone()
@@ -321,11 +318,9 @@ class SQLiteCompiler(ansisql.ANSICompiler):
             return ansisql.ANSICompiler.binary_operator_string(self, binary)
 
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def supports_alter(self):
-        return False
 
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
@@ -345,8 +340,7 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
     #        super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
 
 class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
-    def supports_alter(self):
-        return False
+    pass
 
 class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
     def __init__(self, dialect):
index 0baaeb8268baba91df940ea12cb207b64ff8556d..d8a9c52998e5bee12ef046f8497cee4cb1dd6f02 100644 (file)
@@ -83,7 +83,7 @@ class Dialect(sql.AbstractDialect):
         raise NotImplementedError()
 
     def type_descriptor(self, typeobj):
-        """Trasform the type from generic to database-specific.
+        """Transform the type from generic to database-specific.
 
         Provides a database-specific TypeEngine object, given the
         generic object which comes from the types module.  Subclasses
@@ -105,6 +105,10 @@ class Dialect(sql.AbstractDialect):
 
         raise NotImplementedError()
 
+    def supports_alter(self):
+        """return True if the database supports ALTER TABLE."""
+        raise NotImplementedError()
+
     def max_identifier_length(self):
         """Return the maximum length of identifier names.
         
@@ -118,32 +122,43 @@ class Dialect(sql.AbstractDialect):
     def supports_sane_rowcount(self):
         """Indicate whether the dialect properly implements statements rowcount.
 
-        Provided to indicate when MySQL is being used, which does not
-        have standard behavior for the "rowcount" function on a statement handle.
+        This was needed for MySQL which had non-standard behavior of rowcount,
+        but this issue has since been resolved.
         """
 
         raise NotImplementedError()
 
-    def schemagenerator(self, engine, proxy, **params):
+    def schemagenerator(self, connection, **kwargs):
         """Return a ``schema.SchemaVisitor`` instance that can generate schemas.
 
+            connection
+                a Connection to use for statement execution
+                
         `schemagenerator()` is called via the `create()` method on Table,
         Index, and others.
         """
 
         raise NotImplementedError()
 
-    def schemadropper(self, engine, proxy, **params):
+    def schemadropper(self, connection, **kwargs):
         """Return a ``schema.SchemaVisitor`` instance that can drop schemas.
 
+            connection
+                a Connection to use for statement execution
+
         `schemadropper()` is called via the `drop()` method on Table,
         Index, and others.
         """
 
         raise NotImplementedError()
 
-    def defaultrunner(self, engine, proxy, **params):
-        """Return a ``schema.SchemaVisitor`` instance that can execute defaults."""
+    def defaultrunner(self, connection, **kwargs):
+        """Return a ``schema.SchemaVisitor`` instance that can execute defaults.
+        
+            connection
+                a Connection to use for statement execution
+        
+        """
 
         raise NotImplementedError()
 
@@ -154,7 +169,6 @@ class Dialect(sql.AbstractDialect):
         ansisql.ANSICompiler, and will produce a string representation
         of the given ClauseElement and `parameters` dictionary.
 
-        `compiler()` is called within the context of the compile() method.
         """
 
         raise NotImplementedError()
@@ -188,23 +202,13 @@ class Dialect(sql.AbstractDialect):
 
         raise NotImplementedError()
 
-    def dbapi(self):
-        """Establish a connection to the database.
-
-        Subclasses override this method to provide the DBAPI module
-        used to establish connections.
-        """
-
-        raise NotImplementedError()
-
     def get_default_schema_name(self, connection):
         """Return the currently selected schema given a connection"""
 
         raise NotImplementedError()
 
-    def execution_context(self):
+    def create_execution_context(self, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
         """Return a new ExecutionContext object."""
-
         raise NotImplementedError()
 
     def do_begin(self, connection):
@@ -232,15 +236,6 @@ class Dialect(sql.AbstractDialect):
 
         raise NotImplementedError()
 
-    def create_cursor(self, connection):
-        """Return a new cursor generated from the given connection."""
-
-        raise NotImplementedError()
-
-    def create_result_proxy_args(self, connection, cursor):
-        """Return a dictionary of arguments that should be passed to ResultProxy()."""
-
-        raise NotImplementedError()
 
     def compile(self, clauseelement, parameters=None):
         """Compile the given ClauseElement using this Dialect.
@@ -255,42 +250,74 @@ class Dialect(sql.AbstractDialect):
 class ExecutionContext(object):
     """A messenger object for a Dialect that corresponds to a single execution.
 
+    ExecutionContext should have these datamembers:
+    
+        connection
+            Connection object which initiated the call to the
+            dialect to create this ExecutionContext.
+
+        dialect
+            dialect which created this ExecutionContext.
+            
+        cursor
+            DBAPI cursor procured from the connection
+            
+        compiled
+            if passed to constructor, sql.Compiled object being executed
+        
+        compiled_parameters
+            if passed to constructor, sql.ClauseParameters object
+             
+        statement
+            string version of the statement to be executed.  Is either
+            passed to the constructor, or must be created from the 
+            sql.Compiled object by the time pre_exec() has completed.
+            
+        parameters
+            "raw" parameters suitable for direct execution by the
+            dialect.  Either passed to the constructor, or must be
+            created from the sql.ClauseParameters object by the time 
+            pre_exec() has completed.
+            
+    
     The Dialect should provide an ExecutionContext via the
     create_execution_context() method.  The `pre_exec` and `post_exec`
-    methods will be called for compiled statements, afterwhich it is
-    expected that the various methods `last_inserted_ids`,
-    `last_inserted_params`, etc.  will contain appropriate values, if
-    applicable.
+    methods will be called for compiled statements.
+    
     """
 
-    def pre_exec(self, engine, proxy, compiled, parameters):
-        """Called before an execution of a compiled statement.
+    def create_cursor(self):
+        """Return a new cursor generated this ExecutionContext's connection."""
 
-        `proxy` is a callable that takes a string statement and a bind
-        parameter list/dictionary.
+        raise NotImplementedError()
+
+    def pre_exec(self):
+        """Called before an execution of a compiled statement.
+        
+        If compiled and compiled_parameters were passed to this
+        ExecutionContext, the `statement` and `parameters` datamembers
+        must be initialized after this statement is complete.
         """
 
         raise NotImplementedError()
 
-    def post_exec(self, engine, proxy, compiled, parameters):
+    def post_exec(self):
         """Called after the execution of a compiled statement.
-
-        `proxy` is a callable that takes a string statement and a bind
-        parameter list/dictionary.
+        
+        If compiled was passed to this ExecutionContext,
+        the `last_insert_ids`, `last_inserted_params`, etc. 
+        datamembers should be available after this method
+        completes.
         """
 
         raise NotImplementedError()
-
-    def get_rowcount(self, cursor):
-        """Return the count of rows updated/deleted for an UPDATE/DELETE statement."""
-
+    
+    def get_result_proxy(self):
+        """return a ResultProxy corresponding to this ExecutionContext."""
         raise NotImplementedError()
-
-    def supports_sane_rowcount(self):
-        """Indicate if the "rowcount" DBAPI cursor function works properly.
-
-        Currently, MySQLDB does not properly implement this function.
-        """
+        
+    def get_rowcount(self):
+        """Return the count of rows updated/deleted for an UPDATE/DELETE statement."""
 
         raise NotImplementedError()
 
@@ -299,7 +326,7 @@ class ExecutionContext(object):
 
         This does not apply to straight textual clauses; only to
         ``sql.Insert`` objects compiled against a ``schema.Table`` object,
-        which are executed via `statement.execute()`.  The order of
+        which are executed via `execute()`.  The order of
         items in the list is the same as that of the Table's
         'primary_key' attribute.
 
@@ -337,7 +364,7 @@ class ExecutionContext(object):
         raise NotImplementedError()
 
 
-class Connectable(object):
+class Connectable(sql.Executor):
     """Interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
 
     def contextual_connect(self):
@@ -362,6 +389,7 @@ class Connectable(object):
         raise NotImplementedError()
 
     engine = property(_not_impl, doc="The Engine which this Connectable is associated with.")
+    dialect = property(_not_impl, doc="Dialect which this Connectable is associated with.")
 
 class Connection(Connectable):
     """Represent a single DBAPI connection returned from the underlying connection pool.
@@ -385,7 +413,8 @@ class Connection(Connectable):
         except AttributeError:
             raise exceptions.InvalidRequestError("This Connection is closed")
 
-    engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated (read only)")
+    engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.")
+    dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.")
     connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
     should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.")
 
@@ -429,7 +458,7 @@ class Connection(Connectable):
         """When no Transaction is present, this is called after executions to provide "autocommit" behavior."""
         # TODO: have the dialect determine if autocommit can be set on the connection directly without this
         # extra step
-        if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()):
+        if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip(), re.I):
             self._commit_impl()
 
     def _autorollback(self):
@@ -448,6 +477,9 @@ class Connection(Connectable):
     def scalar(self, object, *multiparams, **params):
         return self.execute(object, *multiparams, **params).scalar()
 
+    def compiler(self, statement, parameters, **kwargs):
+        return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs)
+
     def execute(self, object, *multiparams, **params):
         for c in type(object).__mro__:
             if c in Connection.executors:
@@ -456,7 +488,7 @@ class Connection(Connectable):
             raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
 
     def execute_default(self, default, **kwargs):
-        return default.accept_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs))
+        return default.accept_visitor(self.__engine.dialect.defaultrunner(self))
 
     def execute_text(self, statement, *multiparams, **params):
         if len(multiparams) == 0:
@@ -465,9 +497,9 @@ class Connection(Connectable):
             parameters = multiparams[0]
         else:
             parameters = list(multiparams)
-        cursor = self._execute_raw(statement, parameters)
-        rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
-        return ResultProxy(self.__engine, self, cursor, **rpargs)
+        context = self._create_execution_context(statement=statement, parameters=parameters)
+        self._execute_raw(context)
+        return context.get_result_proxy()
 
     def _params_to_listofdicts(self, *multiparams, **params):
         if len(multiparams) == 0:
@@ -491,29 +523,57 @@ class Connection(Connectable):
             param = multiparams[0]
         else:
             param = params
-        return self.execute_compiled(elem.compile(engine=self.__engine, parameters=param), *multiparams, **params)
+        return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params)
 
     def execute_compiled(self, compiled, *multiparams, **params):
         """Execute a sql.Compiled object."""
         if not compiled.can_execute:
             raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
-        cursor = self.__engine.dialect.create_cursor(self.connection)
         parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)]
         if len(parameters) == 1:
             parameters = parameters[0]
-        def proxy(statement=None, parameters=None):
-            if statement is None:
-                return cursor
-
-            parameters = self.__engine.dialect.convert_compiled_params(parameters)
-            self._execute_raw(statement, parameters, cursor=cursor, context=context)
-            return cursor
-        context = self.__engine.dialect.create_execution_context()
-        context.pre_exec(self.__engine, proxy, compiled, parameters)
-        proxy(unicode(compiled), parameters)
-        context.post_exec(self.__engine, proxy, compiled, parameters)
-        rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
-        return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs)
+        context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters)
+        context.pre_exec()
+        self._execute_raw(context)
+        context.post_exec()
+        return context.get_result_proxy()
+    
+    def _create_execution_context(self, **kwargs):
+        return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
+        
+    def _execute_raw(self, context):
+        self.__engine.logger.info(context.statement)
+        self.__engine.logger.info(repr(context.parameters))
+        if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], dict)):
+            self._executemany(context)
+        else:
+            self._execute(context)
+        self._autocommit(context.statement)
+
+    def _execute(self, context):
+        if context.parameters is None:
+            if context.dialect.positional:
+                context.parameters = ()
+            else:
+                context.parameters = {}
+        try:
+            context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context)
+        except Exception, e:
+            self._autorollback()
+            #self._rollback_impl()
+            if self.__close_with_result:
+                self.close()
+            raise exceptions.SQLError(context.statement, context.parameters, e)
+
+    def _executemany(self, context):
+        try:
+            context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
+        except Exception, e:
+            self._autorollback()
+            #self._rollback_impl()
+            if self.__close_with_result:
+                self.close()
+            raise exceptions.SQLError(context.statement, context.parameters, e)
 
     # poor man's multimethod/generic function thingy
     executors = {
@@ -525,17 +585,17 @@ class Connection(Connectable):
     }
 
     def create(self, entity, **kwargs):
-        """Create a table or index given an appropriate schema object."""
+        """Create a Table or Index given an appropriate Schema object."""
 
         return self.__engine.create(entity, connection=self, **kwargs)
 
     def drop(self, entity, **kwargs):
-        """Drop a table or index given an appropriate schema object."""
+        """Drop a Table or Index given an appropriate Schema object."""
 
         return self.__engine.drop(entity, connection=self, **kwargs)
 
     def reflecttable(self, table, **kwargs):
-        """Reflect the columns in the given table from the database."""
+        """Reflect the columns in the given string table name from the database."""
 
         return self.__engine.reflecttable(table, connection=self, **kwargs)
 
@@ -545,59 +605,6 @@ class Connection(Connectable):
     def run_callable(self, callable_):
         return callable_(self)
 
-    def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs):
-        if cursor is None:
-            cursor = self.__engine.dialect.create_cursor(self.connection)
-        if not self.__engine.dialect.supports_unicode_statements():
-            # encode to ascii, with full error handling
-            statement = statement.encode('ascii')
-        self.__engine.logger.info(statement)
-        self.__engine.logger.info(repr(parameters))
-        if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
-            self._executemany(cursor, statement, parameters, context=context)
-        else:
-            self._execute(cursor, statement, parameters, context=context)
-        self._autocommit(statement)
-        return cursor
-
-    def _execute(self, c, statement, parameters, context=None):
-        if parameters is None:
-            if self.__engine.dialect.positional:
-                parameters = ()
-            else:
-                parameters = {}
-        try:
-            self.__engine.dialect.do_execute(c, statement, parameters, context=context)
-        except Exception, e:
-            self._autorollback()
-            #self._rollback_impl()
-            if self.__close_with_result:
-                self.close()
-            raise exceptions.SQLError(statement, parameters, e)
-
-    def _executemany(self, c, statement, parameters, context=None):
-        try:
-            self.__engine.dialect.do_executemany(c, statement, parameters, context=context)
-        except Exception, e:
-            self._autorollback()
-            #self._rollback_impl()
-            if self.__close_with_result:
-                self.close()
-            raise exceptions.SQLError(statement, parameters, e)
-
-    def proxy(self, statement=None, parameters=None):
-        """Execute the given statement string and parameter object.
-
-        The parameter object is expected to be the result of a call to
-        ``compiled.get_params()``.  This callable is a generic version
-        of a connection/cursor-specific callable that is produced
-        within the execute_compiled method, and is used for objects
-        that require this style of proxy when outside of an
-        execute_compiled method, primarily the DefaultRunner.
-        """
-        parameters = self.__engine.dialect.convert_compiled_params(parameters)
-        return self._execute_raw(statement, parameters)
-
 class Transaction(object):
     """Represent a Transaction in progress.
 
@@ -630,7 +637,7 @@ class Transaction(object):
             self.__connection._commit_impl()
             self.__is_active = False
 
-class Engine(sql.Executor, Connectable):
+class Engine(Connectable):
     """
     Connects a ConnectionProvider, a Dialect and a CompilerFactory together to
     provide a default implementation of SchemaEngine.
@@ -638,12 +645,13 @@ class Engine(sql.Executor, Connectable):
 
     def __init__(self, connection_provider, dialect, echo=None):
         self.connection_provider = connection_provider
-        self.dialect=dialect
+        self._dialect=dialect
         self.echo = echo
         self.logger = logging.instance_logger(self)
 
     name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'])
     engine = property(lambda s:s)
+    dialect = property(lambda s:s._dialect)
     echo = logging.echo_property()
 
     def dispose(self):
@@ -678,11 +686,11 @@ class Engine(sql.Executor, Connectable):
 
     def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
         if connection is None:
-            conn = self.contextual_connect()
+            conn = self.contextual_connect(close_with_result=False)
         else:
             conn = connection
         try:
-            element.accept_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs))
+            element.accept_visitor(visitorcallable(conn, **kwargs))
         finally:
             if connection is None:
                 conn.close()
@@ -807,55 +815,39 @@ class ResultProxy(object):
         def convert_result_value(self, arg, engine):
             raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
 
-    def __new__(cls, *args, **kwargs):
-        if cls is ResultProxy and kwargs.has_key('should_prefetch') and kwargs['should_prefetch']:
-            return PrefetchingResultProxy(*args, **kwargs)
-        else:
-            return object.__new__(cls, *args, **kwargs)
-
-    def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None, column_labels=None, should_prefetch=None):
+    def __init__(self, context):
         """ResultProxy objects are constructed via the execute() method on SQLEngine."""
-
-        self.connection = connection
-        self.dialect = engine.dialect
-        self.cursor = cursor
-        self.engine = engine
+        self.context = context
         self.closed = False
-        self.column_labels = column_labels
-        if executioncontext is not None:
-            self.__executioncontext = executioncontext
-            self.rowcount = executioncontext.get_rowcount(cursor)
-        else:
-            self.rowcount = cursor.rowcount
-        self.__key_cache = {}
-        self.__echo = engine.echo == 'debug'
-        metadata = cursor.description
-        self.props = {}
-        self.keys = []
-        i = 0
+        self.cursor = context.cursor
+        self.__echo = logging.is_debug_enabled(context.engine.logger)
+        self._init_metadata()
         
+    dialect = property(lambda s:s.context.dialect)
+    rowcount = property(lambda s:s.context.get_rowcount())
+    connection = property(lambda s:s.context.connection)
+    
+    def _init_metadata(self):
+        if hasattr(self, '_ResultProxy__props'):
+            return
+        self.__key_cache = {}
+        self.__props = {}
+        self.__keys = []
+        metadata = self.cursor.description
         if metadata is not None:
-            for item in metadata:
+            for i, item in enumerate(metadata):
                 # sqlite possibly prepending table name to colnames so strip
-                colname = item[0].split('.')[-1].lower()
-                if typemap is not None:
-                    rec = (typemap.get(colname, types.NULLTYPE), i)
+                colname = item[0].split('.')[-1]
+                if self.context.typemap is not None:
+                    rec = (self.context.typemap.get(colname.lower(), types.NULLTYPE), i)
                 else:
                     rec = (types.NULLTYPE, i)
                 if rec[0] is None:
                     raise DBAPIError("None for metadata " + colname)
-                if self.props.setdefault(colname, rec) is not rec:
-                    self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0)
-                self.keys.append(colname)
-                self.props[i] = rec
-                i+=1
-
-    def _executioncontext(self):
-        try:
-            return self.__executioncontext
-        except AttributeError:
-            raise exceptions.InvalidRequestError("This ResultProxy does not have an execution context with which to complete this operation.  Execution contexts are not generated for literal SQL execution.")
-    executioncontext = property(_executioncontext)
+                if self.__props.setdefault(colname.lower(), rec) is not rec:
+                    self.__props[colname.lower()] = (ResultProxy.AmbiguousColumn(colname), 0)
+                self.__keys.append(colname)
+                self.__props[i] = rec
 
     def close(self):
         """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution.
@@ -867,13 +859,12 @@ class ResultProxy(object):
         This method is also called automatically when all result rows
         are exhausted.
         """
-
         if not self.closed:
             self.closed = True
             self.cursor.close()
             if self.connection.should_close_with_result and self.dialect.supports_autoclose_results:
                 self.connection.close()
-
+            
     def _convert_key(self, key):
         """Convert and cache a key.
 
@@ -882,25 +873,26 @@ class ResultProxy(object):
         metadata; then cache it locally for quick re-access.
         """
 
-        try:
+        if key in self.__key_cache:
             return self.__key_cache[key]
-        except KeyError:
-            if isinstance(key, int) and key in self.props:
-                rec = self.props[key]
-            elif isinstance(key, basestring) and key.lower() in self.props:
-                rec = self.props[key.lower()]
+        else:
+            if isinstance(key, int) and key in self.__props:
+                rec = self.__props[key]
+            elif isinstance(key, basestring) and key.lower() in self.__props:
+                rec = self.__props[key.lower()]
             elif isinstance(key, sql.ColumnElement):
-                label = self.column_labels.get(key._label, key.name).lower()
-                if label in self.props:
-                    rec = self.props[label]
+                label = self.context.column_labels.get(key._label, key.name).lower()
+                if label in self.__props:
+                    rec = self.__props[label]
                         
             if not "rec" in locals():
                 raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (repr(key)))
 
             self.__key_cache[key] = rec
             return rec
-            
-
+    
+    keys = property(lambda s:s.__keys)
+    
     def _has_key(self, row, key):
         try:
             self._convert_key(key)
@@ -908,10 +900,6 @@ class ResultProxy(object):
         except KeyError:
             return False
 
-    def _get_col(self, row, key):
-        rec = self._convert_key(key)
-        return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect)
-
     def __iter__(self):
         while True:
             row = self.fetchone()
@@ -926,7 +914,7 @@ class ResultProxy(object):
         See ExecutionContext for details.
         """
 
-        return self.executioncontext.last_inserted_ids()
+        return self.context.last_inserted_ids()
 
     def last_updated_params(self):
         """Return ``last_updated_params()`` from the underlying ExecutionContext.
@@ -934,7 +922,7 @@ class ResultProxy(object):
         See ExecutionContext for details.
         """
 
-        return self.executioncontext.last_updated_params()
+        return self.context.last_updated_params()
 
     def last_inserted_params(self):
         """Return ``last_inserted_params()`` from the underlying ExecutionContext.
@@ -942,7 +930,7 @@ class ResultProxy(object):
         See ExecutionContext for details.
         """
 
-        return self.executioncontext.last_inserted_params()
+        return self.context.last_inserted_params()
 
     def lastrow_has_defaults(self):
         """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext.
@@ -950,7 +938,7 @@ class ResultProxy(object):
         See ExecutionContext for details.
         """
 
-        return self.executioncontext.lastrow_has_defaults()
+        return self.context.lastrow_has_defaults()
 
     def supports_sane_rowcount(self):
         """Return ``supports_sane_rowcount()`` from the underlying ExecutionContext.
@@ -958,71 +946,122 @@ class ResultProxy(object):
         See ExecutionContext for details.
         """
 
-        return self.executioncontext.supports_sane_rowcount()
+        return self.context.supports_sane_rowcount()
 
+    def _get_col(self, row, key):
+        rec = self._convert_key(key)
+        return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect)
+    
+    def _fetchone_impl(self):
+        return self.cursor.fetchone()
+    def _fetchmany_impl(self, size=None):
+        return self.cursor.fetchmany(size)
+    def _fetchall_impl(self):
+        return self.cursor.fetchall()
+        
+    def _process_row(self, row):
+        return RowProxy(self, row)
+            
     def fetchall(self):
         """Fetch all rows, just like DBAPI ``cursor.fetchall()``."""
 
-        l = []
-        for row in self.cursor.fetchall():
-            l.append(RowProxy(self, row))
+        l = [self._process_row(row) for row in self._fetchall_impl()]
         self.close()
         return l
 
     def fetchmany(self, size=None):
         """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``."""
 
-        if size is None:
-            rows = self.cursor.fetchmany()
-        else:
-            rows = self.cursor.fetchmany(size)
-        l = []
-        for row in rows:
-            l.append(RowProxy(self, row))
+        l = [self._process_row(row) for row in self._fetchmany_impl(size)]
         if len(l) == 0:
             self.close()
         return l
 
     def fetchone(self):
         """Fetch one row, just like DBAPI ``cursor.fetchone()``."""
-
-        row = self.cursor.fetchone()
+        row = self._fetchone_impl()
         if row is not None:
-            return RowProxy(self, row)
+            return self._process_row(row)
         else:
             self.close()
             return None
 
     def scalar(self):
         """Fetch the first column of the first row, and close the result set."""
-
-        row = self.cursor.fetchone()
+        row = self._fetchone_impl()
         try:
             if row is not None:
-                return RowProxy(self, row)[0]
+                return self._process_row(row)[0]
             else:
                 return None
         finally:
             self.close()
 
-class PrefetchingResultProxy(ResultProxy):
+class BufferedRowResultProxy(ResultProxy):
+    def _init_metadata(self):
+        self.__buffer_rows()
+        super(BufferedRowResultProxy, self)._init_metadata()
+    
+    # this is a "growth chart" for the buffering of rows.
+    # each successive __buffer_rows call will use the next
+    # value in the list for the buffer size until the max
+    # is reached
+    size_growth = {
+        1 : 5,
+        5 : 10,
+        10 : 20,
+        20 : 50,
+        50 : 100
+    }
+    
+    def __buffer_rows(self):
+        size = getattr(self, '_bufsize', 1)
+        self.__rowbuffer = self.cursor.fetchmany(size)
+        #self.context.engine.logger.debug("Buffered %d rows" % size)
+        self._bufsize = self.size_growth.get(size, size)
+    
+    def _fetchone_impl(self):
+        if self.closed:
+            return None
+        if len(self.__rowbuffer) == 0:
+            self.__buffer_rows()
+            if len(self.__rowbuffer) == 0:
+                return None
+        return self.__rowbuffer.pop(0)
+
+    def _fetchmany_impl(self, size=None):
+        result = []
+        for x in range(0, size):
+            row = self._fetchone_impl()
+            if row is None:
+                break
+            result.append(row)
+        return result
+        
+    def _fetchall_impl(self):
+        return self.__rowbuffer + list(self.cursor.fetchall())
+
+class BufferedColumnResultProxy(ResultProxy):
     """ResultProxy that loads all columns into memory each time fetchone() is
     called.  If fetchmany() or fetchall() are called, the full grid of results
     is fetched.
     """
-
     def _get_col(self, row, key):
         rec = self._convert_key(key)
         return row[rec[1]]
+    
+    def _process_row(self, row):
+        sup = super(BufferedColumnResultProxy, self)
+        row = [sup._get_col(row, i) for i in xrange(len(row))]
+        return RowProxy(self, row)
 
     def fetchall(self):
         l = []
         while True:
             row = self.fetchone()
-            if row is not None:
-                l.append(row)
-            else:
+            if row is None:
                 break
+            l.append(row)
         return l
 
     def fetchmany(self, size=None):
@@ -1031,24 +1070,13 @@ class PrefetchingResultProxy(ResultProxy):
         l = []
         for i in xrange(size):
             row = self.fetchone()
-            if row is not None:
-                l.append(row)
-            else:
+            if row is None:
                 break
+            l.append(row)
         return l
 
-    def fetchone(self):
-        sup = super(PrefetchingResultProxy, self)
-        row = self.cursor.fetchone()
-        if row is not None:
-            row = [sup._get_col(row, i) for i in xrange(len(row))]
-            return RowProxy(self, row)
-        else:
-            self.close()
-            return None
-
 class RowProxy(object):
-    """Proxie a single cursor row for a parent ResultProxy.
+    """Proxy a single cursor row for a parent ResultProxy.
 
     Mostly follows "ordered dictionary" behavior, mapping result
     values to the string-based column name, the integer position of
@@ -1063,7 +1091,7 @@ class RowProxy(object):
         self.__parent = parent
         self.__row = row
         if self.__parent._ResultProxy__echo:
-            self.__parent.engine.logger.debug("Row " + repr(row))
+            self.__parent.context.engine.logger.debug("Row " + repr(row))
 
     def close(self):
         """Close the parent ResultProxy."""
@@ -1115,20 +1143,10 @@ class RowProxy(object):
 class SchemaIterator(schema.SchemaVisitor):
     """A visitor that can gather text into a buffer and execute the contents of the buffer."""
 
-    def __init__(self, engine, proxy, **params):
+    def __init__(self, connection):
         """Construct a new SchemaIterator.
-
-        engine
-          the Engine used by this SchemaIterator
-
-        proxy
-          a callable which takes a statement and bind parameters and
-          executes it, returning the cursor (the actual DBAPI cursor).
-          The callable should use the same cursor repeatedly.
         """
-
-        self.proxy = proxy
-        self.engine = engine
+        self.connection = connection
         self.buffer = StringIO.StringIO()
 
     def append(self, s):
@@ -1140,7 +1158,7 @@ class SchemaIterator(schema.SchemaVisitor):
         """Execute the contents of the SchemaIterator's buffer."""
 
         try:
-            return self.proxy(self.buffer.getvalue(), None)
+            return self.connection.execute(self.buffer.getvalue())
         finally:
             self.buffer.truncate(0)
 
@@ -1154,10 +1172,10 @@ class DefaultRunner(schema.SchemaVisitor):
     DefaultRunner to allow database-specific behavior.
     """
 
-    def __init__(self, engine, proxy):
-        self.proxy = proxy
-        self.engine = engine
-
+    def __init__(self, connection):
+        self.connection = connection
+        self.dialect = connection.dialect
+        
     def get_column_default(self, column):
         if column.default is not None:
             return column.default.accept_visitor(self)
@@ -1188,8 +1206,8 @@ class DefaultRunner(schema.SchemaVisitor):
         return None
 
     def exec_default_sql(self, default):
-        c = sql.select([default.arg], engine=self.engine).compile()
-        return self.proxy(str(c), c.get_params()).fetchone()[0]
+        c = sql.select([default.arg]).compile(engine=self.connection)
+        return self.connection.execute_compiled(c).scalar()
 
     def visit_column_onupdate(self, onupdate):
         if isinstance(onupdate.arg, sql.ClauseElement):
index 86563cd7cbcedfb5d4c6a876b9f57003bd5dba15..ceecee364fb9a2b570c49c5083424cc37f5da792 100644 (file)
@@ -26,16 +26,17 @@ class PoolConnectionProvider(base.ConnectionProvider):
 class DefaultDialect(base.Dialect):
     """Default implementation of Dialect"""
 
-    def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', **kwargs):
+    def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
         self.convert_unicode = convert_unicode
         self.supports_autoclose_results = True
         self.encoding = encoding
         self.positional = False
         self._ischema = None
-        self._figure_paramstyle(default=default_paramstyle)
+        self.dbapi = dbapi
+        self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
 
-    def create_execution_context(self):
-        return DefaultExecutionContext(self)
+    def create_execution_context(self, **kwargs):
+        return DefaultExecutionContext(self, **kwargs)
 
     def type_descriptor(self, typeobj):
         """Provide a database-specific ``TypeEngine`` object, given
@@ -56,6 +57,9 @@ class DefaultDialect(base.Dialect):
         # TODO: probably raise this and fill out
         # db modules better
         return 30
+
+    def supports_alter(self):
+        return True
         
     def oid_column_name(self, column):
         return None
@@ -92,14 +96,8 @@ class DefaultDialect(base.Dialect):
     def do_execute(self, cursor, statement, parameters, **kwargs):
         cursor.execute(statement, parameters)
 
-    def defaultrunner(self, engine, proxy):
-        return base.DefaultRunner(engine, proxy)
-
-    def create_cursor(self, connection):
-        return connection.cursor()
-
-    def create_result_proxy_args(self, connection, cursor):
-        return dict(should_prefetch=False)
+    def defaultrunner(self, connection):
+        return base.DefaultRunner(connection)
 
     def _set_paramstyle(self, style):
         self._paramstyle = style
@@ -126,11 +124,10 @@ class DefaultDialect(base.Dialect):
         return parameters
 
     def _figure_paramstyle(self, paramstyle=None, default='named'):
-        db = self.dbapi()
         if paramstyle is not None:
             self._paramstyle = paramstyle
-        elif db is not None:
-            self._paramstyle = db.paramstyle
+        elif self.dbapi is not None:
+            self._paramstyle = self.dbapi.paramstyle
         else:
             self._paramstyle = default
 
@@ -146,10 +143,6 @@ class DefaultDialect(base.Dialect):
             raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
 
     def _get_ischema(self):
-        # We use a property for ischema so that the accessor
-        # creation only happens as needed, since otherwise we
-        # have a circularity problem with the generic
-        # ansisql.engine()
         if self._ischema is None:
             import sqlalchemy.databases.information_schema as ischema
             self._ischema = ischema.ISchema(self)
@@ -157,20 +150,49 @@ class DefaultDialect(base.Dialect):
     ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
 
 class DefaultExecutionContext(base.ExecutionContext):
-    def __init__(self, dialect):
+    def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
         self.dialect = dialect
+        self.connection = connection
+        self.compiled = compiled
+        self.compiled_parameters = compiled_parameters
+        
+        if compiled is not None:
+            self.typemap = compiled.typemap
+            self.column_labels = compiled.column_labels
+            self.statement = unicode(compiled)
+        else:
+            self.typemap = self.column_labels = None
+            self.parameters = parameters
+            self.statement = statement
 
-    def pre_exec(self, engine, proxy, compiled, parameters):
-        self._process_defaults(engine, proxy, compiled, parameters)
+        if not dialect.supports_unicode_statements():
+            self.statement = self.statement.encode('ascii')
+        
+        self.cursor = self.create_cursor()
+        
+    engine = property(lambda s:s.connection.engine)
+    
+    def is_select(self):
+        return re.match(r'SELECT', self.statement.lstrip(), re.I)
+
+    def create_cursor(self):
+        return self.connection.connection.cursor()
+        
+    def pre_exec(self):
+        self._process_defaults()
+        self.parameters = self.dialect.convert_compiled_params(self.compiled_parameters)
 
-    def post_exec(self, engine, proxy, compiled, parameters):
+    def post_exec(self):
         pass
 
-    def get_rowcount(self, cursor):
+    def get_result_proxy(self):
+        return base.ResultProxy(self)
+
+    def get_rowcount(self):
         if hasattr(self, '_rowcount'):
             return self._rowcount
         else:
-            return cursor.rowcount
+            return self.cursor.rowcount
 
     def supports_sane_rowcount(self):
         return self.dialect.supports_sane_rowcount()
@@ -187,44 +209,44 @@ class DefaultExecutionContext(base.ExecutionContext):
     def lastrow_has_defaults(self):
         return self._lastrow_has_defaults
 
-    def set_input_sizes(self, cursor, parameters):
+    def set_input_sizes(self):
         """Given a cursor and ClauseParameters, call the appropriate
         style of ``setinputsizes()`` on the cursor, using DBAPI types
         from the bind parameter's ``TypeEngine`` objects.
         """
 
-        if isinstance(parameters, list):
-            plist = parameters
+        if isinstance(self.compiled_parameters, list):
+            plist = self.compiled_parameters
         else:
-            plist = [parameters]
+            plist = [self.compiled_parameters]
         if self.dialect.positional:
             inputsizes = []
             for params in plist[0:1]:
                 for key in params.positional:
                     typeengine = params.binds[key].type
-                    dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module)
+                    dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
                     if dbtype is not None:
                         inputsizes.append(dbtype)
-            cursor.setinputsizes(*inputsizes)
+            self.cursor.setinputsizes(*inputsizes)
         else:
             inputsizes = {}
             for params in plist[0:1]:
                 for key in params.keys():
                     typeengine = params.binds[key].type
-                    dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module)
+                    dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
                     if dbtype is not None:
                         inputsizes[key] = dbtype
-            cursor.setinputsizes(**inputsizes)
+            self.cursor.setinputsizes(**inputsizes)
 
-    def _process_defaults(self, engine, proxy, compiled, parameters):
+    def _process_defaults(self):
         """``INSERT`` and ``UPDATE`` statements, when compiled, may
         have additional columns added to their ``VALUES`` and ``SET``
         lists corresponding to column defaults/onupdates that are
         present on the ``Table`` object (i.e. ``ColumnDefault``,
         ``Sequence``, ``PassiveDefault``).  This method pre-execs
         those ``DefaultGenerator`` objects that require pre-execution
-        and sets their values within the parameter list, and flags the
-        thread-local state about ``PassiveDefault`` objects that may
+        and sets their values within the parameter list, and flags this
+        ExecutionContext about ``PassiveDefault`` objects that may
         require post-fetching the row after it is inserted/updated.
 
         This method relies upon logic within the ``ANSISQLCompiler``
@@ -234,30 +256,28 @@ class DefaultExecutionContext(base.ExecutionContext):
         statement.
         """
 
-        if compiled is None: return
-
-        if getattr(compiled, "isinsert", False):
-            if isinstance(parameters, list):
-                plist = parameters
+        if self.compiled.isinsert:
+            if isinstance(self.compiled_parameters, list):
+                plist = self.compiled_parameters
             else:
-                plist = [parameters]
-            drunner = self.dialect.defaultrunner(engine, proxy)
+                plist = [self.compiled_parameters]
+            drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
             self._lastrow_has_defaults = False
             for param in plist:
                 last_inserted_ids = []
                 need_lastrowid=False
                 # check the "default" status of each column in the table
-                for c in compiled.statement.table.c:
+                for c in self.compiled.statement.table.c:
                     # check if it will be populated by a SQL clause - we'll need that
                     # after execution.
-                    if c in compiled.inline_params:
+                    if c in self.compiled.inline_params:
                         self._lastrow_has_defaults = True
                         if c.primary_key:
                             need_lastrowid = True
                     # check if its not present at all.  see if theres a default
                     # and fire it off, and add to bind parameters.  if
                     # its a pk, add the value to our last_inserted_ids list,
-                    # or, if its a SQL-side default, dont do any of that, but we'll need
+                    # or, if its a SQL-side default, let it fire off on the DB side, but we'll need
                     # the SQL-generated value after execution.
                     elif not c.key in param or param.get_original(c.key) is None:
                         if isinstance(c.default, schema.PassiveDefault):
@@ -278,19 +298,19 @@ class DefaultExecutionContext(base.ExecutionContext):
                 else:
                     self._last_inserted_ids = last_inserted_ids
                 self._last_inserted_params = param
-        elif getattr(compiled, 'isupdate', False):
-            if isinstance(parameters, list):
-                plist = parameters
+        elif self.compiled.isupdate:
+            if isinstance(self.compiled_parameters, list):
+                plist = self.compiled_parameters
             else:
-                plist = [parameters]
-            drunner = self.dialect.defaultrunner(engine, proxy)
+                plist = [self.compiled_parameters]
+            drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
             self._lastrow_has_defaults = False
             for param in plist:
                 # check the "onupdate" status of each column in the table
-                for c in compiled.statement.table.c:
+                for c in self.compiled.statement.table.c:
                     # it will be populated by a SQL clause - we'll need that
                     # after execution.
-                    if c in compiled.inline_params:
+                    if c in self.compiled.inline_params:
                         pass
                     # its not in the bind parameters, and theres an "onupdate" defined for the column;
                     # execute it and add to bind params
index 8ac721b77c3be94ab12ba04d82d912fd7764109b..1b760fca8b2f28d6ccc2174b8cee2bff5caeb654 100644 (file)
@@ -50,6 +50,16 @@ class DefaultEngineStrategy(EngineStrategy):
             if k in kwargs:
                 dialect_args[k] = kwargs.pop(k)
 
+        dbapi = kwargs.pop('module', None)
+        if dbapi is None:
+            dbapi_args = {}
+            for k in util.get_func_kwargs(module.dbapi):
+                if k in kwargs:
+                    dbapi_args[k] = kwargs.pop(k)
+            dbapi = module.dbapi(**dbapi_args)
+        
+        dialect_args['dbapi'] = dbapi
+        
         # create dialect
         dialect = module.dialect(**dialect_args)
 
@@ -60,10 +70,6 @@ class DefaultEngineStrategy(EngineStrategy):
         # look for existing pool or create
         pool = kwargs.pop('pool', None)
         if pool is None:
-            dbapi = kwargs.pop('module', dialect.dbapi())
-            if dbapi is None:
-                raise exceptions.InvalidRequestError("Can't get DBAPI module for dialect '%s'" % dialect)
-
             def connect():
                 try:
                     return dbapi.connect(*cargs, **cparams)
@@ -73,6 +79,7 @@ class DefaultEngineStrategy(EngineStrategy):
 
             poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool))
             pool_args = {}
+
             # consume pool arguments from kwargs, translating a few of the arguments
             for k in util.get_cls_kwargs(poolclass):
                 tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k)
@@ -139,3 +146,52 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy):
         return threadlocal.TLEngine
 
 ThreadLocalEngineStrategy()
+
+
+class MockEngineStrategy(EngineStrategy):
+    """Produces a single Connection object which dispatches statement executions
+    to a passed-in function"""
+    def __init__(self):
+        EngineStrategy.__init__(self, 'mock')
+        
+    def create(self, name_or_url, executor, **kwargs):
+        # create url.URL object
+        u = url.make_url(name_or_url)
+
+        # get module from sqlalchemy.databases
+        module = u.get_module()
+
+        dialect_args = {}
+        # consume dialect arguments from kwargs
+        for k in util.get_cls_kwargs(module.dialect):
+            if k in kwargs:
+                dialect_args[k] = kwargs.pop(k)
+
+        # create dialect
+        dialect = module.dialect(**dialect_args)
+
+        return MockEngineStrategy.MockConnection(dialect, executor)
+
+    class MockConnection(base.Connectable):
+        def __init__(self, dialect, execute):
+            self._dialect = dialect
+            self.execute = execute
+
+        engine = property(lambda s: s)
+        dialect = property(lambda s:s._dialect)
+        
+        def contextual_connect(self):
+            return self
+
+        def create(self, entity, **kwargs):
+            kwargs['checkfirst'] = False
+            entity.accept_visitor(self.dialect.schemagenerator(self, **kwargs))
+
+        def drop(self, entity, **kwargs):
+            kwargs['checkfirst'] = False
+            entity.accept_visitor(self.dialect.schemadropper(self, **kwargs))
+
+        def execute(self, object, *multiparams, **params):
+            raise NotImplementedError()
+
+MockEngineStrategy()
\ No newline at end of file
index edb8cf32e8a559777c1cc9b1cf0745689e3f10f5..faa0ffc11cd94e53099adc2e2f0a8b24c682e6b9 100644 (file)
@@ -71,6 +71,10 @@ class URL(object):
 
     def get_module(self):
         """Return the SQLAlchemy database module corresponding to this URL's driver name."""
+        if self.drivername == 'ansi':
+            import sqlalchemy.ansisql
+            return sqlalchemy.ansisql
+            
         try:
             return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
         except ImportError:
index 6f4368707988ee41aab02ab92c45ecd34827bebd..91326233a6c57ed478cc2e0d34acaed88e806646 100644 (file)
@@ -31,8 +31,8 @@ import sys
 # py2.5 absolute imports will fix....
 logging = __import__('logging')
 
-# turn off logging at the root sqlalchemy level
-logging.getLogger('sqlalchemy').setLevel(logging.ERROR)
+
+logging.getLogger('sqlalchemy').setLevel(logging.WARN)
 
 default_enabled = False
 def default_logging(name):
index 787fd059f288bb31cd2bcdaf678705a18939c70d..8d559aff52f92143c6d4a5d6ef3298dbebf323cb 100644 (file)
@@ -237,7 +237,9 @@ class _ConnectionFairy(object):
             raise
         if self.__pool.echo:
             self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
-
+    
+    _logger = property(lambda self: self.__pool.logger)
+         
     def invalidate(self):
         if self.connection is None:
             raise exceptions.InvalidRequestError("This connection is closed")
@@ -248,7 +250,8 @@ class _ConnectionFairy(object):
 
     def cursor(self, *args, **kwargs):
         try:
-            return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+            c = self.connection.cursor(*args, **kwargs)
+            return _CursorFairy(self, c)
         except Exception, e:
             self.invalidate()
             raise
@@ -307,11 +310,14 @@ class _CursorFairy(object):
 
     def invalidate(self):
         self.__parent.invalidate()
-
+    
     def close(self):
         if self in self.__parent._cursors:
             del self.__parent._cursors[self]
-            self.cursor.close()
+            try:
+                self.cursor.close()
+            except Exception, e:
+                self.__parent._logger.warn("Error closing cursor: " + str(e))
 
     def __getattr__(self, key):
         return getattr(self.cursor, key)
index 87cbdaf0c3cb5846f73b043ed3386216a2f5e01a..f6c2315ae916ec792703d63901383a6892852d1c 100644 (file)
@@ -508,7 +508,7 @@ class ClauseParameters(object):
         return d
 
     def __repr__(self):
-        return repr(self.get_original_dict())
+        return self.__class__.__name__ + ":" + repr(self.get_original_dict())
 
 class ClauseVisitor(object):
     """A class that knows how to traverse and visit
index 86e323c6ea330a6c234615519ba75f8105823f99..7d7dbeeedf4d51d01f961f0d861f65b9862781ad 100644 (file)
@@ -53,28 +53,12 @@ class TypeEngine(AbstractType):
     def __init__(self, *args, **params):
         pass
 
-    def engine_impl(self, engine):
-        """Deprecated; call dialect_impl with a dialect directly."""
-
-        return self.dialect_impl(engine.dialect)
-
     def dialect_impl(self, dialect):
         try:
             return self.impl_dict[dialect]
         except KeyError:
             return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self))
 
-    def _get_impl(self):
-        if hasattr(self, '_impl'):
-            return self._impl
-        else:
-            return NULLTYPE
-
-    def _set_impl(self, impl):
-        self._impl = impl
-
-    impl = property(_get_impl, _set_impl)
-
     def get_col_spec(self):
         raise NotImplementedError()
 
@@ -86,26 +70,25 @@ class TypeEngine(AbstractType):
 
     def adapt(self, cls):
         return cls()
-
+    
+    def get_search_list(self):
+        """return a list of classes to test for a match 
+        when adapting this type to a dialect-specific type.
+        
+        """
+        
+        return self.__class__.__mro__[0:-1]
+        
 class TypeDecorator(AbstractType):
     def __init__(self, *args, **kwargs):
         if not hasattr(self.__class__, 'impl'):
             raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
         self.impl = self.__class__.impl(*args, **kwargs)
 
-    def engine_impl(self, engine):
-        return self.dialect_impl(engine.dialect)
-
     def dialect_impl(self, dialect):
         try:
             return self.impl_dict[dialect]
         except:
-            # see if the dialect has an adaptation of the TypeDecorator itself
-            adapted_decorator = dialect.type_descriptor(self)
-            if adapted_decorator is not self:
-                result = adapted_decorator.dialect_impl(dialect)
-                self.impl_dict[dialect] = result
-                return result
             typedesc = dialect.type_descriptor(self.impl)
             tt = self.copy()
             if not isinstance(tt, self.__class__):
@@ -168,8 +151,7 @@ def to_instance(typeobj):
 def adapt_type(typeobj, colspecs):
     if isinstance(typeobj, type):
         typeobj = typeobj()
-
-    for t in typeobj.__class__.__mro__[0:-1]:
+    for t in typeobj.get_search_list():
         try:
             impltype = colspecs[t]
             break
@@ -198,26 +180,28 @@ class NullTypeEngine(TypeEngine):
         return value
 
 class String(TypeEngine):
-    def __new__(cls, *args, **kwargs):
-        if cls is not String or len(args) > 0 or kwargs.has_key('length'):
-            return super(String, cls).__new__(cls, *args, **kwargs)
-        else:
-            return super(String, TEXT).__new__(TEXT, *args, **kwargs)
-
-    def __init__(self, length = None):
+    def __init__(self, length=None, convert_unicode=False):
         self.length = length
+        self.convert_unicode = convert_unicode
 
     def adapt(self, impltype):
-        return impltype(length=self.length)
+        return impltype(length=self.length, convert_unicode=self.convert_unicode)
 
     def convert_bind_param(self, value, dialect):
-        if not dialect.convert_unicode or value is None or not isinstance(value, unicode):
+        if not (self.convert_unicode or dialect.convert_unicode) or value is None or not isinstance(value, unicode):
             return value
         else:
             return value.encode(dialect.encoding)
 
+    def get_search_list(self):
+        l = super(String, self).get_search_list()
+        if self.length is None:
+            return (TEXT,) + l
+        else:
+            return l
+
     def convert_result_value(self, value, dialect):
-        if not dialect.convert_unicode or value is None or isinstance(value, unicode):
+        if not (self.convert_unicode or dialect.convert_unicode) or value is None or isinstance(value, unicode):
             return value
         else:
             return value.decode(dialect.encoding)
@@ -228,21 +212,11 @@ class String(TypeEngine):
     def compare_values(self, x, y):
         return x == y
 
-class Unicode(TypeDecorator):
-    impl = String
-
-    def convert_bind_param(self, value, dialect):
-         if value is not None and isinstance(value, unicode):
-              return value.encode(dialect.encoding)
-         else:
-              return value
-
-    def convert_result_value(self, value, dialect):
-         if value is not None and not isinstance(value, unicode):
-             return value.decode(dialect.encoding)
-         else:
-             return value
-
+class Unicode(String):
+    def __init__(self, length=None, **kwargs):
+        kwargs['convert_unicode'] = True
+        super(Unicode, self).__init__(length=length, **kwargs)
+    
 class Integer(TypeEngine):
     """Integer datatype."""
 
@@ -310,7 +284,7 @@ class Binary(TypeEngine):
 
     def convert_bind_param(self, value, dialect):
         if value is not None:
-            return dialect.dbapi().Binary(value)
+            return dialect.dbapi.Binary(value)
         else:
             return None
 
index dadcf0ddeea1eacbd8c3a2dee2a2345760eca271..238f12493fc8aaa2b0a38068989247eae60faa6b 100644 (file)
@@ -94,6 +94,10 @@ def get_cls_kwargs(cls):
                     kw.append(vn)
     return kw
 
+def get_func_kwargs(func):
+    """Return the full set of legal kwargs for the given `func`."""
+    return [vn for vn in func.func_code.co_varnames]
+
 class SimpleProperty(object):
     """A *default* property accessor."""
 
index 62cd92b6e6e358ef2a50b25a46ac6ef8965b500a..0c6323c10ff5cd22b56955da55f61f636d8eab00 100644 (file)
@@ -498,9 +498,10 @@ class SchemaTest(PersistTest):
         # insure this doesnt crash
         print [t for t in metadata.table_iterator()]
         buf = StringIO.StringIO()
-        def foo(s, p):
+        def foo(s, p=None):
             buf.write(s)
-        gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo, None)
+        gen = create_engine(testbase.db.name + "://", strategy="mock", executor=foo)
+        gen = gen.dialect.schemagenerator(gen)
         gen.traverse(table1)
         gen.traverse(table2)
         buf = buf.getvalue()
index f92e70df3ad0e0db8fc0dcfcb074a81ef8be5974..bdc9e02e121f4e10acc42a6c63a157db48c9a5e1 100644 (file)
@@ -42,7 +42,7 @@ class RelationTest1(testbase.ORMTest):
         try:
             compile_mappers()
         except exceptions.ArgumentError, ar:
-            assert str(ar) == "Cant determine relation direction for relationship 'Person.manager (Manager)' - foreign key columns are present in both the parent and the child's mapped tables.  Specify 'foreign_keys' argument."
+            assert str(ar) == "Can't determine relation direction for relationship 'Person.manager (Manager)' - foreign key columns are present in both the parent and the child's mapped tables.  Specify 'foreign_keys' argument.", str(ar)
 
         clear_mappers()
 
index 6f80df38f50fe975f2ee74416ebb361f8992f24f..839a5172e69d8a6ee92dc208faee5519f91c6f16 100644 (file)
@@ -1332,7 +1332,7 @@ class InstancesTest(MapperSuperTest):
             'addresses':relation(Address, lazy=True)
         })
         mapper(Address, addresses)
-        query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True)
+        query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.user_id', addresses.c.address_id])
         q = create_session().query(User)
         
         def go():
@@ -1348,7 +1348,7 @@ class InstancesTest(MapperSuperTest):
         })
         mapper(Address, addresses)
         
-        selectquery = users.outerjoin(addresses).select(use_labels=True)
+        selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.user_id, addresses.c.address_id])
         q = create_session().query(User)
         
         def go():
@@ -1363,7 +1363,7 @@ class InstancesTest(MapperSuperTest):
         mapper(Address, addresses)
 
         adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True)
+        selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
         q = create_session().query(User)
 
         def go():
@@ -1378,7 +1378,7 @@ class InstancesTest(MapperSuperTest):
         mapper(Address, addresses)
 
         adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True)
+        selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
         q = create_session().query(User)
 
         def go():
@@ -1393,7 +1393,7 @@ class InstancesTest(MapperSuperTest):
         mapper(Address, addresses)
 
         adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True)
+        selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
         def decorate(row):
             d = {}
             for c in addresses.columns:
@@ -1418,7 +1418,7 @@ class InstancesTest(MapperSuperTest):
         (user7, user8, user9) = sess.query(User).select()
         (address1, address2, address3, address4) = sess.query(Address).select()
         
-        selectquery = users.outerjoin(addresses).select(use_labels=True)
+        selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.user_id, addresses.c.address_id])
         q = sess.query(User)
         l = q.instances(selectquery.execute(), Address)
         # note the result is a cartesian product
index 231a491b5272b31571d19cdfd092462ad624d614..d695e824c757989d74ed5daff15b5ccbfee54f2c 100644 (file)
@@ -172,11 +172,13 @@ class ConstraintTest(testbase.AssertMixin):
 
         capt = []
         connection = testbase.db.connect()
-        def proxy(statement, parameters):
-            capt.append(statement)
-            capt.append(repr(parameters))
-            connection.proxy(statement, parameters)
-        schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy, connection)
+        ex = connection._execute
+        def proxy(context):
+            capt.append(context.statement)
+            capt.append(repr(context.parameters))
+            ex(context)
+        connection._execute = proxy
+        schemagen = testbase.db.dialect.schemagenerator(connection)
         schemagen.traverse(events)
         
         assert capt[0].strip().startswith('CREATE TABLE events')
index 3c3e2334c085e157e07d18674ed8e471851ccdf6..08c766a0df27ac6169bd38eed009479d9b8bb0a0 100644 (file)
@@ -357,7 +357,7 @@ class QueryTest(PersistTest):
                          Column('__parent', VARCHAR(20)),
                          Column('__row', VARCHAR(20)),
         )
-        shadowed.create()
+        shadowed.create(checkfirst=True)
         try:
             shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row')
             r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone()
@@ -374,7 +374,7 @@ class QueryTest(PersistTest):
                 pass # expected
             r.close()
         finally:
-            shadowed.drop()
+            shadowed.drop(checkfirst=True)
 
 class CompoundTest(PersistTest):
     """test compound statements like UNION, INTERSECT, particularly their ability to nest on
index 97e95d38926c0081a5b284ab39c6f55a9ba90a05..d1256b31a564bbad018bbdb5a422526107fe13c4 100644 (file)
@@ -6,7 +6,7 @@ import string,datetime, re, sys, os
 import sqlalchemy.engine.url as url
 
 import sqlalchemy.types
-
+from sqlalchemy.databases import mssql, oracle
 
 db = testbase.db
 
@@ -22,18 +22,19 @@ class MyType(types.TypeEngine):
 
 class MyDecoratedType(types.TypeDecorator):
     impl = String
-    def convert_bind_param(self, value, engine):
-        return "BIND_IN"+ value
-    def convert_result_value(self, value, engine):
-        return value + "BIND_OUT"
+    def convert_bind_param(self, value, dialect):
+        return "BIND_IN"+ super(MyDecoratedType, self).convert_bind_param(value, dialect)
+    def convert_result_value(self, value, dialect):
+        return super(MyDecoratedType, self).convert_result_value(value, dialect) + "BIND_OUT"
     def copy(self):
         return MyDecoratedType()
         
-class MyUnicodeType(types.Unicode):
-    def convert_bind_param(self, value, engine):
-        return "UNI_BIND_IN"+ value
-    def convert_result_value(self, value, engine):
-        return value + "UNI_BIND_OUT"
+class MyUnicodeType(types.TypeDecorator):
+    impl = Unicode
+    def convert_bind_param(self, value, dialect):
+        return "UNI_BIND_IN"+ super(MyUnicodeType, self).convert_bind_param(value, dialect)
+    def convert_result_value(self, value, dialect):
+        return super(MyUnicodeType, self).convert_result_value(value, dialect) + "UNI_BIND_OUT"
     def copy(self):
         return MyUnicodeType(self.impl.length)
 
@@ -52,31 +53,29 @@ class AdaptTest(PersistTest):
         assert t2 != t3
         assert t3 != t1
     
-    def testdecorator(self):
-        t1 = Unicode(20)
-        t2 = Unicode()
-        assert isinstance(t1.impl, String)
-        assert not isinstance(t1.impl, TEXT)
-        assert (t1.impl.length == 20)
-        assert isinstance(t2.impl, TEXT)
-        assert t2.impl.length is None
-
-
-    def testdialecttypedecorators(self):
-        """test that a a Dialect can provide a dialect-specific subclass of a TypeDecorator subclass."""
-        import sqlalchemy.databases.mssql as mssql
+    def testmsnvarchar(self):
         dialect = mssql.MSSQLDialect()
         # run the test twice to insure the caching step works too
         for x in range(0, 1):
             col = Column('', Unicode(length=10))
             dialect_type = col.type.dialect_impl(dialect)
-            assert isinstance(dialect_type, mssql.MSUnicode)
+            assert isinstance(dialect_type, mssql.MSNVarchar)
             assert dialect_type.get_col_spec() == 'NVARCHAR(10)'
-            assert isinstance(dialect_type.impl, mssql.MSString)
-            
+
+    def testoracletext(self):
+        dialect = oracle.OracleDialect()
+        col = Column('', MyDecoratedType)
+        dialect_type = col.type.dialect_impl(dialect)
+        assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl)
+    
 class OverrideTest(PersistTest):
     """tests user-defined types, including a full type as well as a TypeDecorator"""
 
+    def testbasic(self):
+        print users.c.goofy4.type
+        print users.c.goofy4.type.dialect_impl(testbase.db.dialect)
+        print users.c.goofy4.type.dialect_impl(testbase.db.dialect).get_col_spec()
+        
     def testprocessing(self):
 
         global users
index 8a1d9ee59a827b7ebbf04ffaf4b26265f2376c2f..aae455673fca326bf77d2f63615ddbda2ff2d061 100644 (file)
@@ -1,12 +1,9 @@
 import sys
 sys.path.insert(0, './lib/')
-import os
-import unittest
-import StringIO
-import sqlalchemy.ext.proxy as proxy
-import re
+import os, unittest, StringIO, re
 import sqlalchemy
 from sqlalchemy import sql, engine, pool
+import sqlalchemy.engine.base as base
 import optparse
 from sqlalchemy.schema import BoundMetaData
 from sqlalchemy.orm import clear_mappers
@@ -49,6 +46,7 @@ def parse_argv():
     parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)")
     parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running")
     parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)")
+    parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG")
     
     (options, args) = parser.parse_args()
     sys.argv[1:] = args
@@ -73,7 +71,7 @@ def parse_argv():
             db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
         elif DBTYPE == 'oracle8':
             db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
-            opts = {'use_ansi':False}
+            opts['use_ansi'] = False
         elif DBTYPE == 'mssql':
             db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test'
         elif DBTYPE == 'firebird':
@@ -94,6 +92,9 @@ def parse_argv():
     
     global with_coverage
     with_coverage = options.coverage
+
+    if options.serverside:
+        opts['server_side_cursors'] = True
     
     if options.enginestrategy is not None:
         opts['strategy'] = options.enginestrategy    
@@ -101,7 +102,16 @@ def parse_argv():
         db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts)
     else:
         db = engine.create_engine(db_uri, **opts)
-    db = EngineAssert(db)
+
+    # decorate the dialect's create_execution_context() method
+    # to produce a wrapper
+    create_context = db.dialect.create_execution_context
+    def create_exec_context(*args, **kwargs):
+        return ExecutionContextWrapper(create_context(*args, **kwargs))
+    db.dialect.create_execution_context = create_exec_context
+    
+    global testdata
+    testdata = TestData(db)
     
     if options.topological:
         from sqlalchemy.orm import unitofwork
@@ -172,8 +182,6 @@ class PersistTest(unittest.TestCase):
         """overridden to not return docstrings"""
         return None
 
-
-
 class AssertMixin(PersistTest):
     """given a list-based structure of keys/properties which represent information within an object structure, and
     a list of actual objects, asserts that the list of objects corresponds to the structure."""
@@ -197,20 +205,24 @@ class AssertMixin(PersistTest):
             else:
                 self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
     def assert_sql(self, db, callable_, list, with_sequences=None):
+        global testdata
+        testdata = TestData(db)
         if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'):
-            db.set_assert_list(self, with_sequences)
+            testdata.set_assert_list(self, with_sequences)
         else:
-            db.set_assert_list(self, list)
+            testdata.set_assert_list(self, list)
         try:
             callable_()
         finally:
-            db.set_assert_list(None, None)
+            testdata.set_assert_list(None, None)
+
     def assert_sql_count(self, db, callable_, count):
-        db.sql_count = 0
+        global testdata
+        testdata = TestData(db)
         try:
             callable_()
         finally:
-            self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count))
+            self.assert_(testdata.sql_count == count, "desired statement count %d does not match %d" % (count, testdata.sql_count))
 
 class ORMTest(AssertMixin):
     keep_mappers = False
@@ -233,83 +245,73 @@ class ORMTest(AssertMixin):
             for t in metadata.table_iterator(reverse=True):
                 t.delete().execute().close()
 
-class EngineAssert(proxy.BaseProxyEngine):
-    """decorates a SQLEngine object to match the incoming queries against a set of assertions."""
+class TestData(object):
     def __init__(self, engine):
         self._engine = engine
-
-        self.real_execution_context = engine.dialect.create_execution_context
-        engine.dialect.create_execution_context = self.execution_context
-        
         self.logger = engine.logger
         self.set_assert_list(None, None)
         self.sql_count = 0
-    def get_engine(self):
-        return self._engine
-    def set_engine(self, e):
-        self._engine = e
+        
     def set_assert_list(self, unittest, list):
         self.unittest = unittest
         self.assert_list = list
         if list is not None:
             self.assert_list.reverse()
-    def _set_echo(self, echo):
-        self.engine.echo = echo
-    echo = property(lambda s: s.engine.echo, _set_echo)
     
-    def execution_context(self):
-        def post_exec(engine, proxy, compiled, parameters, **kwargs):
-            ctx = e
-            self.engine.logger = self.logger
-            statement = unicode(compiled)
-            statement = re.sub(r'\n', '', statement)
-
-            if self.assert_list is not None:
-                item = self.assert_list[-1]
-                if not isinstance(item, dict):
-                    item = self.assert_list.pop()
-                else:
-                    # asserting a dictionary of statements->parameters
-                    # this is to specify query assertions where the queries can be in 
-                    # multiple orderings
-                    if not item.has_key('_converted'):
-                        for key in item.keys():
-                            ckey = self.convert_statement(key)
-                            item[ckey] = item[key]
-                            if ckey != key:
-                                del item[key]
-                        item['_converted'] = True
-                    try:
-                        entry = item.pop(statement)
-                        if len(item) == 1:
-                            self.assert_list.pop()
-                        item = (statement, entry)
-                    except KeyError:
-                        self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
-
-                (query, params) = item
-                if callable(params):
-                    params = params(ctx)
-                if params is not None and isinstance(params, list) and len(params) == 1:
-                    params = params[0]
-                
-                if isinstance(parameters, sql.ClauseParameters):
-                    parameters = parameters.get_original_dict()
-                elif isinstance(parameters, list):
-                    parameters = [p.get_original_dict() for p in parameters]
-                        
-                query = self.convert_statement(query)
-                self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
-            self.sql_count += 1
-            return realexec(ctx, proxy, compiled, parameters, **kwargs)
-
-        e = self.real_execution_context()
-        realexec = e.post_exec
-        realexec.im_self.post_exec = post_exec
-        return e
+class ExecutionContextWrapper(object):
+    def __init__(self, ctx):
+        self.__dict__['ctx'] = ctx
+    def __getattr__(self, key):
+        return getattr(self.ctx, key)
+    def __setattr__(self, key, value):
+        setattr(self.ctx, key, value)
+        
+    def post_exec(self):
+        ctx = self.ctx
+        statement = unicode(ctx.compiled)
+        statement = re.sub(r'\n', '', ctx.statement)
+
+        if testdata.assert_list is not None:
+            item = testdata.assert_list[-1]
+            if not isinstance(item, dict):
+                item = testdata.assert_list.pop()
+            else:
+                # asserting a dictionary of statements->parameters
+                # this is to specify query assertions where the queries can be in 
+                # multiple orderings
+                if not item.has_key('_converted'):
+                    for key in item.keys():
+                        ckey = self.convert_statement(key)
+                        item[ckey] = item[key]
+                        if ckey != key:
+                            del item[key]
+                    item['_converted'] = True
+                try:
+                    entry = item.pop(statement)
+                    if len(item) == 1:
+                        testdata.assert_list.pop()
+                    item = (statement, entry)
+                except KeyError:
+                    self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
+
+            (query, params) = item
+            if callable(params):
+                params = params(ctx)
+            if params is not None and isinstance(params, list) and len(params) == 1:
+                params = params[0]
+            
+            if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
+                parameters = ctx.compiled_parameters.get_original_dict()
+            elif isinstance(ctx.compiled_parameters, list):
+                parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
+                    
+            query = self.convert_statement(query)
+            testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
+        testdata.sql_count += 1
+        self.ctx.post_exec()
         
     def convert_statement(self, query):
-        paramstyle = self.engine.dialect.paramstyle
+        paramstyle = self.ctx.dialect.paramstyle
         if paramstyle == 'named':
             pass
         elif paramstyle =='pyformat':