]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add new "exec_once_unless_exception" system; apply to dialect.initialize
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Aug 2019 22:07:06 +0000 (18:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Aug 2019 14:43:37 +0000 (10:43 -0400)
Fixed an issue whereby if the dialect "initialize" process which occurs on
first connect would encounter an unexpected exception, the initialize
process would fail to complete and then no longer attempt on subsequent
connection attempts, leaving the dialect in an un-initialized, or partially
initialized state, within the scope of parameters that need to be
established based on inspection of a live connection.   The "invoke once"
logic in the event system has been reworked to accommodate for this
occurrence using new, private API features that establish an "exec once"
hook that will continue to allow the initializer to fire off on subsequent
connections, until it completes without raising an exception. This does not
impact the behavior of the existing ``once=True`` flag within the event
system.

Fixes: #4807
Change-Id: Iec32999b61b6af4b38b6719e0c2651454619078c
(cherry picked from commit 2051fa2ce9e724e6e77e19067d27d2660e7cd74a)

doc/build/changelog/unreleased_13/4807.rst [new file with mode: 0644]
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/event/attr.py
lib/sqlalchemy/event/registry.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_events.py
test/engine/test_reconnect.py

diff --git a/doc/build/changelog/unreleased_13/4807.rst b/doc/build/changelog/unreleased_13/4807.rst
new file mode 100644 (file)
index 0000000..a688abb
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 4807
+
+    Fixed an issue whereby if the dialect "initialize" process which occurs on
+    first connect would encounter an unexpected exception, the initialize
+    process would fail to complete and then no longer attempt on subsequent
+    connection attempts, leaving the dialect in an un-initialized, or partially
+    initialized state, within the scope of parameters that need to be
+    established based on inspection of a live connection.   The "invoke once"
+    logic in the event system has been reworked to accommodate for this
+    occurrence using new, private API features that establish an "exec once"
+    hook that will continue to allow the initializer to fire off on subsequent
+    connections, until it completes without raising an exception. This does not
+    impact the behavior of the existing ``once=True`` flag within the event
+    system.
index d3a22e5ac8bef8ac139761befbd37286396a2752..fc1d614ca1f816e281787e2ffee66076abb34fe8 100644 (file)
@@ -199,7 +199,12 @@ class DefaultEngineStrategy(EngineStrategy):
                 dialect.initialize(c)
                 dialect.do_rollback(c.connection)
 
-            event.listen(pool, "first_connect", first_connect, once=True)
+            event.listen(
+                pool,
+                "first_connect",
+                first_connect,
+                _once_unless_exception=True,
+            )
 
         dialect_cls.engine_created(engine)
         if entrypoint is not dialect_cls:
index 9dfa89809dc985d5504055c2260f58ed93b50baf..b6c48fa6c254dd9fe5b47247f5384c3e8f2a0e3e 100644 (file)
@@ -250,7 +250,9 @@ class _EmptyListener(_InstanceLevelDispatch):
     def _needs_modify(self, *args, **kw):
         raise NotImplementedError("need to call for_modify()")
 
-    exec_once = insert = append = remove = clear = _needs_modify
+    exec_once = (
+        exec_once_unless_exception
+    ) = insert = append = remove = clear = _needs_modify
 
     def __call__(self, *args, **kw):
         """Execute this event."""
@@ -276,17 +278,40 @@ class _CompoundListener(_InstanceLevelDispatch):
     def _memoized_attr__exec_once_mutex(self):
         return threading.Lock()
 
+    def _exec_once_impl(self, retry_on_exception, *args, **kw):
+        with self._exec_once_mutex:
+            if not self._exec_once:
+                try:
+                    self(*args, **kw)
+                    exception = False
+                except:
+                    exception = True
+                    raise
+                finally:
+                    if not exception or not retry_on_exception:
+                        self._exec_once = True
+
     def exec_once(self, *args, **kw):
         """Execute this event, but only if it has not been
         executed already for this collection."""
 
         if not self._exec_once:
-            with self._exec_once_mutex:
-                if not self._exec_once:
-                    try:
-                        self(*args, **kw)
-                    finally:
-                        self._exec_once = True
+            self._exec_once_impl(False, *args, **kw)
+
+    def exec_once_unless_exception(self, *args, **kw):
+        """Execute this event, but only if it has not been
+        executed already for this collection, or was called
+        by a previous exec_once_unless_exception call and
+        raised an exception.
+
+        If exec_once was already called, then this method will never run
+        the callable regardless of whether it raised or not.
+
+        .. versionadded:: 1.3.8
+
+        """
+        if not self._exec_once:
+            self._exec_once_impl(True, *args, **kw)
 
     def __call__(self, *args, **kw):
         """Execute this event."""
index 07b961c012d781f217a39fca543a690e15df9dda..2b8619b6e2ed302fc8af36b74029cbd94a790fdf 100644 (file)
@@ -192,6 +192,7 @@ class _EventKey(object):
 
     def listen(self, *args, **kw):
         once = kw.pop("once", False)
+        once_unless_exception = kw.pop("_once_unless_exception", False)
         named = kw.pop("named", False)
 
         target, identifier, fn = (
@@ -212,10 +213,12 @@ class _EventKey(object):
         if hasattr(stub_function, "_sa_warn"):
             stub_function._sa_warn()
 
-        if once:
-            self.with_wrapper(util.only_once(self._listen_fn)).listen(
-                *args, **kw
-            )
+        if once or once_unless_exception:
+            self.with_wrapper(
+                util.only_once(
+                    self._listen_fn, retry_on_exception=once_unless_exception
+                )
+            ).listen(*args, **kw)
         else:
             self.dispatch_target.dispatch._listen(self, *args, **kw)
 
index 410df47f1ad0273279cf653c67756f85c0524a7d..f98e0374332fe8aab763457c92149477df77bb93 100644 (file)
@@ -646,7 +646,7 @@ class _ConnectionRecord(object):
             if first_connect_check:
                 pool.dispatch.first_connect.for_modify(
                     pool.dispatch
-                ).exec_once(self.connection, self)
+                ).exec_once_unless_exception(self.connection, self)
             if pool.dispatch.connect:
                 pool.dispatch.connect(self.connection, self)
 
index b9ce2ebea1dd429b13b699d215cb9df4e6fd22db..9276df332b52a2e294d228397315eeec8394616e 100644 (file)
@@ -1466,7 +1466,7 @@ def warn_limited(msg, args):
     warnings.warn(msg, exc.SAWarning, stacklevel=2)
 
 
-def only_once(fn):
+def only_once(fn, retry_on_exception):
     """Decorate the given function to be a no-op after it is called exactly
     once."""
 
@@ -1478,7 +1478,12 @@ def only_once(fn):
         strong_fn = fn  # noqa
         if once:
             once_fn = once.pop()
-            return once_fn(*arg, **kw)
+            try:
+                return once_fn(*arg, **kw)
+            except:
+                if retry_on_exception:
+                    once.insert(0, once_fn)
+                raise
 
     return go
 
index c12b3414c761a4d193ab0f4b47c904b74fbf6a73..f13137084a377ce886424801dd4662cf896afcbc 100644 (file)
@@ -171,6 +171,78 @@ class EventsTest(fixtures.TestBase):
             t2.dispatch.event_one,
         )
 
+    def test_exec_once(self):
+        m1 = Mock()
+
+        event.listen(self.Target, "event_one", m1)
+
+        t1 = self.Target()
+        t2 = self.Target()
+
+        t1.dispatch.event_one.for_modify(t1.dispatch).exec_once(5, 6)
+
+        t1.dispatch.event_one.for_modify(t1.dispatch).exec_once(7, 8)
+
+        t2.dispatch.event_one.for_modify(t2.dispatch).exec_once(9, 10)
+
+        eq_(m1.mock_calls, [call(5, 6), call(9, 10)])
+
+    def test_exec_once_exception(self):
+        m1 = Mock()
+        m1.side_effect = ValueError
+
+        event.listen(self.Target, "event_one", m1)
+
+        t1 = self.Target()
+
+        assert_raises(
+            ValueError,
+            t1.dispatch.event_one.for_modify(t1.dispatch).exec_once,
+            5,
+            6,
+        )
+
+        t1.dispatch.event_one.for_modify(t1.dispatch).exec_once(7, 8)
+
+        eq_(m1.mock_calls, [call(5, 6)])
+
+    def test_exec_once_unless_exception(self):
+        m1 = Mock()
+        m1.side_effect = ValueError
+
+        event.listen(self.Target, "event_one", m1)
+
+        t1 = self.Target()
+
+        assert_raises(
+            ValueError,
+            t1.dispatch.event_one.for_modify(
+                t1.dispatch
+            ).exec_once_unless_exception,
+            5,
+            6,
+        )
+
+        assert_raises(
+            ValueError,
+            t1.dispatch.event_one.for_modify(
+                t1.dispatch
+            ).exec_once_unless_exception,
+            7,
+            8,
+        )
+
+        m1.side_effect = None
+        t1.dispatch.event_one.for_modify(
+            t1.dispatch
+        ).exec_once_unless_exception(9, 10)
+
+        t1.dispatch.event_one.for_modify(
+            t1.dispatch
+        ).exec_once_unless_exception(11, 12)
+
+        eq_(m1.mock_calls, [call(5, 6), call(7, 8), call(9, 10)])
+
     def test_immutable_methods(self):
         t1 = self.Target()
         for meth in [
@@ -1146,6 +1218,70 @@ class RemovalTest(fixtures.TestBase):
         eq_(m3.mock_calls, [call("x")])
         eq_(m4.mock_calls, [call("z")])
 
+    def test_once_unless_exception(self):
+        Target = self._fixture()
+
+        m1 = Mock()
+        m2 = Mock()
+        m3 = Mock()
+        m4 = Mock()
+
+        m1.side_effect = ValueError
+        m2.side_effect = ValueError
+        m3.side_effect = ValueError
+
+        event.listen(Target, "event_one", m1)
+        event.listen(Target, "event_one", m2, _once_unless_exception=True)
+        event.listen(Target, "event_one", m3, _once_unless_exception=True)
+
+        t1 = Target()
+
+        # only m1 is called, raises
+        assert_raises(ValueError, t1.dispatch.event_one, "x")
+
+        # now m1 and m2 can be called but not m3
+        m1.side_effect = None
+
+        assert_raises(ValueError, t1.dispatch.event_one, "y")
+
+        # now m3 can be called
+        m2.side_effect = None
+
+        event.listen(Target, "event_one", m4, _once_unless_exception=True)
+        assert_raises(ValueError, t1.dispatch.event_one, "z")
+
+        assert_raises(ValueError, t1.dispatch.event_one, "q")
+
+        eq_(m1.mock_calls, [call("x"), call("y"), call("z"), call("q")])
+        eq_(m2.mock_calls, [call("y"), call("z")])
+        eq_(m3.mock_calls, [call("z"), call("q")])
+        eq_(m4.mock_calls, [])  # m4 never got called because m3 blocked it
+
+        # now m4 can be called
+        m3.side_effect = None
+
+        t1.dispatch.event_one("p")
+        eq_(
+            m1.mock_calls,
+            [call("x"), call("y"), call("z"), call("q"), call("p")],
+        )
+
+        # m2 already got called, so no "p"
+        eq_(m2.mock_calls, [call("y"), call("z")])
+        eq_(m3.mock_calls, [call("z"), call("q"), call("p")])
+        eq_(m4.mock_calls, [call("p")])
+
+        t1.dispatch.event_one("j")
+        eq_(
+            m1.mock_calls,
+            [call("x"), call("y"), call("z"), call("q"), call("p"), call("j")],
+        )
+
+        # nobody got "j" because they've all been successful
+        eq_(m2.mock_calls, [call("y"), call("z")])
+        eq_(m3.mock_calls, [call("z"), call("q"), call("p")])
+        eq_(m4.mock_calls, [call("p")])
+
     def test_once_doesnt_dereference_listener(self):
         # test for [ticket:4794]
 
index 14f3a7fd56e53af0a6a31a60f41d5f2ea43a319f..45d8827148c638f793b2521b2c5a7e03dadd997d 100644 (file)
@@ -18,6 +18,8 @@ from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing.engines import testing_engine
@@ -550,10 +552,78 @@ class MockReconnectTest(fixtures.TestBase):
 
         engine = create_engine(MyURL("foo://"), module=dbapi)
         engine.connect()
+
+        # note that the dispose() call replaces the old pool with a new one;
+        # this is to test that even though a single pool is using
+        # dispatch.exec_once(), by replacing the pool with a new one, the event
+        # would normally fire again onless once=True is set on the original
+        # listen as well.
+
         engine.dispose()
         engine.connect()
         eq_(Dialect.initialize.call_count, 1)
 
+    def test_dialect_initialize_retry_if_exception(self):
+        from sqlalchemy.engine.url import URL
+        from sqlalchemy.engine.default import DefaultDialect
+
+        dbapi = self.dbapi
+
+        class MyURL(URL):
+            def _get_entrypoint(self):
+                return Dialect
+
+            def get_dialect(self):
+                return Dialect
+
+        class Dialect(DefaultDialect):
+            initialize = Mock()
+
+        # note that the first_connect hook is only invoked when the pool
+        # makes a new DBAPI connection, and not when it checks out an existing
+        # connection.  So there is a dependency here that if the initializer
+        # raises an exception, the pool-level connection attempt is also
+        # failed, meaning no DBAPI connection is pooled.  If the first_connect
+        # exception raise did not prevent the connection from being pooled,
+        # there could be the case where the pool could return that connection
+        # on a subsequent attempt without initialization having proceeded.
+
+        Dialect.initialize.side_effect = TypeError
+        engine = create_engine(MyURL("foo://"), module=dbapi)
+
+        assert_raises(TypeError, engine.connect)
+        eq_(Dialect.initialize.call_count, 1)
+        is_true(engine.pool._pool.empty())
+
+        assert_raises(TypeError, engine.connect)
+        eq_(Dialect.initialize.call_count, 2)
+        is_true(engine.pool._pool.empty())
+
+        engine.dispose()
+
+        assert_raises(TypeError, engine.connect)
+        eq_(Dialect.initialize.call_count, 3)
+        is_true(engine.pool._pool.empty())
+
+        Dialect.initialize.side_effect = None
+
+        conn = engine.connect()
+        eq_(Dialect.initialize.call_count, 4)
+        conn.close()
+        is_false(engine.pool._pool.empty())
+
+        conn = engine.connect()
+        eq_(Dialect.initialize.call_count, 4)
+        conn.close()
+        is_false(engine.pool._pool.empty())
+
+        engine.dispose()
+        conn = engine.connect()
+
+        eq_(Dialect.initialize.call_count, 4)
+        conn.close()
+        is_false(engine.pool._pool.empty())
+
     def test_invalidate_conn_w_contextmanager_interrupt(self):
         # test [ticket:3803]
         pool = self.db.pool