]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more inlines
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Dec 2010 03:48:39 +0000 (22:48 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Dec 2010 03:48:39 +0000 (22:48 -0500)
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/oracle/zxjdbc.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py

index b7d6631388a0edf09271fa009d0b6f90d1808757..d705067d021e38109ae1155ba845f955a1474e5a 100644 (file)
@@ -327,7 +327,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
                                                         self.out_parameters[name]
         
     def create_cursor(self):
-        c = self._connection.connection.cursor()
+        c = self._dbapi_connection.cursor()
         if self.dialect.arraysize:
             c.arraysize = self.dialect.arraysize
 
index d742654a0d0fee70d438484d76ce048bb462848c..67139f5d9c040d3be411d5af3bb6958ce0af0839 100644 (file)
@@ -109,7 +109,7 @@ class OracleExecutionContext_zxjdbc(OracleExecutionContext):
         return base.ResultProxy(self)
 
     def create_cursor(self):
-        cursor = self._connection.connection.cursor()
+        cursor = self._dbapi_connection.cursor()
         cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
         return cursor
 
index b3f42c330607188992737a1c774ed1e39534f1e2..59251900cd46701a7456f54b863213c08e623c52 100644 (file)
@@ -172,9 +172,9 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
             # use server-side cursors:
             # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
             ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
-            return self._connection.connection.cursor(ident)
+            return self._dbapi_connection.cursor(ident)
         else:
-            return self._connection.connection.cursor()
+            return self._dbapi_connection.cursor()
 
     def get_result_proxy(self):
         # TODO: ouch
index f3f32f833f5ce2991b4ec48f3da6a7398f8f3951..4e11117f7b11c8d941381ca7d7ae214e36bd50cd 100644 (file)
@@ -889,16 +889,19 @@ class Connection(Connectable):
         try:
             return self.__connection
         except AttributeError:
-            if self.__invalid:
-                if self.__transaction is not None:
-                    raise exc.InvalidRequestError(
-                                    "Can't reconnect until invalid "
-                                    "transaction is rolled back")
-                self.__connection = self.engine.raw_connection()
-                self.__invalid = False
-                return self.__connection
-            raise exc.ResourceClosedError("This Connection is closed")
-
+            return self._revalidate_connection()
+        
+    def _revalidate_connection(self):
+        if self.__invalid:
+            if self.__transaction is not None:
+                raise exc.InvalidRequestError(
+                                "Can't reconnect until invalid "
+                                "transaction is rolled back")
+            self.__connection = self.engine.raw_connection()
+            self.__invalid = False
+            return self.__connection
+        raise exc.ResourceClosedError("This Connection is closed")
+        
     @property
     def _connection_is_valid(self):
         # use getattr() for is_valid to support exceptions raised in
@@ -1214,9 +1217,14 @@ class Connection(Connectable):
         """Execute a schema.ColumnDefault object."""
         
         try:
+            try:
+                conn = self.__connection
+            except AttributeError:
+                conn = self._revalidate_connection()
+            
             dialect = self.dialect
             ctx = dialect.execution_ctx_cls._init_default(
-                                dialect, self)
+                                dialect, self, conn)
         except Exception, e:
             self._handle_dbapi_exception(e, None, None, None, None)
             raise
@@ -1306,7 +1314,12 @@ class Connection(Connectable):
         a :class:`.ResultProxy`."""
         
         try:
-            context = constructor(dialect, self, *args)
+            try:
+                conn = self.__connection
+            except AttributeError:
+                conn = self._revalidate_connection()
+            
+            context = constructor(dialect, self, conn, *args)
         except Exception, e:
             self._handle_dbapi_exception(e, 
                         statement, parameters, 
index 3fc5910676017dd82912c8c57abbbff92542b565..757e42d0323ef7d38c5aab4e105fa2ae68a982a1 100644 (file)
@@ -322,14 +322,17 @@ class DefaultExecutionContext(base.ExecutionContext):
     result_map = None
     compiled = None
     statement = None
+    _is_implicit_returning = False
+    _is_explicit_returning = False
     
     @classmethod
-    def _init_ddl(cls, dialect, connection, compiled_ddl):
+    def _init_ddl(cls, dialect, connection, dbapi_connection, compiled_ddl):
         """Initialize execution context for a DDLElement construct."""
         
         self = cls.__new__(cls)
         self.dialect = dialect
-        self._connection = self.root_connection = connection
+        self.root_connection = connection
+        self._dbapi_connection = dbapi_connection
         self.engine = connection.engine
 
         self.compiled = compiled = compiled_ddl
@@ -357,12 +360,13 @@ class DefaultExecutionContext(base.ExecutionContext):
         return self
         
     @classmethod
-    def _init_compiled(cls, dialect, connection, compiled, parameters):
+    def _init_compiled(cls, dialect, connection, dbapi_connection, compiled, parameters):
         """Initialize execution context for a Compiled construct."""
 
         self = cls.__new__(cls)
         self.dialect = dialect
-        self._connection = self.root_connection = connection
+        self.root_connection = connection
+        self._dbapi_connection = dbapi_connection
         self.engine = connection.engine
 
         self.compiled = compiled
@@ -389,6 +393,11 @@ class DefaultExecutionContext(base.ExecutionContext):
         self.isinsert = compiled.isinsert
         self.isupdate = compiled.isupdate
         self.isdelete = compiled.isdelete
+        
+        if self.isinsert or self.isupdate or self.isdelete:
+            self._is_explicit_returning = compiled.statement._returning
+            self._is_implicit_returning = compiled.returning and \
+                                            not compiled.statement._returning
 
         if not parameters:
             self.compiled_parameters = [compiled.construct_params()]
@@ -444,12 +453,13 @@ class DefaultExecutionContext(base.ExecutionContext):
         return self
     
     @classmethod
-    def _init_statement(cls, dialect, connection, statement, parameters):
+    def _init_statement(cls, dialect, connection, dbapi_connection, statement, parameters):
         """Initialize execution context for a string SQL statement."""
 
         self = cls.__new__(cls)
         self.dialect = dialect
-        self._connection = self.root_connection = connection
+        self.root_connection = connection
+        self._dbapi_connection = dbapi_connection
         self.engine = connection.engine
 
         # plain text statement
@@ -486,12 +496,13 @@ class DefaultExecutionContext(base.ExecutionContext):
         return self
     
     @classmethod
-    def _init_default(cls, dialect, connection):
+    def _init_default(cls, dialect, connection, dbapi_connection):
         """Initialize execution context for a ColumnDefault construct."""
 
         self = cls.__new__(cls)
         self.dialect = dialect
-        self._connection = self.root_connection = connection
+        self.root_connection = connection
+        self._dbapi_connection = dbapi_connection
         self.engine = connection.engine
         self.execution_options = connection._execution_options
         self.cursor = self.create_cursor()
@@ -514,17 +525,6 @@ class DefaultExecutionContext(base.ExecutionContext):
         else:
             return autocommit
             
-    @util.memoized_property
-    def _is_explicit_returning(self):
-        return self.compiled and \
-            getattr(self.compiled.statement, '_returning', False)
-    
-    @util.memoized_property
-    def _is_implicit_returning(self):
-        return self.compiled and \
-            bool(self.compiled.returning) and \
-            not self.compiled.statement._returning
-    
     def _execute_scalar(self, stmt):
         """Execute a string statement on the current cursor, returning a
         scalar result.
@@ -535,7 +535,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         
         """
 
-        conn = self._connection
+        conn = self.root_connection
         if isinstance(stmt, unicode) and \
             not self.dialect.supports_unicode_statements:
             stmt = stmt.encode(self.dialect.encoding)
@@ -550,13 +550,13 @@ class DefaultExecutionContext(base.ExecutionContext):
     
     @property
     def connection(self):
-        return self._connection._branch()
+        return self.root_connection._branch()
 
     def should_autocommit_text(self, statement):
         return AUTOCOMMIT_REGEXP.match(statement)
 
     def create_cursor(self):
-        return self._connection.connection.cursor()
+        return self._dbapi_connection.cursor()
 
     def pre_exec(self):
         pass
@@ -610,7 +610,7 @@ class DefaultExecutionContext(base.ExecutionContext):
     
     def post_insert(self):
         if self.dialect.postfetch_lastrowid and \
-            (not len(self.inserted_primary_key) or \
+            (not self.inserted_primary_key or \
                         None in self.inserted_primary_key):
             
             table = self.compiled.statement.table
@@ -664,7 +664,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             try:
                 self.cursor.setinputsizes(*inputsizes)
             except Exception, e:
-                self._connection._handle_dbapi_exception(e, None, None, None, self)
+                self.root_connection._handle_dbapi_exception(e, None, None, None, self)
                 raise
         else:
             inputsizes = {}
@@ -678,7 +678,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             try:
                 self.cursor.setinputsizes(**inputsizes)
             except Exception, e:
-                self._connection._handle_dbapi_exception(e, None, None, None, self)
+                self.root_connection._handle_dbapi_exception(e, None, None, None, self)
                 raise
 
     def _exec_default(self, default):
index cf1e28f50a7325c5d323b53022bece4905629298..8474ebaccb2146e73241627755e3f88853140755 100644 (file)
@@ -277,25 +277,24 @@ class SQLCompiler(engine.Compiled):
         if params:
             pd = {}
             for bindparam, name in self.bind_names.iteritems():
-                for paramname in (bindparam.key, name):
-                    if paramname in params:
-                        pd[name] = params[paramname]
-                        break
-                else:
-                    if bindparam.required:
-                        if _group_number:
-                            raise exc.InvalidRequestError(
-                                            "A value is required for bind parameter %r, "
-                                            "in parameter group %d" % 
-                                            (bindparam.key, _group_number))
-                        else:
-                            raise exc.InvalidRequestError(
-                                            "A value is required for bind parameter %r" 
-                                            % bindparam.key)
-                    elif bindparam.callable:
-                        pd[name] = bindparam.callable()
+                if bindparam.key in params:
+                    pd[name] = params[bindparam.key]
+                elif name in params:
+                    pd[name] = params[name]
+                elif bindparam.required:
+                    if _group_number:
+                        raise exc.InvalidRequestError(
+                                        "A value is required for bind parameter %r, "
+                                        "in parameter group %d" % 
+                                        (bindparam.key, _group_number))
                     else:
-                        pd[name] = bindparam.value
+                        raise exc.InvalidRequestError(
+                                        "A value is required for bind parameter %r" 
+                                        % bindparam.key)
+                elif bindparam.callable:
+                    pd[name] = bindparam.callable()
+                else:
+                    pd[name] = bindparam.value
             return pd
         else:
             pd = {}