]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
reorganized SingletonThreadPool to return distinct connections in the same thread...
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Mar 2006 19:09:53 +0000 (19:09 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Mar 2006 19:09:53 +0000 (19:09 +0000)
lib/sqlalchemy/pool.py
test/pool.py

index ad65f7462a6b6a25198c587843ffa1eedf61076d..5ab8ba57888cf19b17a92129aa6e9143ba0853e2 100644 (file)
@@ -120,7 +120,7 @@ class Pool(object):
         raise NotImplementedError()
 
     def log(self, msg):
-        self.logger.write(msg)
+        self._logger.write(msg)
 
 class ConnectionFairy(object):
     def __init__(self, pool, connection=None):
@@ -155,19 +155,17 @@ class SingletonThreadPool(Pool):
     """Maintains one connection per each thread, never moving to another thread.  this is
     used for SQLite and other databases with a similar restriction."""
     def __init__(self, creator, **params):
-        params['use_threadlocal'] = False
         Pool.__init__(self, **params)
         self._conns = {}
         self._creator = creator
 
     def status(self):
-        return "SingletonThreadPool size: %d" % len(self._conns)
-
-    def unique_connection(self):
-        return ConnectionFairy(self, self._creator())
+        return "SingletonThreadPool thread:%d size: %d" % (thread.get_ident(), len(self._conns))
 
     def do_return_conn(self, conn):
-        pass
+        if self._conns.get(thread.get_ident(), None) is None:
+            self._conns[thread.get_ident()] = conn
+
     def do_return_invalid(self):
         try:
             del self._conns[thread.get_ident()]
@@ -176,9 +174,13 @@ class SingletonThreadPool(Pool):
             
     def do_get(self):
         try:
-            return self._conns[thread.get_ident()]
+            c = self._conns[thread.get_ident()]
+            if c is None:
+                return self._creator()
         except KeyError:
-            return self._conns.setdefault(thread.get_ident(), self._creator())
+            c = self._creator()
+        self._conns[thread.get_ident()] = None
+        return c
     
 class QueuePool(Pool):
     """uses Queue.Queue to maintain a fixed-size list of connections."""
index 811bda51aed265389199ca8298a6a4bcdd0ce71e..2737a33b1cba7130657f0aba1793812da225962d 100644 (file)
@@ -69,7 +69,22 @@ class PoolTest(PersistTest):
         self.assert_(status(p) == (3, 1, 0, 2))
         c2 = None
         self.assert_(status(p) == (3, 2, 0, 1))
-        
+    
+    def testthreadlocal(self):
+        for p in (
+            pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True, echo = False),
+            pool.SingletonThreadPool(creator = lambda: sqlite.connect('foo.db'), use_threadlocal = True)
+        ):    
+            c1 = p.connect()
+            c2 = p.connect()
+            self.assert_(c1 is c2)
+            c3 = p.unique_connection()
+            self.assert_(c3 is not c1)
+            c2 = None
+            c2 = p.connect()
+            self.assert_(c1 is c2)
+            self.assert_(c3 is not c1)
+
     def tearDown(self):
        pool.clear_managers()
        for file in ('foo.db', 'bar.db'):