--- /dev/null
+.. change::
+ :tags: bug, pool
+ :tickets: 4585
+
+ Fixed behavioral regression as a result of deprecating the "use_threadlocal"
+ flag for :class:`.Pool`, where the :class:`.SingletonThreadPool` no longer
+ makes use of this option which causes the "rollback on return" logic to take
+ place when the same :class:`.Engine` is used multiple times in the context
+ of a transaction to connect or implicitly execute, thereby cancelling the
+ transaction. While this is not the recommended way to work with engines
+ and connections, it is nonetheless a confusing behavioral change as when
+ using :class:`.SingletonThreadPool`, the transaction should stay open
+ regardless of what else is done with the same engine in the same thread.
+ The ``use_threadlocal`` flag remains deprecated however the
+ :class:`.SingletonThreadPool` now implements its own version of the same
+ logic.
+
import traceback
import weakref
+from .base import _ConnectionFairy
from .base import _ConnectionRecord
from .base import Pool
from .. import exc
def __init__(self, creator, pool_size=5, **kw):
Pool.__init__(self, creator, **kw)
self._conn = threading.local()
+ self._fairy = threading.local()
self._all_conns = set()
self.size = pool_size
self._all_conns.add(c)
return c
+ def connect(self):
+ # vendored from Pool to include use_threadlocal behavior
+ try:
+ rec = self._fairy.current()
+ except AttributeError:
+ pass
+ else:
+ if rec is not None:
+ return rec._checkout_existing()
+
+ return _ConnectionFairy._checkout(self, self._fairy)
+
+ def _return_conn(self, record):
+ try:
+ del self._fairy.current
+ except AttributeError:
+ pass
+ self._do_return_conn(record)
+
class StaticPool(Pool):
still_opened = len([c for c in sr if not c.close.call_count])
eq_(still_opened, 3)
+ def test_no_rollback_from_nested_connections(self):
+ dbapi = MockDBAPI()
+
+ lock = threading.Lock()
+
+ def creator():
+ # the mock iterator isn't threadsafe...
+ with lock:
+ return dbapi.connect()
+
+ p = pool.SingletonThreadPool(creator=creator, pool_size=3)
+
+ c1 = p.connect()
+ mock_conn = c1.connection
+
+ c2 = p.connect()
+ is_(c1, c2)
+
+ c2.close()
+
+ eq_(mock_conn.mock_calls, [])
+ c1.close()
+
+ eq_(mock_conn.mock_calls, [call.rollback()])
+
class AssertionPoolTest(PoolTestBase):
def test_connect_error(self):