]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged the patch from #516 + fixes
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Apr 2007 22:03:06 +0000 (22:03 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Apr 2007 22:03:06 +0000 (22:03 +0000)
- improves the framework for auto-invalidation of connections that have
lost their underlying database - the error catching/invalidate
step is totally moved to the connection pool.
- added better condition checking for do_rollback() and do_commit() including
SQLError excepetion wrapping

CHANGES
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/pool.py

diff --git a/CHANGES b/CHANGES
index 41a2ac3837f27a872e1917a83f773f3fc9525fc5..083114e36b971950d9cad2bb46d1dd2df492e4fa 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -11,6 +11,9 @@
       "buffered" result sets used for different purposes.
     - server side cursor support fully functional in postgres
       [ticket:514].
+    - improved framework for auto-invalidation of connections that have
+      lost their underlying database - the error catching/invalidate
+      step is totally moved to the connection pool. #516
 - sql:
     - the Unicode type is now a direct subclass of String, which now
       contains all the "convert_unicode" logic.  This helps the variety
index 6d2ff66cd594475eda59d7b73d1538bb4ac4d704..013e78c6af5d08529573d9153bd83ca77ad8fa7a 100644 (file)
@@ -553,6 +553,7 @@ class MSSQLDialect_pymssql(MSSQLDialect):
 
     def do_rollback(self, connection):
         # pymssql throws an error on repeated rollbacks. Ignore it.
+        # TODO: this is normal behavior for most DBs.  are we sure we want to ignore it ?
         try:
             connection.rollback()
         except:
@@ -571,6 +572,11 @@ class MSSQLDialect_pymssql(MSSQLDialect):
             del keys['port']
         return [[], keys]
 
+    def get_disconnect_checker(self):
+        def disconnect_checker(e):
+            return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
+        return disconnect_checker
+
 
 ##    This code is leftover from the initial implementation, for reference
 ##    def do_begin(self, connection):
@@ -630,6 +636,11 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
             connectors.append ("TrustedConnection=Yes")
         return [[";".join (connectors)], {}]
 
+    def get_disconnect_checker(self):
+        def disconnect_checker(e):
+            return isinstance(e, self.dbapi.Error) and '[08S01]' in e.args[1]
+        return disconnect_checker
+
 
 class MSSQLDialect_adodbapi(MSSQLDialect):
     def import_dbapi(cls):
@@ -660,6 +671,11 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
             connectors.append("Integrated Security=SSPI")
         return [[";".join (connectors)], {}]
 
+    def get_disconnect_checker(self):
+        def disconnect_checker(e):
+            return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
+        return disconnect_checker
+
 dialect_mapping = {
     'pymssql':  MSSQLDialect_pymssql,
     'pyodbc':   MSSQLDialect_pyodbc,
index 65ccb6af19f86d818c030be8053c65560f8701c2..7ea98e92f4e6cbce99868912dd2020fb75f7686b 100644 (file)
@@ -328,21 +328,12 @@ class MySQLDialect(ansisql.ANSIDialect):
         return MySQLIdentifierPreparer(self)
 
     def do_executemany(self, cursor, statement, parameters, context=None, **kwargs):
-        try:
-            rowcount = cursor.executemany(statement, parameters)
-            if context is not None:
-                context._rowcount = rowcount
-        except self.dbapi.OperationalError, o:
-            if o.args[0] == 2006 or o.args[0] == 2014:
-                cursor.invalidate()
-            raise o
+        rowcount = cursor.executemany(statement, parameters)
+        if context is not None:
+            context._rowcount = rowcount
+            
     def do_execute(self, cursor, statement, parameters, **kwargs):
-        try:
-            cursor.execute(statement, parameters)
-        except self.dbapi.OperationalError, o:
-            if o.args[0] == 2006 or o.args[0] == 2014:
-                cursor.invalidate()
-            raise o
+        cursor.execute(statement, parameters)
 
     def do_rollback(self, connection):
         # MySQL without InnoDB doesnt support rollback()
@@ -351,6 +342,12 @@ class MySQLDialect(ansisql.ANSIDialect):
         except:
             pass
 
+    def get_disconnect_checker(self):
+        def disconnect_checker(e):
+            return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2014)
+        return disconnect_checker            
+
+
     def get_default_schema_name(self):
         if not hasattr(self, '_default_schema_name'):
             self._default_schema_name = text("select database()", self).scalar()
index 2943d163e5960d39941ab7c02c41f620242d386a..a26ef76b6f5c1825bff367aeb1980b71bf642ceb 100644 (file)
@@ -338,6 +338,19 @@ class PGDialect(ansisql.ANSIDialect):
         cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name})
         return bool(not not cursor.rowcount)
 
+    def get_disconnect_checker(self):
+        def disconnect_checker(e):
+            if isinstance(e, self.dbapi.OperationalError):
+                return 'closed the connection' in str(e) or 'connection not open' in str(e)
+            elif isinstance(e, self.dbapi.InterfaceError):
+                return 'connection already closed' in str(e)
+            elif isinstance(e, self.dbapi.ProgrammingError):
+                # yes, it really says "losed", not "closed"
+                return "losed the connection unexpectedly" in str(e)
+            else:
+                return False
+        return disconnect_checker
+
     def reflecttable(self, connection, table):
         if self.version == 2:
             ischema_names = pg2_ischema_names
index d8a9c52998e5bee12ef046f8497cee4cb1dd6f02..80d93e61cc43ec1980eff14b73ad820c3c91710c 100644 (file)
@@ -246,6 +246,11 @@ class Dialect(sql.AbstractDialect):
 
         return clauseelement.compile(dialect=self, parameters=parameters)
 
+    def get_disconnect_checker(self):
+        """Return a callable that determines if an SQLError is caused by a database disconnection."""
+
+        return lambda x: False
+
 
 class ExecutionContext(object):
     """A messenger object for a Dialect that corresponds to a single execution.
@@ -440,18 +445,30 @@ class Connection(Connectable):
         return self.__transaction is not None
 
     def _begin_impl(self):
-        self.__engine.logger.info("BEGIN")
-        self.__engine.dialect.do_begin(self.connection)
+        if self.__connection.is_valid:
+            self.__engine.logger.info("BEGIN")
+            try:
+                self.__engine.dialect.do_begin(self.connection)
+            except Exception, e:
+                raise exceptions.SQLError(None, None, e)
 
     def _rollback_impl(self):
-        self.__engine.logger.info("ROLLBACK")
-        self.__engine.dialect.do_rollback(self.connection)
-        self.__connection.close_open_cursors()
+        if self.__connection.is_valid:
+            self.__engine.logger.info("ROLLBACK")
+            try:
+                self.__engine.dialect.do_rollback(self.connection)
+            except Exception, e:
+                raise exceptions.SQLError(None, None, e)
+            self.__connection.close_open_cursors()
         self.__transaction = None
 
     def _commit_impl(self):
-        self.__engine.logger.info("COMMIT")
-        self.__engine.dialect.do_commit(self.connection)
+        if self.__connection.is_valid:
+            self.__engine.logger.info("COMMIT")
+            try:
+                self.__engine.dialect.do_commit(self.connection)
+            except Exception, e:
+                raise exceptions.SQLError(None, None, e)
         self.__transaction = None
 
     def _autocommit(self, statement):
@@ -560,7 +577,6 @@ class Connection(Connectable):
             context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context)
         except Exception, e:
             self._autorollback()
-            #self._rollback_impl()
             if self.__close_with_result:
                 self.close()
             raise exceptions.SQLError(context.statement, context.parameters, e)
@@ -570,7 +586,6 @@ class Connection(Connectable):
             context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
         except Exception, e:
             self._autorollback()
-            #self._rollback_impl()
             if self.__close_with_result:
                 self.close()
             raise exceptions.SQLError(context.statement, context.parameters, e)
index 1b760fca8b2f28d6ccc2174b8cee2bff5caeb654..2f3b4519976844f57d5ee1f2cf9d5adb90bf0c0b 100644 (file)
@@ -86,7 +86,7 @@ class DefaultEngineStrategy(EngineStrategy):
                 if tk in kwargs:
                     pool_args[k] = kwargs.pop(tk)
             pool_args['use_threadlocal'] = self.pool_threadlocal()
-            pool = poolclass(creator, **pool_args)
+            pool = poolclass(creator, disconnect_checker=dialect.get_disconnect_checker(), **pool_args)
         else:
             if isinstance(pool, poollib._DBProxy):
                 pool = pool.get_pool(*cargs, **cparams)
index 8d559aff52f92143c6d4a5d6ef3298dbebf323cb..0b1ac2630f31f9b56160c1596e4278371abde518 100644 (file)
@@ -125,7 +125,8 @@ class Pool(object):
     False, then no cursor processing occurs upon checkin.
     """
 
-    def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, auto_close_cursors=True, disallow_open_cursors=False):
+    def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, auto_close_cursors=True,
+                 disallow_open_cursors=False, disconnect_checker=None):
         self.logger = logging.instance_logger(self)
         self._threadconns = weakref.WeakValueDictionary()
         self._creator = creator
@@ -134,6 +135,10 @@ class Pool(object):
         self.auto_close_cursors = auto_close_cursors
         self.disallow_open_cursors = disallow_open_cursors
         self.echo = echo
+        if disconnect_checker:
+            self.disconnect_checker = disconnect_checker
+        else:
+            self.disconnect_checker = lambda x: False
     echo = logging.echo_property()
 
     def unique_connection(self):
@@ -183,8 +188,11 @@ class _ConnectionRecord(object):
         self.__pool.log("Closing connection %s" % repr(self.connection))
         self.connection.close()
 
-    def invalidate(self):
-        self.__pool.log("Invalidate connection %s" % repr(self.connection))
+    def invalidate(self, e=None):
+        if e is not None:
+            self.__pool.log("Invalidate connection %s (reason: %s:%s)" % (repr(self.connection), e.__class__.__name__, str(e)))
+        else:
+            self.__pool.log("Invalidate connection %s" % repr(self.connection))
         self.__close()
         self.connection = None
 
@@ -226,7 +234,7 @@ class _ConnectionFairy(object):
     def __init__(self, pool):
         self._threadfairy = _ThreadFairy(self)
         self._cursors = weakref.WeakKeyDictionary()
-        self.__pool = pool
+        self._pool = pool
         self.__counter = 0
         try:
             self._connection_record = pool.get()
@@ -235,15 +243,17 @@ class _ConnectionFairy(object):
             self.connection = None # helps with endless __getattr__ loops later on
             self._connection_record = None
             raise
-        if self.__pool.echo:
-            self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
+        if self._pool.echo:
+            self._pool.log("Connection %s checked out from pool" % repr(self.connection))
+    
+    _logger = property(lambda self: self._pool.logger)
     
-    _logger = property(lambda self: self.__pool.logger)
-         
-    def invalidate(self):
+    is_valid = property(lambda self:self.connection is not None)
+    
+    def invalidate(self, e=None):
         if self.connection is None:
             raise exceptions.InvalidRequestError("This connection is closed")
-        self._connection_record.invalidate()
+        self._connection_record.invalidate(e=e)
         self.connection = None
         self._cursors = None
         self._close()
@@ -253,7 +263,7 @@ class _ConnectionFairy(object):
             c = self.connection.cursor(*args, **kwargs)
             return _CursorFairy(self, c)
         except Exception, e:
-            self.invalidate()
+            self.invalidate(e=e)
             raise
 
     def __getattr__(self, key):
@@ -282,21 +292,21 @@ class _ConnectionFairy(object):
         if self._cursors is not None:
             # cursors should be closed before connection is returned to the pool.  some dbapis like
             # mysql have real issues if they are not.
-            if self.__pool.auto_close_cursors:
+            if self._pool.auto_close_cursors:
                 self.close_open_cursors()
-            elif self.__pool.disallow_open_cursors:
+            elif self._pool.disallow_open_cursors:
                 if len(self._cursors):
                     raise exceptions.InvalidRequestError("This connection still has %d open cursors" % len(self._cursors))
         if self.connection is not None:
             try:
                 self.connection.rollback()
-            except:
+            except Exception, e:
                 if self._connection_record is not None:
-                    self._connection_record.invalidate()
+                    self._connection_record.invalidate(e=e)
         if self._connection_record is not None:
-            if self.__pool.echo:
-                self.__pool.log("Connection %s being returned to pool" % repr(self.connection))
-            self.__pool.return_conn(self)
+            if self._pool.echo:
+                self._pool.log("Connection %s being returned to pool" % repr(self.connection))
+            self._pool.return_conn(self)
         self.connection = None
         self._connection_record = None
         self._threadfairy = None
@@ -305,11 +315,27 @@ class _ConnectionFairy(object):
 class _CursorFairy(object):
     def __init__(self, parent, cursor):
         self.__parent = parent
-        self.__parent._cursors[self]=True
+        self.__parent._cursors[self] = True
         self.cursor = cursor
 
-    def invalidate(self):
-        self.__parent.invalidate()
+    def execute(self, *args, **kwargs):
+        try:
+            self.cursor.execute(*args, **kwargs)
+        except Exception, e:
+            if self.__parent._pool.disconnect_checker(e):
+                self.invalidate(e=e)
+            raise
+
+    def executemany(self, *args, **kwargs):
+        try:
+            self.cursor.executemany(*args, **kwargs)
+        except Exception, e:
+            if self.__parent._pool.disconnect_checker(e):
+                self.invalidate(e=e)
+            raise
+
+    def invalidate(self, e=None):
+        self.__parent.invalidate(e=e)
     
     def close(self):
         if self in self.__parent._cursors: