]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed another occasional race condition which could occur
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Aug 2007 23:57:30 +0000 (23:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Aug 2007 23:57:30 +0000 (23:57 +0000)
when using pool with threadlocal setting

CHANGES
lib/sqlalchemy/pool.py
test/engine/pool.py

diff --git a/CHANGES b/CHANGES
index 9d75e2dfbdd5aae837ecf71fd8c8bca88ef79363..43b8213dbae1524c815fbf8fbcffa43a1662a1f6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -3,6 +3,9 @@
     - added a check for joining from A->B using join(), along two
       different m2m tables.  this raises an error in 0.3 but is 
       possible in 0.4 when aliases are used. [ticket:687]
+- engine
+    - fixed another occasional race condition which could occur
+      when using pool with threadlocal setting
 - mssql
     - added support for TIME columns (simulated using DATETIME) [ticket:679]
     - index names are now quoted when dropping from reflected tables [ticket:684]
index 8670464a05f8facfe0c77af12ed3255f1d1a94ce..bb6997e904e25fbeb417670d252da0f079140693 100644 (file)
@@ -162,13 +162,15 @@ class Pool(object):
             return _ConnectionFairy(self).checkout()
 
         try:
-            return self._threadconns[thread.get_ident()].connfairy().checkout()
+            return self._threadconns[thread.get_ident()].checkout()
         except KeyError:
-            agent = _ConnectionFairy(self).checkout()
-            self._threadconns[thread.get_ident()] = agent._threadfairy
-            return agent
+            agent = _ConnectionFairy(self)
+            self._threadconns[thread.get_ident()] = agent
+            return agent.checkout()
 
     def return_conn(self, agent):
+        if self._use_threadlocal and thread.get_ident() in self._threadconns:
+            del self._threadconns[thread.get_ident()]
         self.do_return_conn(agent._connection_record)
 
     def get(self):
@@ -230,17 +232,10 @@ class _ConnectionRecord(object):
             self.__pool.log("Error on connect(): %s" % (str(e)))
             raise
 
-class _ThreadFairy(object):
-    """Mark a thread identifier as owning a connection, for a thread local pool."""
-
-    def __init__(self, connfairy):
-        self.connfairy = weakref.ref(connfairy)
-
 class _ConnectionFairy(object):
     """Proxy a DBAPI connection object and provides return-on-dereference support."""
 
     def __init__(self, pool):
-        self._threadfairy = _ThreadFairy(self)
         self._cursors = weakref.WeakKeyDictionary()
         self._pool = pool
         self.__counter = 0
@@ -340,7 +335,6 @@ class _ConnectionFairy(object):
             self._pool.return_conn(self)
         self.connection = None
         self._connection_record = None
-        self._threadfairy = None
         self._cursors = None
 
 class _CursorFairy(object):
index 85e9d59fd8720dade0098d0c2d690d013d638507..023ab70b7ec08b5c022808e34a0b976fc85d7b85 100644 (file)
@@ -1,6 +1,6 @@
 import testbase
 from testbase import PersistTest
-import unittest, sys, os, time
+import unittest, sys, os, time, gc
 import threading, thread
 
 import sqlalchemy.pool as pool
@@ -300,7 +300,14 @@ class PoolTest(PersistTest):
         assert not con.closed
         c1.close()
         assert con.closed
-        
+    
+    def test_threadfairy(self):
+        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True)
+        c1 = p.connect()
+        c1.close()
+        c2 = p.connect()
+        assert c2.connection is not None
+
     def testthreadlocal_del(self):
         self._do_testthreadlocal(useclose=False)
 
@@ -328,7 +335,7 @@ class PoolTest(PersistTest):
                 c2.close()
             else:
                 c2 = None
-        
+
             if useclose:
                 c1 = p.connect()
                 c2 = p.connect()
@@ -340,7 +347,7 @@ class PoolTest(PersistTest):
 
             c1 = c2 = c3 = None
             
-            # extra tests with QueuePool to insure connections get __del__()ed when dereferenced
+            # extra tests with QueuePool to ensure connections get __del__()ed when dereferenced
             if isinstance(p, pool.QueuePool):
                 self.assert_(p.checkedout() == 0)
                 c1 = p.connect()