]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- work the wrapping of the "creator" to be as resilient to
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 30 Apr 2015 23:44:16 +0000 (19:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 30 Apr 2015 23:55:28 +0000 (19:55 -0400)
old / new style, direct access, and ad-hoc patching and
unpatching as possible

lib/sqlalchemy/pool.py
test/engine/test_execute.py
test/engine/test_pool.py

index 902309d75977857b06c473068f160f51f6afb593..8eb9d796dba48ab5883d3f824302bdf2a42f4b9e 100644 (file)
@@ -219,7 +219,7 @@ class Pool(log.Identified):
         log.instance_logger(self, echoflag=echo)
         self._threadconns = threading.local()
         self._creator = creator
-        self._wrapped_creator = self._maybe_wrap_callable(creator)
+        self._set_should_wrap_creator()
         self._recycle = recycle
         self._invalidate_time = 0
         self._use_threadlocal = use_threadlocal
@@ -250,16 +250,17 @@ class Pool(log.Identified):
             for l in listeners:
                 self.add_listener(l)
 
-    def _maybe_wrap_callable(self, fn):
+    def _set_should_wrap_creator(self):
         """Detect if creator accepts a single argument, or is sent
         as a legacy style no-arg function.
 
         """
 
         try:
-            argspec = util.get_callable_argspec(fn, no_self=True)
+            argspec = util.get_callable_argspec(self._creator, no_self=True)
         except TypeError:
-            return lambda ctx: fn()
+            self._should_wrap_creator = (True, self._creator)
+            return
 
         defaulted = argspec[3] is not None and len(argspec[3]) or 0
         positionals = len(argspec[0]) - defaulted
@@ -267,14 +268,36 @@ class Pool(log.Identified):
         # look for the exact arg signature that DefaultStrategy
         # sends us
         if (argspec[0], argspec[3]) == (['connection_record'], (None,)):
-            return fn
+            self._should_wrap_creator = (False, self._creator)
         # or just a single positional
         elif positionals == 1:
-            return fn
+            self._should_wrap_creator = (False, self._creator)
         # all other cases, just wrap and assume legacy "creator" callable
         # thing
         else:
-            return lambda ctx: fn()
+            self._should_wrap_creator = (True, self._creator)
+
+    def _invoke_creator(self, connection_record):
+        """adjust for old or new style "creator" callable.
+
+        This function is spending extra effort in order to accommodate
+        any degree of manipulation of the _creator callable by end-user
+        applications, including ad-hoc patching in test suites.
+
+        """
+
+        should_wrap, against_creator = self._should_wrap_creator
+        creator = self._creator
+
+        if creator is not against_creator:
+            # check if the _creator function has been patched since
+            # we last looked at it
+            self._set_should_wrap_creator()
+            return self._invoke_creator(connection_record)
+        elif should_wrap:
+            return self._creator()
+        else:
+            return self._creator(connection_record)
 
     def _close_connection(self, connection):
         self.logger.debug("Closing connection %r", connection)
@@ -591,7 +614,7 @@ class _ConnectionRecord(object):
     def __connect(self):
         try:
             self.starttime = time.time()
-            connection = self.__pool._wrapped_creator(self)
+            connection = self.__pool._invoke_creator(self)
             self.__pool.logger.debug("Created new connection %r", connection)
             return connection
         except Exception as e:
index cba3972f62684dc7d905fbc4ec05fa0bf2e12f54..761ac102a0ed41fcac3215175d537ffb5f2132b3 100644 (file)
@@ -2174,7 +2174,7 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase):
 
         conn.invalidate()
 
-        eng.pool._wrapped_creator = Mock(
+        eng.pool._creator = Mock(
             side_effect=self.ProgrammingError(
                 "Cannot operate on a closed database."))
 
index 3d93cda899415fb56d580b6aaaf7de585a1d35e3..912c6c3fe12e4d3e6188e945cbcbffff8f72da5e 100644 (file)
@@ -1807,3 +1807,58 @@ class StaticPoolTest(PoolTestBase):
         p = pool.StaticPool(creator)
         p2 = p.recreate()
         assert p._creator is p2._creator
+
+
+class CreatorCompatibilityTest(PoolTestBase):
+    def test_creator_callable_outside_noarg(self):
+        e = testing_engine()
+
+        creator = e.pool._creator
+        try:
+            conn = creator()
+        finally:
+            conn.close()
+
+    def test_creator_callable_outside_witharg(self):
+        e = testing_engine()
+
+        creator = e.pool._creator
+        try:
+            conn = creator(Mock())
+        finally:
+            conn.close()
+
+    def test_creator_patching_arg_to_noarg(self):
+        e = testing_engine()
+        creator = e.pool._creator
+        try:
+            # the creator is the two-arg form
+            conn = creator(Mock())
+        finally:
+            conn.close()
+
+        def mock_create():
+            return creator()
+
+        conn = e.connect()
+        conn.invalidate()
+        conn.close()
+
+        # test that the 'should_wrap_creator' memoized attribute
+        # will dynamically switch if the _creator is monkeypatched.
+
+        is_(e.pool.__dict__.get("_should_wrap_creator")[0], False)
+
+        # patch it with a zero-arg form
+        with patch.object(e.pool, "_creator", mock_create):
+            conn = e.connect()
+            conn.invalidate()
+            conn.close()
+
+            is_(e.pool.__dict__.get("_should_wrap_creator")[0], True)
+
+        conn = e.connect()
+        conn.close()
+
+        is_(e.pool.__dict__.get("_should_wrap_creator")[0], False)
+