]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refinement of connection.execute() , parameter processing behavior
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jul 2007 15:52:09 +0000 (15:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jul 2007 15:52:09 +0000 (15:52 +0000)
- Connection's dealings with params are simplified; generation of
ClauseParameters pushed into DefaultDialect.
- simplified ClauseParameters.
- this is to make room for execute_raw() but I haven't decided how that
should look yet.

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/dialect/postgres.py
test/sql/constraints.py
test/sql/testtypes.py
test/testlib/testing.py

index d26d11c4d9b30aed0593be35d4f5f20a0b1ff8d5..2d738769d7eb9e46a115264e25b3d0bc87a1e497 100644 (file)
@@ -240,24 +240,12 @@ class ANSICompiler(engine.Compiled):
         return self.wheres.get(obj, None)
 
     def construct_params(self, params):
-        """Return a structure of bind parameters for this compiled object.
-
-        This includes bind parameters that might be compiled in via
-        the `values` argument of an ``Insert`` or ``Update`` statement
-        object, and also the given `**params`.  The keys inside of
-        `**params` can be any key that matches the
-        ``BindParameterClause`` objects compiled within this object.
-
-        The output is dependent on the paramstyle of the DBAPI being
-        used; if a named style, the return result will be a dictionary
-        with keynames matching the compiled statement.  If a
-        positional style, the output will be a list, with an iterator
-        that will return parameter values in an order corresponding to
-        the bind positions in the compiled statement.
-
-        For an executemany style of call, this method should be called
-        for each element in the list of parameter groups that will
-        ultimately be executed.
+        """Return a sql.ClauseParameters object.
+        
+        Combines the given bind parameter dictionary (string keys to object values)
+        with the _BindParamClause objects stored within this Compiled object
+        to produce a ClauseParameters structure, representing the bind arguments
+        for a single statement execution, or one element of an executemany execution.
         """
         
         if self.parameters is not None:
index cc9fd3d9a4c48dd210badfe008cc6f78ee906457..50432f61478b3f3671ead9d08af03933ffc5efea 100644 (file)
@@ -600,9 +600,9 @@ class OracleSchemaDropper(ansisql.ANSISchemaDropper):
 class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
     def exec_default_sql(self, default):
         c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
-        return self.connection.execute_compiled(c).scalar()
+        return self.connection.execute(c).scalar()
 
     def visit_sequence(self, seq):
-        return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
+        return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
 
 dialect = OracleDialect
index cd486e282392c45c54d110782b62119295053954..96ca048b11b1796478b00b07a1933fa2aed05375 100644 (file)
@@ -599,7 +599,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
         if column.primary_key:
             # passive defaults on primary keys have to be overridden
             if isinstance(column.default, schema.PassiveDefault):
-                return self.connection.execute_text("select %s" % column.default.arg).scalar()
+                return self.connection.execute("select %s" % column.default.arg).scalar()
             elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
                 sch = column.table.schema
                 # TODO: this has to build into the Sequence object so we can get the quoting
@@ -608,7 +608,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
                     exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
                 else:
                     exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
-                return self.connection.execute_text(exc).scalar()
+                return self.connection.execute(exc).scalar()
 
         return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
 
index 9beb08826a2c0d5ee06b195017b7d5827457e80e..5bc72db8acc1d7f8fc945d2d30cc6a9dbb0751f0 100644 (file)
@@ -47,15 +47,6 @@ class Dialect(sql.AbstractDialect):
 
         raise NotImplementedError()
 
-    def convert_compiled_params(self, parameters):
-        """Build DBAPI execute arguments from a [sqlalchemy.sql#ClauseParameters] instance.
-
-        Returns an array or dictionary suitable to pass directly to this ``Dialect`` instance's DBAPI's
-        execute method.
-        """
-
-        raise NotImplementedError()
-
     def dbapi_type_map(self):
         """return a mapping of DBAPI type objects present in this Dialect's DBAPI
         mapped to TypeEngine implementations used by the dialect. 
@@ -295,19 +286,18 @@ class ExecutionContext(object):
         compiled
             if passed to constructor, sql.Compiled object being executed
         
-        compiled_parameters
-            if passed to constructor, sql.ClauseParameters object
-             
         statement
             string version of the statement to be executed.  Is either
             passed to the constructor, or must be created from the 
             sql.Compiled object by the time pre_exec() has completed.
             
         parameters
-            "raw" parameters suitable for direct execution by the
-            dialect.  Either passed to the constructor, or must be
-            created from the sql.ClauseParameters object by the time 
-            pre_exec() has completed.
+            bind parameters passed to the execute() method.  for
+            compiled statements, this is a dictionary or list 
+            of dictionaries.  for textual statements, it should
+            be in a format suitable for the dialect's paramstyle
+            (i.e. dict or list of dicts for non positional,
+            list or list of lists/tuples for positional).
             
     
     The Dialect should provide an ExecutionContext via the
@@ -317,24 +307,28 @@ class ExecutionContext(object):
     """
 
     def create_cursor(self):
-        """Return a new cursor generated this ExecutionContext's connection."""
+        """Return a new cursor generated from this ExecutionContext's connection.
+        
+        Some dialects may wish to change the behavior of connection.cursor(),
+        such as postgres which may return a PG "server side" cursor.
+        """
 
         raise NotImplementedError()
 
-    def pre_exec(self):
+    def pre_execution(self):
         """Called before an execution of a compiled statement.
         
-        If compiled and compiled_parameters were passed to this
+        If a compiled statement was passed to this
         ExecutionContext, the `statement` and `parameters` datamembers
         must be initialized after this statement is complete.
         """
 
         raise NotImplementedError()
 
-    def post_exec(self):
+    def post_execution(self):
         """Called after the execution of a compiled statement.
         
-        If compiled was passed to this ExecutionContext,
+        If a compiled statement was passed to this ExecutionContext,
         the `last_insert_ids`, `last_inserted_params`, etc. 
         datamembers should be available after this method
         completes.
@@ -342,8 +336,11 @@ class ExecutionContext(object):
 
         raise NotImplementedError()
     
-    def get_result_proxy(self):
-        """return a ResultProxy corresponding to this ExecutionContext."""
+    def result(self):
+        """return a result object corresponding to this ExecutionContext.
+        
+        Returns a ResultProxy."""
+        
         raise NotImplementedError()
         
     def get_rowcount(self):
@@ -450,12 +447,9 @@ class Compiled(sql.ClauseVisitor):
     def construct_params(self, params):
         """Return the bind params for this compiled object.
 
-        Will start with the default parameters specified when this
-        ``Compiled`` object was first constructed, and will override
-        those values with those sent via `**params`, which are
-        key/value pairs.  Each key should match one of the
-        ``_BindParamClause`` objects compiled into this object; either
-        the `key` or `shortname` property of the ``_BindParamClause``.
+        params is a dict of string/object pairs whos 
+        values will override bind values compiled in
+        to the statement.
         """
         raise NotImplementedError()
 
@@ -465,7 +459,7 @@ class Compiled(sql.ClauseVisitor):
         e = self.bind
         if e is None:
             raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.")
-        return e.execute_compiled(self, *multiparams, **params)
+        return e._execute_compiled(self, multiparams, params)
 
     def scalar(self, *multiparams, **params):
         """Execute this compiled object and return the result's scalar value."""
@@ -693,75 +687,66 @@ class Connection(Connectable):
     def execute(self, object, *multiparams, **params):
         for c in type(object).__mro__:
             if c in Connection.executors:
-                return Connection.executors[c](self, object, *multiparams, **params)
+                return Connection.executors[c](self, object, multiparams, params)
         else:
             raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
 
-    def execute_default(self, default, **kwargs):
+    def _execute_default(self, default, multiparams=None, params=None):
         return self.__engine.dialect.defaultrunner(self).traverse_single(default)
 
-    def execute_text(self, statement, *multiparams, **params):
-        if len(multiparams) == 0:
+    def _execute_text(self, statement, multiparams, params):
+        parameters = self.__distill_params(multiparams, params)
+        context = self.__create_execution_context(statement=statement, parameters=parameters)
+        self.__execute_raw(context)
+        return context.result()
+
+    def __distill_params(self, multiparams, params):
+        if multiparams is None or len(multiparams) == 0:
             parameters = params or None
         elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)):
             parameters = multiparams[0]
         else:
             parameters = list(multiparams)
-        context = self._create_execution_context(statement=statement, parameters=parameters)
-        self._execute_raw(context)
-        return context.get_result_proxy()
-
-    def _params_to_listofdicts(self, *multiparams, **params):
-        if len(multiparams) == 0:
-            return [params]
-        elif len(multiparams) == 1:
-            if multiparams[0] == None:
-                return [{}]
-            elif isinstance (multiparams[0], (list, tuple)):
-                return multiparams[0]
-            else:
-                return [multiparams[0]]
-        else:
-            return multiparams
-    
-    def execute_function(self, func, *multiparams, **params):
-        return self.execute_clauseelement(func.select(), *multiparams, **params)
+        return parameters
+
+    def _execute_function(self, func, multiparams, params):
+        return self._execute_clauseelement(func.select(), multiparams, params)
         
-    def execute_clauseelement(self, elem, *multiparams, **params):
-        executemany = len(multiparams) > 0
+    def _execute_clauseelement(self, elem, multiparams=None, params=None):
+        executemany = multiparams is not None and len(multiparams) > 0
         if executemany:
             param = multiparams[0]
         else:
             param = params
-        return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params)
+        return self._execute_compiled(elem.compile(dialect=self.dialect, parameters=param), multiparams, params)
 
-    def execute_compiled(self, compiled, *multiparams, **params):
+    def _execute_compiled(self, compiled, multiparams=None, params=None):
         """Execute a sql.Compiled object."""
         if not compiled.can_execute:
             raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
-        parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)]
-        if len(parameters) == 1:
-            parameters = parameters[0]
-        context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters)
-        context.pre_exec()
-        self._execute_raw(context)
-        context.post_exec()
-        return context.get_result_proxy()
-    
-    def _create_execution_context(self, **kwargs):
+
+        params = self.__distill_params(multiparams, params)
+        context = self.__create_execution_context(compiled=compiled, parameters=params)
+        
+        context.pre_execution()
+        self.__execute_raw(context)
+        context.post_execution()
+        return context.result()
+            
+    def __create_execution_context(self, **kwargs):
         return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
         
-    def _execute_raw(self, context):
+    def __execute_raw(self, context):
         if logging.is_info_enabled(self.__engine.logger):
             self.__engine.logger.info(context.statement)
             self.__engine.logger.info(repr(context.parameters))
         if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)):
-            self._executemany(context)
+            self.__executemany(context)
         else:
-            self._execute(context)
+            self.__execute(context)
         self._autocommit(context.statement)
 
-    def _execute(self, context):
+    def __execute(self, context):
         if context.parameters is None:
             if context.dialect.positional:
                 context.parameters = ()
@@ -778,7 +763,7 @@ class Connection(Connectable):
                 self.close()
             raise exceptions.SQLError(context.statement, context.parameters, e)
 
-    def _executemany(self, context):
+    def __executemany(self, context):
         try:
             context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
         except Exception, e:
@@ -792,11 +777,11 @@ class Connection(Connectable):
 
     # poor man's multimethod/generic function thingy
     executors = {
-        sql._Function : execute_function,
-        sql.ClauseElement : execute_clauseelement,
-        sql.ClauseVisitor : execute_compiled,
-        schema.SchemaItem:execute_default,
-        str.__mro__[-2] : execute_text
+        sql._Function : _execute_function,
+        sql.ClauseElement : _execute_clauseelement,
+        sql.ClauseVisitor : _execute_compiled,
+        schema.SchemaItem:_execute_default,
+        str.__mro__[-2] : _execute_text
     }
 
     def create(self, entity, **kwargs):
@@ -934,10 +919,10 @@ class Engine(Connectable):
 
         self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs)
 
-    def execute_default(self, default, **kwargs):
+    def _execute_default(self, default):
         connection = self.contextual_connect()
         try:
-            return connection.execute_default(default, **kwargs)
+            return connection._execute_default(default)
         finally:
             connection.close()
 
@@ -1006,9 +991,9 @@ class Engine(Connectable):
     def scalar(self, statement, *multiparams, **params):
         return self.execute(statement, *multiparams, **params).scalar()
 
-    def execute_compiled(self, compiled, *multiparams, **params):
+    def _execute_compiled(self, compiled, multiparams, params):
         connection = self.contextual_connect(close_with_result=True)
-        return connection.execute_compiled(compiled, *multiparams, **params)
+        return connection._execute_compiled(compiled, multiparams, params)
 
     def compiler(self, statement, parameters, **kwargs):
         return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
@@ -1509,7 +1494,7 @@ class DefaultRunner(schema.SchemaVisitor):
 
     def exec_default_sql(self, default):
         c = sql.select([default.arg]).compile(bind=self.connection)
-        return self.connection.execute_compiled(c).scalar()
+        return self.connection._execute_compiled(c).scalar()
 
     def visit_column_onupdate(self, onupdate):
         if isinstance(onupdate.arg, sql.ClauseElement):
index f87551c2a87b06f52e48935dcde8f6d655ea5278..dfdc1baaa495f619595025dba1809a4832a8baa6 100644 (file)
@@ -127,23 +127,6 @@ class DefaultDialect(base.Dialect):
 
     paramstyle = property(lambda s:s._paramstyle, _set_paramstyle)
 
-    def convert_compiled_params(self, parameters):
-        executemany = parameters is not None and isinstance(parameters, list)
-        # the bind params are a CompiledParams object.  but all the DBAPI's hate
-        # that object (or similar).  so convert it to a clean
-        # dictionary/list/tuple of dictionary/tuple of list
-        if parameters is not None:
-           if self.positional:
-                if executemany:
-                    parameters = [p.get_raw_list() for p in parameters]
-                else:
-                    parameters = parameters.get_raw_list()
-           else:
-                if executemany:
-                    parameters = [p.get_raw_dict() for p in parameters]
-                else:
-                    parameters = parameters.get_raw_dict()
-        return parameters
 
     def _figure_paramstyle(self, paramstyle=None, default='named'):
         if paramstyle is not None:
@@ -172,19 +155,26 @@ class DefaultDialect(base.Dialect):
     ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
 
 class DefaultExecutionContext(base.ExecutionContext):
-    def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
+    def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
         self.dialect = dialect
         self.connection = connection
         self.compiled = compiled
-        self.compiled_parameters = compiled_parameters
         
         if compiled is not None:
             self.typemap = compiled.typemap
             self.column_labels = compiled.column_labels
             self.statement = unicode(compiled)
+            if parameters is None:
+                self.compiled_parameters = compiled.construct_params({})
+            elif not isinstance(parameters, (list, tuple)):
+                self.compiled_parameters = compiled.construct_params(parameters)
+            else:
+                self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters]
+                if len(self.compiled_parameters) == 1:
+                    self.compiled_parameters = self.compiled_parameters[0]
         else:
             self.typemap = self.column_labels = None
-            self.parameters = self._encode_param_keys(parameters)
+            self.parameters = self.__encode_param_keys(parameters)
             self.statement = statement
 
         if not dialect.supports_unicode_statements():
@@ -194,7 +184,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         
     engine = property(lambda s:s.connection.engine)
     
-    def _encode_param_keys(self, params):
+    def __encode_param_keys(self, params):
         """apply string encoding to the keys of dictionary-based bind parameters"""
         if self.dialect.positional or self.dialect.supports_unicode_statements():
             return params
@@ -209,6 +199,25 @@ class DefaultExecutionContext(base.ExecutionContext):
                 return [proc(d) for d in params]
             else:
                 return proc(params)
+
+    def __convert_compiled_params(self, parameters):
+        executemany = parameters is not None and isinstance(parameters, list)
+        encode = not self.dialect.supports_unicode_statements()
+        # the bind params are a CompiledParams object.  but all the DBAPI's hate
+        # that object (or similar).  so convert it to a clean
+        # dictionary/list/tuple of dictionary/tuple of list
+        if parameters is not None:
+           if self.dialect.positional:
+                if executemany:
+                    parameters = [p.get_raw_list() for p in parameters]
+                else:
+                    parameters = parameters.get_raw_list()
+           else:
+                if executemany:
+                    parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters]
+                else:
+                    parameters = parameters.get_raw_dict(encode_keys=encode)
+        return parameters
                 
     def is_select(self):
         """return TRUE if the statement is expected to have result rows."""
@@ -217,10 +226,19 @@ class DefaultExecutionContext(base.ExecutionContext):
 
     def create_cursor(self):
         return self.connection.connection.cursor()
-        
+
+    def pre_execution(self):
+        self.pre_exec()
+    
+    def post_execution(self):
+        self.post_exec()
+    
+    def result(self):
+        return self.get_result_proxy()
+            
     def pre_exec(self):
         self._process_defaults()
-        self.parameters = self._encode_param_keys(self.dialect.convert_compiled_params(self.compiled_parameters))
+        self.parameters = self.__convert_compiled_params(self.compiled_parameters)
 
     def post_exec(self):
         pass
@@ -279,26 +297,10 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.cursor.setinputsizes(**inputsizes)
 
     def _process_defaults(self):
-        """``INSERT`` and ``UPDATE`` statements, when compiled, may
-        have additional columns added to their ``VALUES`` and ``SET``
-        lists corresponding to column defaults/onupdates that are
-        present on the ``Table`` object (i.e. ``ColumnDefault``,
-        ``Sequence``, ``PassiveDefault``).  This method pre-execs
-        those ``DefaultGenerator`` objects that require pre-execution
-        and sets their values within the parameter list, and flags this
-        ExecutionContext about ``PassiveDefault`` objects that may
-        require post-fetching the row after it is inserted/updated.
-
-        This method relies upon logic within the ``ANSISQLCompiler``
-        in its `visit_insert` and `visit_update` methods that add the
-        appropriate column clauses to the statement when its being
-        compiled, so that these parameters can be bound to the
-        statement.
-        """
+        """generate default values for compiled insert/update statements,
+        and generate last_inserted_ids() collection."""
 
-        # TODO: this calculation of defaults is one of the places SA slows down inserts.
-        # look into optimizing this for a list of params where theres no defaults defined
-        # (i.e. analyze the first batch of params).
+        # TODO: cleanup
         if self.compiled.isinsert:
             if isinstance(self.compiled_parameters, list):
                 plist = self.compiled_parameters
index 7004094aa14d9a11d9212aff62417af883c8dfd5..7a278053747fb3c933b96b2efeee2de65d6244d0 100644 (file)
@@ -778,7 +778,7 @@ class DefaultGenerator(SchemaItem):
     def execute(self, bind=None, **kwargs):
         if bind is None:
             bind = self._get_engine(raiseerr=True)
-        return bind.execute_default(self, **kwargs)
+        return bind._execute_default(self, **kwargs)
 
     def __repr__(self):
         return "DefaultGenerator()"
index fb8c8ff8595d4ca4c12d69dfd6e7fd3159c65aa1..3c23d67398370000cfc0261eccde9896e11997d1 100644 (file)
@@ -781,52 +781,42 @@ class ClauseParameters(object):
     def __init__(self, dialect, positional=None):
         super(ClauseParameters, self).__init__()
         self.dialect = dialect
-        self.binds = {}
-        self.binds_to_names = {}
-        self.binds_to_values = {}
+        self.__binds = {}
         self.positional = positional or []
 
     def set_parameter(self, bindparam, value, name):
-        self.binds[bindparam.key] = bindparam
-        self.binds[name] = bindparam
-        self.binds_to_names[bindparam] = name
-        self.binds_to_values[bindparam] = value
+        self.__binds[name] = [bindparam, name, value]
         
     def get_original(self, key):
-        """Return the given parameter as it was originally placed in
-        this ``ClauseParameters`` object, without any ``Type``
-        conversion."""
-        return self.binds_to_values[self.binds[key]]
+        return self.__binds[key][2]
 
     def get_processed(self, key):
-        bind = self.binds[key]
-        value = self.binds_to_values[bind]
+        (bind, name, value) = self.__binds[key]
         return bind.typeprocess(value, self.dialect)
    
     def keys(self):
-        return self.binds_to_names.values()
+        return self.__binds.keys()
  
     def __getitem__(self, key):
         return self.get_processed(key)
         
     def __contains__(self, key):
-        return key in self.binds
+        return key in self.__binds
     
     def set_value(self, key, value):
-        bind = self.binds[key]
-        self.binds_to_values[bind] = value
+        self.__binds[key][2] = value
             
     def get_original_dict(self):
-        return dict([(self.binds_to_names[b], self.binds_to_values[b]) for b in self.binds_to_names.keys()])
+        return dict([(name, value) for (b, name, value) in self.__binds.values()])
 
     def get_raw_list(self):
         return [self.get_processed(key) for key in self.positional]
 
-    def get_raw_dict(self):
-        d = {}
-        for k in self.binds_to_names.values():
-            d[k] = self.get_processed(k)
-        return d
+    def get_raw_dict(self, encode_keys=False):
+        if encode_keys:
+            return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()])
+        else:
+            return dict([(key, self.get_processed(key)) for key in self.keys()])
 
     def __repr__(self):
         return self.__class__.__name__ + ":" + repr(self.get_original_dict())
index 550966d0a40b978de7dc960eb64203c46fb83fb3..f80ddcadd66280bf69ef600d67fc7b8f82a98327 100644 (file)
@@ -56,7 +56,7 @@ class DomainReflectionTest(AssertMixin):
 
     @testing.supported('postgres')
     def test_crosschema_domain_is_reflected(self):
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         table = Table('crosschema', metadata, autoload=True)
         self.assertEquals(str(table.columns.answer.default.arg), '0', "Reflected default value didn't equal expected value")
         self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
index 1c2bd1b57c20e417b59d6ceb5546fbe734507832..3120185d59d9c01598bc7a3f8b746a215bcaaf7b 100644 (file)
@@ -172,12 +172,13 @@ class ConstraintTest(AssertMixin):
 
         capt = []
         connection = testbase.db.connect()
-        ex = connection._execute
+        # TODO: hacky, put a real connection proxy in
+        ex = connection._Connection__execute
         def proxy(context):
             capt.append(context.statement)
             capt.append(repr(context.parameters))
             ex(context)
-        connection._execute = proxy
+        connection._Connection__execute = proxy
         schemagen = testbase.db.dialect.schemagenerator(connection)
         schemagen.traverse(events)
         
index 28e7db3a3ed88f45906effaed730d9aca328c14a..8dbeda19af4c6c2c2dc35bb39b9c754165fddaf3 100644 (file)
@@ -341,7 +341,7 @@ class DateTest(AssertMixin):
                          Sequence('datetest_id_seq', optional=True),
                          primary_key=True),
                 Column('adate', Date), Column('adatetime', DateTime))
-        t.create()
+        t.create(checkfirst=True)
         try:
             d1 = datetime.date(2007, 10, 30)
             t.insert().execute(adate=d1, adatetime=d1)
@@ -353,7 +353,7 @@ class DateTest(AssertMixin):
             self.assert_(x.adatetime.__class__ == datetime.datetime)
 
         finally:
-            t.drop()
+            t.drop(checkfirst=True)
 
 class IntervalTest(AssertMixin):
     def setUpAll(self):
index 0361cfb682360920e11c3d70e329fda874435e41..213772e9e167975129c0f3f195137b80d5baa7e9 100644 (file)
@@ -76,7 +76,7 @@ class ExecutionContextWrapper(object):
     def __setattr__(self, key, value):
         setattr(self.ctx, key, value)
         
-    def post_exec(self):
+    def post_execution(self):
         ctx = self.ctx
         statement = unicode(ctx.compiled)
         statement = re.sub(r'\n', '', ctx.statement)
@@ -123,7 +123,7 @@ class ExecutionContextWrapper(object):
                 statement = statement[:-25]
             testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
         testdata.sql_count += 1
-        self.ctx.post_exec()
+        self.ctx.post_execution()
         
     def convert_statement(self, query):
         paramstyle = self.ctx.dialect.paramstyle