]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- auto-reconnect support improved; a Connection can now automatically
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Dec 2007 19:51:46 +0000 (19:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Dec 2007 19:51:46 +0000 (19:51 +0000)
reconnect after its underlying connection is invalidated, without
needing to connect() again from the engine.  This allows an ORM session
bound to a single Connection to not need a reconnect.
Open transactions on the Connection must be rolled back after an invalidation
of the underlying connection else an error is raised.  Also fixed
bug where disconnect detect was not being called for cursor(), rollback(),
or commit().

CHANGES
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/exceptions.py
lib/sqlalchemy/pool.py
test/engine/reconnect.py
test/sql/testtypes.py
test/testlib/engines.py

diff --git a/CHANGES b/CHANGES
index 1077b4c30d83d2a70dca5b93bc6657f2e16e1360..4178cb33d6f28e416f099a0c5011280cd8e6f7b7 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -13,6 +13,15 @@ CHANGES
       knows that its return type is a Date. We only have a few functions
       represented so far but will continue to add to the system [ticket:615]
 
+    - auto-reconnect support improved; a Connection can now automatically
+      reconnect after its underlying connection is invalidated, without
+      needing to connect() again from the engine.  This allows an ORM session
+      bound to a single Connection to not need a reconnect.
+      Open transactions on the Connection must be rolled back after an invalidation 
+      of the underlying connection else an error is raised.  Also fixed
+      bug where disconnect detect was not being called for cursor(), rollback(),
+      or commit().
+      
     - added new flag to String and create_engine(),
       assert_unicode=(True|False|'warn'|None). Defaults to `False` or `None` on
       create_engine() and String, `'warn'` on the Unicode type. When `True`,
index 122c24bff51499fdbefc12fee55de9245342b94c..a738887f47a1e2f36f4a41150d8bf3e6f1c6f291 100644 (file)
@@ -1527,8 +1527,12 @@ class MySQLDialect(default.DefaultDialect):
         connection.ping()
 
     def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.OperationalError) and \
-               e.args[0] in (2006, 2013, 2014, 2045, 2055)
+        if isinstance(e, self.dbapi.OperationalError):
+            return e.args[0] in (2006, 2013, 2014, 2045, 2055)
+        elif isinstance(e, self.dbapi.InterfaceError):  # if underlying connection is closed, this is the error you get
+            return "(0, '')" in str(e)
+        else:
+            return False
 
     def get_default_schema_name(self, connection):
         try:
index 16dd9427c0b9bc9784f871d834c2667a9980792c..e028b1c539b65417dddd5b9cfb78ea2e2900a758 100644 (file)
@@ -237,6 +237,9 @@ class SQLiteDialect(default.DefaultDialect):
     def oid_column_name(self, column):
         return "oid"
     
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
+
     def table_names(self, connection, schema):
         s = "SELECT name FROM sqlite_master WHERE type='table'"
         return [row[0] for row in connection.execute(s)]
index ff2245f39e69aefb5544aacf7e6ae3d06e41d38a..57c76e7ed3386d5a97312d8fa159aeaaafa4553a 100644 (file)
@@ -542,6 +542,7 @@ class Connection(Connectable):
         self.__close_with_result = close_with_result
         self.__savepoint_seq = 0
         self.__branch = _branch
+        self.__invalid = False
 
     def _branch(self):
         """Return a new Connection which references this Connection's 
@@ -559,12 +560,30 @@ class Connection(Connectable):
         return self.engine.dialect
     dialect = property(dialect)
     
+    def closed(self):
+        """return True if this connection is closed."""
+        
+        return not self.__invalid and '_Connection__connection' not in self.__dict__
+    closed = property(closed)
+    
+    def invalidated(self):
+        """return True if this connection was invalidated."""
+        
+        return self.__invalid
+    invalidated = property(invalidated)
+    
     def connection(self):
         "The underlying DB-API connection managed by this Connection."
 
         try:
             return self.__connection
         except AttributeError:
+            if self.__invalid:
+                if self.__transaction is not None:
+                    raise exceptions.InvalidRequestError("Can't reconnect until invalid transaction is rolled back")
+                self.__connection = self.engine.raw_connection()
+                self.__invalid = False
+                return self.__connection
             raise exceptions.InvalidRequestError("This Connection is closed")
     connection = property(connection)
     
@@ -603,16 +622,28 @@ class Connection(Connectable):
 
         return self
 
-    def invalidate(self):
-        """Invalidate and close the Connection.
+    def invalidate(self, exception=None):
+        """Invalidate the underlying DBAPI connection associated with this Connection.
 
         The underlying DB-API connection is literally closed (if
         possible), and is discarded.  Its source connection pool will
         typically lazilly create a new connection to replace it.
+        
+        Upon the next usage, this Connection will attempt to reconnect
+        to the pool with a new connection.
+
+        Transactions in progress remain in an "opened" state (even though
+        the actual transaction is gone); these must be explicitly 
+        rolled back before a reconnect on this Connection can proceed.  This
+        is to prevent applications from accidentally continuing their transactional
+        operations in a non-transactional state.
+        
         """
 
-        self.__connection.invalidate()
-        self.__connection = None
+        if self.__connection.is_valid:
+            self.__connection.invalidate(exception)
+        del self.__connection
+        self.__invalid = True
 
     def detach(self):
         """Detach the underlying DB-API connection from its connection pool.
@@ -699,29 +730,31 @@ class Connection(Connectable):
         if self.engine._should_log_info:
             self.engine.logger.info("BEGIN")
         try:
-            self.engine.dialect.do_begin(self.__connection)
+            self.engine.dialect.do_begin(self.connection)
         except Exception, e:
-            raise exceptions.DBAPIError.instance(None, None, e)
+            raise self.__handle_dbapi_exception(e, None, None, None)
 
     def _rollback_impl(self):
-        if self.__connection.is_valid:
+        if not self.closed and not self.invalidated and self.__connection.is_valid:
             if self.engine._should_log_info:
                 self.engine.logger.info("ROLLBACK")
             try:
-                self.engine.dialect.do_rollback(self.__connection)
+                self.engine.dialect.do_rollback(self.connection)
+                self.__transaction = None
             except Exception, e:
-                raise exceptions.DBAPIError.instance(None, None, e)
-        self.__transaction = None
+                raise self.__handle_dbapi_exception(e, None, None, None)
+        else:
+            self.__transaction = None
 
     def _commit_impl(self):
         if self.engine._should_log_info:
             self.engine.logger.info("COMMIT")
         try:
-            self.engine.dialect.do_commit(self.__connection)
+            self.engine.dialect.do_commit(self.connection)
+            self.__transaction = None
         except Exception, e:
-            raise exceptions.DBAPIError.instance(None, None, e)
-        self.__transaction = None
-
+            raise self.__handle_dbapi_exception(e, None, None, None)
+        
     def _savepoint_impl(self, name=None):
         if name is None:
             self.__savepoint_seq += 1
@@ -789,6 +822,7 @@ class Connection(Connectable):
         if not self.__branch:
             self.__connection.close()
         self.__connection = None
+        self.__invalid = False
         del self.__connection
 
     def scalar(self, object, *multiparams, **params):
@@ -872,15 +906,32 @@ class Connection(Connectable):
         self._autocommit(context)
         return context.result()
 
-    def __create_execution_context(self, **kwargs):
-        return self.engine.dialect.create_execution_context(connection=self, **kwargs)
-
     def __execute_raw(self, context):
         if context.executemany:
             self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context)
         else:
             self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
+
+    def __handle_dbapi_exception(self, e, statement, parameters, cursor):
+        if not isinstance(e, self.dialect.dbapi.Error):
+            return e
+        is_disconnect = self.dialect.is_disconnect(e)
+        if is_disconnect:
+            self.invalidate(e)
+            self.engine.dispose()
+        if cursor:
+            cursor.close()
+        self._autorollback()
+        if self.__close_with_result:
+            self.close()
+        return exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
         
+    def __create_execution_context(self, **kwargs):
+        try:
+            return self.engine.dialect.create_execution_context(connection=self, **kwargs)
+        except Exception, e:
+            raise self.__handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None)
+
     def _cursor_execute(self, cursor, statement, parameters, context=None):
         if self.engine._should_log_info:
             self.engine.logger.info(statement)
@@ -888,14 +939,7 @@ class Connection(Connectable):
         try:
             self.dialect.do_execute(cursor, statement, parameters, context=context)
         except Exception, e:
-            if self.dialect.is_disconnect(e):
-                self.__connection.invalidate(e=e)
-                self.engine.dispose()
-            cursor.close()
-            self._autorollback()
-            if self.__close_with_result:
-                self.close()
-            raise exceptions.DBAPIError.instance(statement, parameters, e)
+            raise self.__handle_dbapi_exception(e, statement, parameters, cursor)
 
     def _cursor_executemany(self, cursor, statement, parameters, context=None):
         if self.engine._should_log_info:
@@ -904,14 +948,7 @@ class Connection(Connectable):
         try:
             self.dialect.do_executemany(cursor, statement, parameters, context=context)
         except Exception, e:
-            if self.dialect.is_disconnect(e):
-                self.__connection.invalidate(e=e)
-                self.engine.dispose()
-            cursor.close()
-            self._autorollback()
-            if self.__close_with_result:
-                self.close()
-            raise exceptions.DBAPIError.instance(statement, parameters, e)
+            raise self.__handle_dbapi_exception(e, statement, parameters, cursor)
 
     # poor man's multimethod/generic function thingy
     executors = {
@@ -990,8 +1027,8 @@ class Transaction(object):
     def commit(self):
         if not self._parent._is_active:
             raise exceptions.InvalidRequestError("This transaction is inactive")
-        self._is_active = False
         self._do_commit()
+        self._is_active = False
 
     def _do_commit(self):
         pass
index 8338bc554f7e2ffe69f8d8cb5d9486abd0ff0aeb..530ce3e3a2bdf56c5c9a4edb7b6953d39f194b35 100644 (file)
@@ -62,7 +62,11 @@ class NoSuchColumnError(KeyError, SQLAlchemyError):
 
 
 class DisconnectionError(SQLAlchemyError):
-    """Raised within ``Pool`` when a disconnect is detected on a raw DB-API connection."""
+    """Raised within ``Pool`` when a disconnect is detected on a raw DB-API connection.
+    
+    This error is consumed internally by a connection pool.  It can be raised by
+    a ``PoolListener`` so that the host pool forces a disconnect.
+    """
 
 
 class DBAPIError(SQLAlchemyError):
@@ -84,7 +88,7 @@ class DBAPIError(SQLAlchemyError):
     Its type and properties are DB-API implementation specific.  
     """
 
-    def instance(cls, statement, params, orig):
+    def instance(cls, statement, params, orig, connection_invalidated=False):
         # Don't ever wrap these, just return them directly as if
         # DBAPIError didn't exist.
         if isinstance(orig, (KeyboardInterrupt, SystemExit)):
@@ -95,10 +99,10 @@ class DBAPIError(SQLAlchemyError):
             if name in glob and issubclass(glob[name], DBAPIError):
                 cls = glob[name]
             
-        return cls(statement, params, orig)
+        return cls(statement, params, orig, connection_invalidated)
     instance = classmethod(instance)
     
-    def __init__(self, statement, params, orig):
+    def __init__(self, statement, params, orig, connection_invalidated=False):
         try:
             text = str(orig)
         except (KeyboardInterrupt, SystemExit):
@@ -110,6 +114,7 @@ class DBAPIError(SQLAlchemyError):
         self.statement = statement
         self.params = params
         self.orig = orig
+        self.connection_invalidated = connection_invalidated
 
     def __str__(self):
         return ' '.join([SQLAlchemyError.__str__(self),
index ff38f21b874569e084f3390035fa22d44a1ace04..7a5c2ef0edc18fbeb843a26f3afd389a15279ad1 100644 (file)
@@ -208,7 +208,12 @@ class _ConnectionRecord(object):
         if self.connection is not None:
             if self.__pool._should_log_info:
                 self.__pool.log("Closing connection %s" % repr(self.connection))
-            self.connection.close()
+            try:
+                self.connection.close()
+            except:
+                if self.__pool._should_log_info:
+                    self.__pool.log("Exception closing connection %s" % repr(self.connection))
+                
 
     def invalidate(self, e=None):
         if self.__pool._should_log_info:
index 7c213695f2643385c9946a8480d144af93731ddd..f9d692b3d71800b894e65a7c299dbb028e553374 100644 (file)
@@ -13,37 +13,39 @@ class MockDBAPI(object):
         self.connections = weakref.WeakKeyDictionary()
     def connect(self, *args, **kwargs):
         return MockConnection(self)
-        
+    def shutdown(self):
+        for c in self.connections:
+            c.explode[0] = True
+    Error = MockDisconnect
+            
 class MockConnection(object):
     def __init__(self, dbapi):
-        self.explode = False
         dbapi.connections[self] = True
+        self.explode = [False]
     def rollback(self):
         pass
     def commit(self):
         pass
     def cursor(self):
-        return MockCursor(explode=self.explode)
+        return MockCursor(self)
     def close(self):
         pass
             
 class MockCursor(object):
-    def __init__(self, explode):
-        self.explode = explode
+    def __init__(self, parent):
+        self.explode = parent.explode
         self.description = None
     def execute(self, *args, **kwargs):
-        if self.explode:
+        if self.explode[0]:
             raise MockDisconnect("Lost the DB connection")
         else:
             return
     def close(self):
         pass
         
-class ReconnectTest(PersistTest):
-    def test_reconnect(self):
-        """test that an 'is_disconnect' condition will invalidate the connection, and additionally
-        dispose the previous connection pool and recreate."""
-        
+class MockReconnectTest(PersistTest):
+    def setUp(self):
+        global db, dbapi
         dbapi = MockDBAPI()
         
         # create engine using our current dburi
@@ -52,6 +54,11 @@ class ReconnectTest(PersistTest):
         # monkeypatch disconnect checker
         db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
         
+    def test_reconnect(self):
+        """test that an 'is_disconnect' condition will invalidate the connection, and additionally
+        dispose the previous connection pool and recreate."""
+        
+        
         pid = id(db.pool)
         
         # make a connection
@@ -68,17 +75,17 @@ class ReconnectTest(PersistTest):
         assert len(dbapi.connections) == 2
 
         # set it to fail
-        conn.connection.connection.explode = True
-        
+        dbapi.shutdown()
+
         try:
-            # execute should fail
             conn.execute("SELECT 1")
             assert False
-        except exceptions.SQLAlchemyError, e:
+        except exceptions.DBAPIError:
             pass
         
         # assert was invalidated
-        assert conn.connection.connection is None
+        assert not conn.closed
+        assert conn.invalidated
         
         # close shouldnt break
         conn.close()
@@ -92,6 +99,182 @@ class ReconnectTest(PersistTest):
         conn.execute("SELECT 1")
         conn.close()
         assert len(dbapi.connections) == 1
+    
+    def test_invalidate_trans(self):
+        conn = db.connect()
+        trans = conn.begin()
+        dbapi.shutdown()
+
+        try:
+            conn.execute("SELECT 1")
+            assert False
+        except exceptions.DBAPIError:
+            pass
+            
+        # assert was invalidated
+        assert len(dbapi.connections) == 0
+        assert not conn.closed
+        assert conn.invalidated
+        assert trans.is_active
+
+        try:
+            conn.execute("SELECT 1")
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Can't reconnect until invalid transaction is rolled back"
+
+        assert trans.is_active
+
+        try:
+            trans.commit()
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Can't reconnect until invalid transaction is rolled back"
+
+        assert trans.is_active
+
+        trans.rollback()
+        assert not trans.is_active
+        
+        conn.execute("SELECT 1")
+        assert not conn.invalidated
+        
+        assert len(dbapi.connections) == 1
+        
+    def test_conn_reusable(self):
+        conn = db.connect()
+        
+        conn.execute("SELECT 1")
+
+        assert len(dbapi.connections) == 1
+        
+        dbapi.shutdown()
+
+        # raises error
+        try:
+            conn.execute("SELECT 1")
+            assert False
+        except exceptions.DBAPIError:
+            pass
+
+        assert not conn.closed
+        assert conn.invalidated
+
+        # ensure all connections closed (pool was recycled)
+        assert len(dbapi.connections) == 0
+            
+        # test reconnects
+        conn.execute("SELECT 1")
+        assert not conn.invalidated
+        assert len(dbapi.connections) == 1
+        
+
+class RealReconnectTest(PersistTest):
+    def setUp(self):
+        global engine
+        engine = engines.reconnecting_engine()
+    
+    def tearDown(self):
+        engine.dispose()
+        
+    def test_reconnect(self):
+        conn = engine.connect()
+
+        self.assertEquals(conn.execute("SELECT 1").scalar(), 1) 
+        assert not conn.closed
+
+        engine.test_shutdown()
+
+        try:
+            conn.execute("SELECT 1")
+            assert False
+        except exceptions.DBAPIError, e:
+            if not e.connection_invalidated:
+                raise
+
+        assert not conn.closed
+        assert conn.invalidated
+
+        assert conn.invalidated
+        self.assertEquals(conn.execute("SELECT 1").scalar(), 1) 
+        assert not conn.invalidated
+
+        # one more time
+        engine.test_shutdown()
+        try:
+            conn.execute("SELECT 1")
+            assert False
+        except exceptions.DBAPIError, e:
+            if not e.connection_invalidated:
+                raise
+        assert conn.invalidated
+        self.assertEquals(conn.execute("SELECT 1").scalar(), 1) 
+        assert not conn.invalidated
+
+        conn.close()
+    
+    def test_close(self):
+        conn = engine.connect()
+        self.assertEquals(conn.execute("SELECT 1").scalar(), 1) 
+        assert not conn.closed
+
+        engine.test_shutdown()
+
+        try:
+            conn.execute("SELECT 1")
+            assert False
+        except exceptions.DBAPIError, e:
+            if not e.connection_invalidated:
+                raise
+
+        conn.close()
+        conn = engine.connect()
+        self.assertEquals(conn.execute("SELECT 1").scalar(), 1) 
+        
+    def test_with_transaction(self):
+        conn = engine.connect()
+
+        trans = conn.begin()
+
+        self.assertEquals(conn.execute("SELECT 1").scalar(), 1) 
+        assert not conn.closed
+
+        engine.test_shutdown()
+
+        try:
+            conn.execute("SELECT 1")
+            assert False
+        except exceptions.DBAPIError, e:
+            if not e.connection_invalidated:
+                raise
+
+        assert not conn.closed
+        assert conn.invalidated
+        assert trans.is_active
+
+        try:
+            conn.execute("SELECT 1")
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Can't reconnect until invalid transaction is rolled back"
+
+        assert trans.is_active
+
+        try:
+            trans.commit()
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Can't reconnect until invalid transaction is rolled back"
+
+        assert trans.is_active
+        
+        trans.rollback()
+        assert not trans.is_active
+
+        assert conn.invalidated
+        self.assertEquals(conn.execute("SELECT 1").scalar(), 1) 
+        assert not conn.invalidated
+        
         
 if __name__ == '__main__':
     testbase.main()
index 69696ec642f2b8f3559718bdd1f60841fd432887..59acaff2fff10ce300b42f55ccdd1e4e4f0197ba 100644 (file)
@@ -342,7 +342,7 @@ class UnicodeTest(AssertMixin):
                 assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'"
         finally:
             unicode_engine.dispose()
-        
+
     @testing.unsupported('oracle')
     def testblanks(self):
         unicode_table.insert().execute(unicode_varchar=u'')
index 56507618c2628da7e073b8e48ec2052559d65616..b576a15364a08ef407665bbccb310e67f061bcd6 100644 (file)
@@ -68,7 +68,31 @@ def close_open_connections(fn):
     decorated.__name__ = fn.__name__
     return decorated
 
-
+class ReconnectFixture(object):
+    def __init__(self, dbapi):
+        self.dbapi = dbapi
+        self.connections = []
+    
+    def __getattr__(self, key):
+        return getattr(self.dbapi, key)
+
+    def connect(self, *args, **kwargs):
+        conn = self.dbapi.connect(*args, **kwargs)
+        self.connections.append(conn)
+        return conn
+
+    def shutdown(self):
+        for c in list(self.connections):
+            c.close()
+        self.connections = []
+        
+def reconnecting_engine(url=None, options=None):
+    url = url or config.db_url
+    dbapi = config.db.dialect.dbapi
+    engine = testing_engine(url, {'module':ReconnectFixture(dbapi)})
+    engine.test_shutdown = engine.dialect.dbapi.shutdown
+    return engine
+    
 def testing_engine(url=None, options=None):
     """Produce an engine configured by --options with optional overrides."""
     
@@ -109,3 +133,5 @@ def utf8_engine(url=None, options=None):
             url = str(url)
 
     return testing_engine(url, options)
+
+