]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Engine and TLEngine assume "threadlocal" behavior on Pool; both use connect()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Aug 2007 19:07:07 +0000 (19:07 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Aug 2007 19:07:07 +0000 (19:07 +0000)
for contextual connection, unique_connection() for non-contextual.
- Pool use_threadlocal defaults to True, can be set to false at create_engine()
level with pool_threadlocal=False
- made all logger statements in pool conditional based on a flag calcualted once.
- chagned WeakValueDictionary() used for "threadlocal" pool to be a regular dict
referencing weakref objects.  WVD had a lot of overhead, apparently.  *CAUTION* -
im pretty confident about this change, as the threadlocal dict gets explicitly managed
anyway, tests pass with PG etc., but keep a close eye on this one regardless.

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/pool.py
test/engine/pool.py

index 298264362cb9cf8c9a96ab2bed532e614dc1db58..2e75d358c4ae6f414dad301dc0522aebf1327ef6 100644 (file)
@@ -844,7 +844,7 @@ class Connection(Connectable):
         return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
 
     def __execute_raw(self, context):
-        if logging.is_info_enabled(self.__engine.logger):
+        if self.__engine._should_log:
             self.__engine.logger.info(context.statement)
             self.__engine.logger.info(repr(context.parameters))
         if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)):
@@ -1023,6 +1023,7 @@ class Engine(Connectable):
         self._dialect=dialect
         self.echo = echo
         self.logger = logging.instance_logger(self)
+        self._should_log = logging.is_info_enabled(self.logger)
 
     name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'], doc="String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``.")
     engine = property(lambda s:s)
@@ -1136,7 +1137,7 @@ class Engine(Connectable):
         This Connection is meant to be used by the various "auto-connecting" operations.
         """
 
-        return Connection(self, close_with_result=close_with_result, **kwargs)
+        return Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs)
     
     def table_names(self, schema=None, connection=None):
         """Return a list of all table names available in the database.
@@ -1183,7 +1184,7 @@ class Engine(Connectable):
     def raw_connection(self):
         """Return a DB-API connection."""
 
-        return self.pool.connect()
+        return self.pool.unique_connection()
 
     def log(self, msg):
         """Log a message using this SQLEngine's logger stream."""
@@ -1223,7 +1224,7 @@ class ResultProxy(object):
         self.dialect = context.dialect
         self.closed = False
         self.cursor = context.cursor
-        self.__echo = logging.is_debug_enabled(context.engine.logger)
+        self.__echo = context.engine._should_log
         self._process_row = self._row_processor()
         if context.is_select():
             self._init_metadata()
index a0a6445fd05c51bef50e2ecfd1720399f0b1ee66..259ba55c56a65c953ca2f2a0297dfca2089f1cb1 100644 (file)
@@ -85,12 +85,13 @@ class DefaultEngineStrategy(EngineStrategy):
             # the arguments
             translate = {'echo': 'echo_pool',
                          'timeout': 'pool_timeout',
-                         'recycle': 'pool_recycle'}
+                         'recycle': 'pool_recycle',
+                         'use_threadlocal':'pool_threadlocal'}
             for k in util.get_cls_kwargs(poolclass):
                 tk = translate.get(k, k)
                 if tk in kwargs:
                     pool_args[k] = kwargs.pop(tk)
-            pool_args['use_threadlocal'] = self.pool_threadlocal()
+            pool_args.setdefault('use_threadlocal', self.pool_threadlocal())
             pool = poolclass(creator, **pool_args)
         else:
             if isinstance(pool, poollib._DBProxy):
index e9843ea2e318aeb377b46003be2c7e4c71fdfada..dc6b6007fb2344532e38d32dfab0da92b0b23436 100644 (file)
@@ -16,7 +16,7 @@ class TLSession(object):
         try:
             return self.__transaction._increment_connect()
         except AttributeError:
-            return TLConnection(self, close_with_result=close_with_result)
+            return TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result)
 
     def reset(self):
         try:
@@ -82,9 +82,8 @@ class TLSession(object):
 
 
 class TLConnection(base.Connection):
-    def __init__(self, session, close_with_result):
-        base.Connection.__init__(self, session.engine,
-                                 close_with_result=close_with_result)
+    def __init__(self, session, connection, close_with_result):
+        base.Connection.__init__(self, session.engine, connection, close_with_result=close_with_result)
         self.__session = session
         self.__opencount = 1
 
@@ -160,20 +159,6 @@ class TLEngine(base.Engine):
         super(TLEngine, self).__init__(*args, **kwargs)
         self.context = util.ThreadLocal()
 
-    def raw_connection(self):
-        """Return a DB-API connection."""
-
-        return self.pool.connect()
-
-    def connect(self, **kwargs):
-        """Return a Connection that is not thread-locally scoped.
-
-        This is the equivalent to calling ``connect()`` on a
-        base.Engine.
-        """
-
-        return base.Connection(self, self.pool.unique_connection())
-
     def _session(self):
         if not hasattr(self.context, 'session'):
             self.context.session = TLSession(self)
index 55fe85ddc4ac33f277eaf408b6afac251874c17e..b44c7ee416056a2eccca1d47ab1d30b5c83b0340 100644 (file)
@@ -112,10 +112,10 @@ class Pool(object):
       the pool.
     """
 
-    def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False,
+    def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=True,
                  listeners=None):
         self.logger = logging.instance_logger(self)
-        self._threadconns = weakref.WeakValueDictionary()
+        self._threadconns = {}
         self._creator = creator
         self._recycle = recycle
         self._use_threadlocal = use_threadlocal
@@ -124,6 +124,8 @@ class Pool(object):
         self._on_connect = []
         self._on_checkout = []
         self._on_checkin = []
+        self._should_log = logging.is_info_enabled(self.logger)
+        
         if listeners:
             for l in listeners:
                 self.add_listener(l)
@@ -153,10 +155,10 @@ class Pool(object):
             return _ConnectionFairy(self).checkout()
 
         try:
-            return self._threadconns[thread.get_ident()].checkout()
+            return self._threadconns[thread.get_ident()]().checkout()
         except KeyError:
             agent = _ConnectionFairy(self)
-            self._threadconns[thread.get_ident()] = agent
+            self._threadconns[thread.get_ident()] = weakref.ref(agent)
             return agent.checkout()
 
     def return_conn(self, agent):
@@ -201,14 +203,16 @@ class _ConnectionRecord(object):
 
     def close(self):
         if self.connection is not None:
-            self.__pool.log("Closing connection %s" % repr(self.connection))
+            if self.__pool._should_log:
+                self.__pool.log("Closing connection %s" % repr(self.connection))
             self.connection.close()
 
     def invalidate(self, e=None):
-        if e is not None:
-            self.__pool.log("Invalidate connection %s (reason: %s:%s)" % (repr(self.connection), e.__class__.__name__, str(e)))
-        else:
-            self.__pool.log("Invalidate connection %s" % repr(self.connection))
+        if self.__pool._should_log:
+            if e is not None:
+                self.__pool.log("Invalidate connection %s (reason: %s:%s)" % (repr(self.connection), e.__class__.__name__, str(e)))
+            else:
+                self.__pool.log("Invalidate connection %s" % repr(self.connection))
         self.__close()
         self.connection = None
 
@@ -220,7 +224,8 @@ class _ConnectionRecord(object):
                 for l in self.__pool._on_connect:
                     l.connect(self.connection, self)
         elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle):
-            self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
+            if self.__pool._should_log:
+                self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
             self.__close()
             self.connection = self.__connect()
             self.properties.clear()
@@ -231,10 +236,12 @@ class _ConnectionRecord(object):
 
     def __close(self):
         try:
-            self.__pool.log("Closing connection %s" % (repr(self.connection)))
+            if self.__pool._should_log:
+                self.__pool.log("Closing connection %s" % (repr(self.connection)))
             self.connection.close()
         except Exception, e:
-            self.__pool.log("Connection %s threw an error on close: %s" % (repr(self.connection), str(e)))
+            if self.__pool._should_log:
+                self.__pool.log("Connection %s threw an error on close: %s" % (repr(self.connection), str(e)))
             if isinstance(e, (SystemExit, KeyboardInterrupt)):
                 raise
 
@@ -242,10 +249,12 @@ class _ConnectionRecord(object):
         try:
             self.starttime = time.time()
             connection = self.__pool._creator()
-            self.__pool.log("Created new connection %s" % repr(connection))
+            if self.__pool._should_log:
+                self.__pool.log("Created new connection %s" % repr(connection))
             return connection
         except Exception, e:
-            self.__pool.log("Error on connect(): %s" % (str(e)))
+            if self.__pool._should_log:
+                self.__pool.log("Error on connect(): %s" % (str(e)))
             raise
 
 class _ConnectionFairy(object):
@@ -261,7 +270,7 @@ class _ConnectionFairy(object):
             self.connection = None # helps with endless __getattr__ loops later on
             self._connection_record = None
             raise
-        if self._pool.echo:
+        if self._pool._should_log:
             self._pool.log("Connection %s checked out from pool" % repr(self.connection))
     
     _logger = property(lambda self: self._pool.logger)
@@ -323,13 +332,15 @@ class _ConnectionFairy(object):
                     l.checkout(self.connection, self._connection_record, self)
                 return self
             except exceptions.DisconnectionError, e:
-                self._pool.log(
+                if self._pool._should_log:
+                    self._pool.log(
                     "Disconnection detected on checkout: %s" % (str(e)))
                 self._connection_record.invalidate(e)
                 self.connection = self._connection_record.get_connection()
                 attempts -= 1
 
-        self._pool.log("Reconnection attempts exhausted on checkout")
+        if self._pool._should_log:
+            self._pool.log("Reconnection attempts exhausted on checkout")
         self.invalidate()
         raise exceptions.InvalidRequestError("This connection is closed")
 
@@ -375,7 +386,7 @@ class _ConnectionFairy(object):
                 if isinstance(e, (SystemExit, KeyboardInterrupt)):
                     raise
         if self._connection_record is not None:
-            if self._pool.echo:
+            if self._pool._should_log:
                 self._pool.log("Connection %s being returned to pool" % repr(self.connection))
             if self._pool._on_checkin:
                 for l in self._pool._on_checkin:
@@ -572,7 +583,8 @@ class QueuePool(Pool):
                 break
 
         self._overflow = 0 - self.size()
-        self.log("Pool disposed. " + self.status())
+        if self._should_log:
+            self.log("Pool disposed. " + self.status())
 
     def status(self):
         tup = (self.size(), self.checkedin(), self.overflow(), self.checkedout())
index 98c01343723885fa5224996569846e310a2c8a3c..443dfd3fcaa8020799d40955e9f9d76a8c1e82c9 100644 (file)
@@ -207,6 +207,19 @@ class PoolTest(PersistTest):
         assert p.checkedout() == 1
         c1 = None
         assert p.checkedout() == 0
+
+    def test_weakref_kaboom(self):
+        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True)
+        c1 = p.connect()
+        c2 = p.connect()
+        c1.close()
+        c2 = None
+        del c1
+        del c2
+        gc.collect()
+        assert p.checkedout() == 0
+        c3 = p.connect()
+        assert c3 is not None
     
     def test_trick_the_counter(self):
         """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
@@ -363,7 +376,7 @@ class PoolTest(PersistTest):
     def test_properties(self):
         dbapi = MockDBAPI()
         p = pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
-                           pool_size=1, max_overflow=0)
+                           pool_size=1, max_overflow=0, use_threadlocal=False)
 
         c = p.connect()
         self.assert_(not c.properties)
@@ -443,7 +456,7 @@ class PoolTest(PersistTest):
                 pass
 
         def _pool(**kw):
-            return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), **kw)
+            return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), use_threadlocal=False, **kw)
             #, pool_size=1, max_overflow=0, **kw)
 
         def assert_listeners(p, total, conn, cout, cin):