]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- changed "invalidate" semantics with pooled connection; will
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Aug 2006 21:32:11 +0000 (21:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Aug 2006 21:32:11 +0000 (21:32 +0000)
instruct the underlying connection record to reconnect the next
time its called.  "invalidate" will also automatically be called
if any error is thrown in the underlying call to connection.cursor().
this will hopefully allow the connection pool to reconnect to a
database that had been stopped and started without restarting
the connecting application [ticket:121]

CHANGES
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/pool.py
test/engine/pool.py

diff --git a/CHANGES b/CHANGES
index 3be73d3d52f801d894623a79a863a9f962f759ae..085a5979ede82420214277b408d7f5979fbdbebd 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -7,6 +7,13 @@ function to 'create_engine'.
 defaults to 3600 seconds; connections after this age will be closed and
 replaced with a new one, to handle db's that automatically close 
 stale connections [ticket:274]
+- changed "invalidate" semantics with pooled connection; will 
+instruct the underlying connection record to reconnect the next 
+time its called.  "invalidate" will also automatically be called
+if any error is thrown in the underlying call to connection.cursor().
+this will hopefully allow the connection pool to reconnect to a
+database that had been stopped and started without restarting
+the connecting application [ticket:121] 
 - eesh !  the tutorial doctest was broken for quite some time.
 - add_property() method on mapper does a "compile all mappers"
 step in case the given property references a non-compiled mapper
index 6a96fbcfbef11c5614d157c1743b173c52c7dbda..ce6cc7d82ef365011d76d42dd2a119492992f383 100644 (file)
@@ -342,7 +342,8 @@ class Connection(Connectable):
         try:
             self.__engine.dialect.do_executemany(c, statement, parameters, context=context)
         except Exception, e:
-            self._rollback_impl()
+            self._autorollback()
+            #self._rollback_impl()
             if self.__close_with_result:
                 self.close()
             raise exceptions.SQLError(statement, parameters, e)
index 211f96070d88a6937d3aad3cd99b2a4a27c5c519..577405b0f8efa5a4db370b9cf73abdda059b92c3 100644 (file)
@@ -104,9 +104,6 @@ class Pool(object):
     def return_conn(self, agent):
         self.do_return_conn(agent._connection_record)
 
-    def return_invalid(self, agent):
-        self.do_return_invalid(agent._connection_record)
-        
     def get(self):
         return self.do_get()
     
@@ -116,9 +113,6 @@ class Pool(object):
     def do_return_conn(self, conn):
         raise NotImplementedError()
         
-    def do_return_invalid(self, conn):
-        raise NotImplementedError()
-        
     def status(self):
         raise NotImplementedError()
 
@@ -133,26 +127,35 @@ class Pool(object):
 
 class _ConnectionRecord(object):
     def __init__(self, pool):
-        self.pool = pool
+        self.__pool = pool
         self.connection = self.__connect()
     def close(self):
         self.connection.close()
+    def invalidate(self):
+        self.__pool.log("Invalidate connection %s" % repr(self.connection))
+        self.__close()
+        self.connection = None
     def get_connection(self):
-        if self.pool._recycle > -1 and time.time() - self.starttime > self.pool._recycle:
-            self.pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
-            try:
-                self.connection.close()
-            except Exception, e:
-                self.pool.log("Connection %s threw an error: %s" % (repr(self.connection), str(e)))
+        if self.connection is None:
+            self.connection = self.__connect()
+        elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle):
+            self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
+            self.__close()
             self.connection = self.__connect()
         return self.connection
+    def __close(self):
+        try:
+            self.__pool.log("Closing connection %s" % (repr(self.connection)))
+            self.connection.close()
+        except Exception, e:
+            self.__pool.log("Connection %s threw an error on close: %s" % (repr(self.connection), str(e)))
     def __connect(self):
         try:
             self.starttime = time.time()
-            return self.pool._creator()
-        except:
+            return self.__pool._creator()
+        except Exception, e:
+            self.__pool.log("Error on connect(): %s" % (str(e)))
             raise
-            # TODO: reconnect support here ?
 
 class _ThreadFairy(object):
     """marks a thread identifier as owning a connection, for a thread local pool."""
@@ -171,19 +174,19 @@ class _ConnectionFairy(object):
         except:
             self.connection = None # helps with endless __getattr__ loops later on
             self._connection_record = None
-            self.__pool.return_invalid(self)
             raise
         if self.__pool.echo:
             self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
     def invalidate(self):
-        if self.__pool.echo:
-            self.__pool.log("Invalidate connection %s" % repr(self.connection))
+        self._connection_record.invalidate()
         self.connection = None
-        self._connection_record = None
-        self._threadfairy = None
-        self.__pool.return_invalid(self)
+        self._close()
     def cursor(self, *args, **kwargs):
-        return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+        try:
+            return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+        except Exception, e:
+            self.invalidate()
+            raise
     def __getattr__(self, key):
         return getattr(self.connection, key)
     def checkout(self):
@@ -199,16 +202,15 @@ class _ConnectionFairy(object):
         self._close()
     def _close(self):
         if self.connection is not None:
-            if self.__pool.echo:
-                self.__pool.log("Connection %s being returned to pool" % repr(self.connection))
             try:
                 self.connection.rollback()
             except:
                 # damn mysql -- (todo look for NotSupportedError)
                 pass
+        if self._connection_record is not None:
+            if self.__pool.echo:
+                self.__pool.log("Connection %s being returned to pool" % repr(self.connection))
             self.__pool.return_conn(self)
-        self.__pool = None
-        self.connection = None
         self._connection_record = None
         self._threadfairy = None
             
@@ -257,12 +259,6 @@ class SingletonThreadPool(Pool):
     def do_return_conn(self, conn):
         pass
         
-    def do_return_invalid(self, conn):
-        try:
-            del self._conns[thread.get_ident()]
-        except KeyError:
-            pass
-            
     def do_get(self):
         try:
             return self._conns[thread.get_ident()]
@@ -288,10 +284,6 @@ class QueuePool(Pool):
         except Queue.Full:
             self._overflow -= 1
 
-    def do_return_invalid(self, conn):
-        if conn is not None:
-            self._overflow -= 1
-        
     def do_get(self):
         try:
             return self._pool.get(self._max_overflow > -1 and self._overflow >= self._max_overflow, self._timeout)
index cfc8f5684a7c43835d56591e0bededd2d2a646a1..c6b96813091b53bcd55c335f19afa30fb12a31b7 100644 (file)
@@ -6,7 +6,11 @@ import sqlalchemy.pool as pool
 import sqlalchemy.exceptions as exceptions
 
 class MockDBAPI(object):
+    def __init__(self):
+        self.throw_error = False
     def connect(self, argument):
+        if self.throw_error:
+            raise Exception("couldnt connect !")
         return MockConnection()
 class MockConnection(object):
     def close(self):
@@ -154,6 +158,37 @@ class PoolTest(PersistTest):
         time.sleep(3)
         c3= p.connect()
         assert id(c3.connection) != c_id
+    
+    def test_invalidate(self):
+        dbapi = MockDBAPI()
+        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, echo=True)
+        c1 = p.connect()
+        c_id = id(c1.connection)
+        c1.close(); c1=None
+
+        c1 = p.connect()
+        assert id(c1.connection) == c_id
+        c1.invalidate()
+        c1 = None
+        
+        c1 = p.connect()
+        assert id(c1.connection) != c_id
+
+    def test_reconnect(self):
+        dbapi = MockDBAPI()
+        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, echo=True)
+        c1 = p.connect()
+        c_id = id(c1.connection)
+        c1.close(); c1=None
+
+        c1 = p.connect()
+        assert id(c1.connection) == c_id
+        dbapi.raise_error = True
+        c1.invalidate()
+        c1 = None
+
+        c1 = p.connect()
+        assert id(c1.connection) != c_id
         
     def testthreadlocal_del(self):
         self._do_testthreadlocal(useclose=False)