]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
sequence pre-executes dont create an ExecutionContext, use straight cursor
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Sep 2007 22:42:51 +0000 (22:42 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Sep 2007 22:42:51 +0000 (22:42 +0000)
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
test/sql/constraints.py

index fb5b512e236efc1666c3c08068a4282493366b76..5b852c185f0565577ed3cdf5f76a6d2700346c53 100644 (file)
@@ -672,12 +672,8 @@ class OracleSchemaDropper(compiler.SchemaDropper):
             self.execute()
 
 class OracleDefaultRunner(base.DefaultRunner):
-    def exec_default_sql(self, default):
-        c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
-        return self.connection.execute(c).scalar()
-
     def visit_sequence(self, seq):
-        return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
+        return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL")
 
 class OracleIdentifierPreparer(compiler.IdentifierPreparer):
     def format_savepoint(self, savepoint):
index 701114b17e9752a7c485542f6d24a79594f2b4bd..f5eaf47f21367650ace9fb89ad7d8c6714f5ea88 100644 (file)
@@ -611,9 +611,9 @@ class PGSchemaDropper(compiler.SchemaDropper):
 class PGDefaultRunner(base.DefaultRunner):
     def get_column_default(self, column, isinsert=True):
         if column.primary_key:
-            # passive defaults on primary keys have to be overridden
+            # pre-execute passive defaults on primary keys
             if isinstance(column.default, schema.PassiveDefault):
-                return self.connection.execute("select %s" % column.default.arg).scalar()
+                return self.execute_string("select %s" % column.default.arg)
             elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
                 sch = column.table.schema
                 # TODO: this has to build into the Sequence object so we can get the quoting
@@ -622,13 +622,13 @@ class PGDefaultRunner(base.DefaultRunner):
                     exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
                 else:
                     exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
-                return self.connection.execute(exc).scalar()
+                return self.execute_string(exc.encode(self.dialect.encoding))
 
         return super(PGDefaultRunner, self).get_column_default(column)
 
     def visit_sequence(self, seq):
         if not seq.optional:
-            return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar()
+            return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).encode(self.dialect.encoding))
         else:
             return None
 
index 1ab05fe036eb3a3d70c26fdb3bbeea158b539f5c..45d84f90d1a01efd4d3f98196e2a2480a4896d82 100644 (file)
@@ -837,45 +837,50 @@ class Connection(Connectable):
         return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
 
     def __execute_raw(self, context):
-        if self.__engine._should_log_info:
-            self.__engine.logger.info(context.statement)
-            self.__engine.logger.info(repr(context.parameters))
         if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)):
-            self.__executemany(context)
+            self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context)
         else:
-            self.__execute(context)
+            if context.parameters is None:
+                if context.dialect.positional:
+                    parameters = ()
+                else:
+                    parameters = {}
+            else:
+                parameters = context.parameters
+            self._cursor_execute(context.cursor, context.statement, parameters, context=context)
         self._autocommit(context)
 
-    def __execute(self, context):
-        if context.parameters is None:
-            if context.dialect.positional:
-                context.parameters = ()
-            else:
-                context.parameters = {}
+    def _cursor_execute(self, cursor, statement, parameters, context=None):
+        if self.__engine._should_log_info:
+            self.__engine.logger.info(statement)
+            self.__engine.logger.info(repr(parameters))
         try:
-            context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context)
+            self.dialect.do_execute(cursor, statement, parameters, context=context)
         except Exception, e:
             if self.dialect.is_disconnect(e):
                 self.__connection.invalidate(e=e)
                 self.engine.dispose()
-            context.cursor.close()
+            cursor.close()
             self._autorollback()
             if self.__close_with_result:
                 self.close()
-            raise exceptions.DBAPIError.instance(context.statement, context.parameters, e)
+            raise exceptions.DBAPIError.instance(statement, parameters, e)
 
-    def __executemany(self, context):
+    def _cursor_executemany(self, cursor, statement, parameters, context=None):
+        if self.__engine._should_log_info:
+            self.__engine.logger.info(statement)
+            self.__engine.logger.info(repr(parameters))
         try:
-            context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
+            self.dialect.do_executemany(cursor, statement, parameters, context=context)
         except Exception, e:
             if self.dialect.is_disconnect(e):
                 self.__connection.invalidate(e=e)
                 self.engine.dispose()
-            context.cursor.close()
+            cursor.close()
             self._autorollback()
             if self.__close_with_result:
                 self.close()
-            raise exceptions.DBAPIError.instance(context.statement, context.parameters, e)
+            raise exceptions.DBAPIError.instance(statement, parameters, e)
 
     # poor man's multimethod/generic function thingy
     executors = {
@@ -1632,7 +1637,6 @@ class DefaultRunner(schema.SchemaVisitor):
 
     def __init__(self, context):
         self.context = context
-        self.connection = context._connection._branch()
         self.dialect = context.dialect
 
     def get_column_default(self, column):
@@ -1665,9 +1669,17 @@ class DefaultRunner(schema.SchemaVisitor):
         return None
 
     def exec_default_sql(self, default):
-        c = expression.select([default.arg]).compile(bind=self.connection)
-        return self.connection._execute_compiled(c).scalar()
-
+        conn = self.context.connection
+        c = expression.select([default.arg]).compile(bind=conn)
+        return conn._execute_compiled(c).scalar()
+    
+    def execute_string(self, stmt, params=None):
+        """execute a string statement, using the raw cursor,
+        and return a scalar result."""
+        conn = self.context._connection
+        conn._cursor_execute(self.context.cursor, stmt, params)
+        return self.context.cursor.fetchone()[0]
+        
     def visit_column_onupdate(self, onupdate):
         if isinstance(onupdate.arg, expression.ClauseElement):
             return self.exec_default_sql(onupdate)
index 93ba231ab559b1aed12476bae381d121486f2039..c5320ada375b32671718f1e0cad2ea0fb359fada 100644 (file)
@@ -174,12 +174,12 @@ class ConstraintTest(AssertMixin):
         capt = []
         connection = testbase.db.connect()
         # TODO: hacky, put a real connection proxy in
-        ex = connection._Connection__execute
+        ex = connection._Connection__execute_raw
         def proxy(context):
             capt.append(context.statement)
             capt.append(repr(context.parameters))
             ex(context)
-        connection._Connection__execute = proxy
+        connection._Connection__execute_raw = proxy
         schemagen = testbase.db.dialect.schemagenerator(testbase.db.dialect, connection)
         schemagen.traverse(events)