]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add new "exec_once_unless_exception" system; apply to dialect.initialize
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Aug 2019 22:07:06 +0000 (18:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Aug 2019 14:41:52 +0000 (10:41 -0400)
Fixed an issue whereby if the dialect "initialize" process which occurs on
first connect would encounter an unexpected exception, the initialize
process would fail to complete and then no longer attempt on subsequent
connection attempts, leaving the dialect in an un-initialized, or partially
initialized state, within the scope of parameters that need to be
established based on inspection of a live connection.   The "invoke once"
logic in the event system has been reworked to accommodate for this
occurrence using new, private API features that establish an "exec once"
hook that will continue to allow the initializer to fire off on subsequent
connections, until it completes without raising an exception. This does not
impact the behavior of the existing ``once=True`` flag within the event
system.

Fixes: #4807
Change-Id: Iec32999b61b6af4b38b6719e0c2651454619078c

doc/build/changelog/unreleased_13/4807.rst [new file with mode: 0644]
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/event/attr.py
lib/sqlalchemy/event/registry.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_events.py
test/engine/test_reconnect.py

diff --git a/doc/build/changelog/unreleased_13/4807.rst b/doc/build/changelog/unreleased_13/4807.rst
new file mode 100644 (file)
index 0000000..a688abb
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 4807
+
+    Fixed an issue whereby if the dialect "initialize" process which occurs on
+    first connect would encounter an unexpected exception, the initialize
+    process would fail to complete and then no longer attempt on subsequent
+    connection attempts, leaving the dialect in an un-initialized, or partially
+    initialized state, within the scope of parameters that need to be
+    established based on inspection of a live connection.   The "invoke once"
+    logic in the event system has been reworked to accommodate for this
+    occurrence using new, private API features that establish an "exec once"
+    hook that will continue to allow the initializer to fire off on subsequent
+    connections, until it completes without raising an exception. This does not
+    impact the behavior of the existing ``once=True`` flag within the event
+    system.
index cc83041319569db7e4ba09fe770ee71f168c5343..72be6009bda82d130f8ba10641c3991851d568f3 100644 (file)
@@ -529,7 +529,9 @@ def create_engine(url, **kwargs):
             dialect.initialize(c)
             dialect.do_rollback(c.connection)
 
-        event.listen(pool, "first_connect", first_connect, once=True)
+        event.listen(
+            pool, "first_connect", first_connect, _once_unless_exception=True
+        )
 
     dialect_cls.engine_created(engine)
     if entrypoint is not dialect_cls:
index 9dfa89809dc985d5504055c2260f58ed93b50baf..b6c48fa6c254dd9fe5b47247f5384c3e8f2a0e3e 100644 (file)
@@ -250,7 +250,9 @@ class _EmptyListener(_InstanceLevelDispatch):
     def _needs_modify(self, *args, **kw):
         raise NotImplementedError("need to call for_modify()")
 
-    exec_once = insert = append = remove = clear = _needs_modify
+    exec_once = (
+        exec_once_unless_exception
+    ) = insert = append = remove = clear = _needs_modify
 
     def __call__(self, *args, **kw):
         """Execute this event."""
@@ -276,17 +278,40 @@ class _CompoundListener(_InstanceLevelDispatch):
     def _memoized_attr__exec_once_mutex(self):
         return threading.Lock()
 
+    def _exec_once_impl(self, retry_on_exception, *args, **kw):
+        with self._exec_once_mutex:
+            if not self._exec_once:
+                try:
+                    self(*args, **kw)
+                    exception = False
+                except:
+                    exception = True
+                    raise
+                finally:
+                    if not exception or not retry_on_exception:
+                        self._exec_once = True
+
     def exec_once(self, *args, **kw):
         """Execute this event, but only if it has not been
         executed already for this collection."""
 
         if not self._exec_once:
-            with self._exec_once_mutex:
-                if not self._exec_once:
-                    try:
-                        self(*args, **kw)
-                    finally:
-                        self._exec_once = True
+            self._exec_once_impl(False, *args, **kw)
+
+    def exec_once_unless_exception(self, *args, **kw):
+        """Execute this event, but only if it has not been
+        executed already for this collection, or was called
+        by a previous exec_once_unless_exception call and
+        raised an exception.
+
+        If exec_once was already called, then this method will never run
+        the callable regardless of whether it raised or not.
+
+        .. versionadded:: 1.3.8
+
+        """
+        if not self._exec_once:
+            self._exec_once_impl(True, *args, **kw)
 
     def __call__(self, *args, **kw):
         """Execute this event."""
index 07b961c012d781f217a39fca543a690e15df9dda..2b8619b6e2ed302fc8af36b74029cbd94a790fdf 100644 (file)
@@ -192,6 +192,7 @@ class _EventKey(object):
 
     def listen(self, *args, **kw):
         once = kw.pop("once", False)
+        once_unless_exception = kw.pop("_once_unless_exception", False)
         named = kw.pop("named", False)
 
         target, identifier, fn = (
@@ -212,10 +213,12 @@ class _EventKey(object):
         if hasattr(stub_function, "_sa_warn"):
             stub_function._sa_warn()
 
-        if once:
-            self.with_wrapper(util.only_once(self._listen_fn)).listen(
-                *args, **kw
-            )
+        if once or once_unless_exception:
+            self.with_wrapper(
+                util.only_once(
+                    self._listen_fn, retry_on_exception=once_unless_exception
+                )
+            ).listen(*args, **kw)
         else:
             self.dispatch_target.dispatch._listen(self, *args, **kw)
 
index 2325e7faa1ffd076a8a9050765c8940f9d5c61e0..c45f836db2cef7227d574cc379f5fdb248f25420 100644 (file)
@@ -604,7 +604,7 @@ class _ConnectionRecord(object):
             if first_connect_check:
                 pool.dispatch.first_connect.for_modify(
                     pool.dispatch
-                ).exec_once(self.connection, self)
+                ).exec_once_unless_exception(self.connection, self)
             if pool.dispatch.connect:
                 pool.dispatch.connect(self.connection, self)
 
index 12fc5c0e87850e7c943fa6eaaccc60ba01e6aa40..f3f3f9ea5d78570d86b3fadb8a31db387dd7f522 100644 (file)
@@ -1487,7 +1487,7 @@ def warn_limited(msg, args):
     warnings.warn(msg, exc.SAWarning, stacklevel=2)
 
 
-def only_once(fn):
+def only_once(fn, retry_on_exception):
     """Decorate the given function to be a no-op after it is called exactly
     once."""
 
@@ -1499,7 +1499,12 @@ def only_once(fn):
         strong_fn = fn  # noqa
         if once:
             once_fn = once.pop()
-            return once_fn(*arg, **kw)
+            try:
+                return once_fn(*arg, **kw)
+            except:
+                if retry_on_exception:
+                    once.insert(0, once_fn)
+                raise
 
     return go
 
index c12b3414c761a4d193ab0f4b47c904b74fbf6a73..f13137084a377ce886424801dd4662cf896afcbc 100644 (file)
@@ -171,6 +171,78 @@ class EventsTest(fixtures.TestBase):
             t2.dispatch.event_one,
         )
 
+    def test_exec_once(self):
+        m1 = Mock()
+
+        event.listen(self.Target, "event_one", m1)
+
+        t1 = self.Target()
+        t2 = self.Target()
+
+        t1.dispatch.event_one.for_modify(t1.dispatch).exec_once(5, 6)
+
+        t1.dispatch.event_one.for_modify(t1.dispatch).exec_once(7, 8)
+
+        t2.dispatch.event_one.for_modify(t2.dispatch).exec_once(9, 10)
+
+        eq_(m1.mock_calls, [call(5, 6), call(9, 10)])
+
+    def test_exec_once_exception(self):
+        m1 = Mock()
+        m1.side_effect = ValueError
+
+        event.listen(self.Target, "event_one", m1)
+
+        t1 = self.Target()
+
+        assert_raises(
+            ValueError,
+            t1.dispatch.event_one.for_modify(t1.dispatch).exec_once,
+            5,
+            6,
+        )
+
+        t1.dispatch.event_one.for_modify(t1.dispatch).exec_once(7, 8)
+
+        eq_(m1.mock_calls, [call(5, 6)])
+
+    def test_exec_once_unless_exception(self):
+        m1 = Mock()
+        m1.side_effect = ValueError
+
+        event.listen(self.Target, "event_one", m1)
+
+        t1 = self.Target()
+
+        assert_raises(
+            ValueError,
+            t1.dispatch.event_one.for_modify(
+                t1.dispatch
+            ).exec_once_unless_exception,
+            5,
+            6,
+        )
+
+        assert_raises(
+            ValueError,
+            t1.dispatch.event_one.for_modify(
+                t1.dispatch
+            ).exec_once_unless_exception,
+            7,
+            8,
+        )
+
+        m1.side_effect = None
+        t1.dispatch.event_one.for_modify(
+            t1.dispatch
+        ).exec_once_unless_exception(9, 10)
+
+        t1.dispatch.event_one.for_modify(
+            t1.dispatch
+        ).exec_once_unless_exception(11, 12)
+
+        eq_(m1.mock_calls, [call(5, 6), call(7, 8), call(9, 10)])
+
     def test_immutable_methods(self):
         t1 = self.Target()
         for meth in [
@@ -1146,6 +1218,70 @@ class RemovalTest(fixtures.TestBase):
         eq_(m3.mock_calls, [call("x")])
         eq_(m4.mock_calls, [call("z")])
 
+    def test_once_unless_exception(self):
+        Target = self._fixture()
+
+        m1 = Mock()
+        m2 = Mock()
+        m3 = Mock()
+        m4 = Mock()
+
+        m1.side_effect = ValueError
+        m2.side_effect = ValueError
+        m3.side_effect = ValueError
+
+        event.listen(Target, "event_one", m1)
+        event.listen(Target, "event_one", m2, _once_unless_exception=True)
+        event.listen(Target, "event_one", m3, _once_unless_exception=True)
+
+        t1 = Target()
+
+        # only m1 is called, raises
+        assert_raises(ValueError, t1.dispatch.event_one, "x")
+
+        # now m1 and m2 can be called but not m3
+        m1.side_effect = None
+
+        assert_raises(ValueError, t1.dispatch.event_one, "y")
+
+        # now m3 can be called
+        m2.side_effect = None
+
+        event.listen(Target, "event_one", m4, _once_unless_exception=True)
+        assert_raises(ValueError, t1.dispatch.event_one, "z")
+
+        assert_raises(ValueError, t1.dispatch.event_one, "q")
+
+        eq_(m1.mock_calls, [call("x"), call("y"), call("z"), call("q")])
+        eq_(m2.mock_calls, [call("y"), call("z")])
+        eq_(m3.mock_calls, [call("z"), call("q")])
+        eq_(m4.mock_calls, [])  # m4 never got called because m3 blocked it
+
+        # now m4 can be called
+        m3.side_effect = None
+
+        t1.dispatch.event_one("p")
+        eq_(
+            m1.mock_calls,
+            [call("x"), call("y"), call("z"), call("q"), call("p")],
+        )
+
+        # m2 already got called, so no "p"
+        eq_(m2.mock_calls, [call("y"), call("z")])
+        eq_(m3.mock_calls, [call("z"), call("q"), call("p")])
+        eq_(m4.mock_calls, [call("p")])
+
+        t1.dispatch.event_one("j")
+        eq_(
+            m1.mock_calls,
+            [call("x"), call("y"), call("z"), call("q"), call("p"), call("j")],
+        )
+
+        # nobody got "j" because they've all been successful
+        eq_(m2.mock_calls, [call("y"), call("z")])
+        eq_(m3.mock_calls, [call("z"), call("q"), call("p")])
+        eq_(m4.mock_calls, [call("p")])
+
     def test_once_doesnt_dereference_listener(self):
         # test for [ticket:4794]
 
index 0eab8fb632583e3a89972bbfdc15e4a2dbf057bf..481700e70261b8811257852e51cf14df7a20def6 100644 (file)
@@ -18,6 +18,8 @@ from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing.engines import testing_engine
@@ -550,10 +552,78 @@ class MockReconnectTest(fixtures.TestBase):
 
         engine = create_engine(MyURL("foo://"), module=dbapi)
         engine.connect()
+
+        # note that the dispose() call replaces the old pool with a new one;
+        # this is to test that even though a single pool is using
+        # dispatch.exec_once(), by replacing the pool with a new one, the event
+        # would normally fire again onless once=True is set on the original
+        # listen as well.
+
         engine.dispose()
         engine.connect()
         eq_(Dialect.initialize.call_count, 1)
 
+    def test_dialect_initialize_retry_if_exception(self):
+        from sqlalchemy.engine.url import URL
+        from sqlalchemy.engine.default import DefaultDialect
+
+        dbapi = self.dbapi
+
+        class MyURL(URL):
+            def _get_entrypoint(self):
+                return Dialect
+
+            def get_dialect(self):
+                return Dialect
+
+        class Dialect(DefaultDialect):
+            initialize = Mock()
+
+        # note that the first_connect hook is only invoked when the pool
+        # makes a new DBAPI connection, and not when it checks out an existing
+        # connection.  So there is a dependency here that if the initializer
+        # raises an exception, the pool-level connection attempt is also
+        # failed, meaning no DBAPI connection is pooled.  If the first_connect
+        # exception raise did not prevent the connection from being pooled,
+        # there could be the case where the pool could return that connection
+        # on a subsequent attempt without initialization having proceeded.
+
+        Dialect.initialize.side_effect = TypeError
+        engine = create_engine(MyURL("foo://"), module=dbapi)
+
+        assert_raises(TypeError, engine.connect)
+        eq_(Dialect.initialize.call_count, 1)
+        is_true(engine.pool._pool.empty())
+
+        assert_raises(TypeError, engine.connect)
+        eq_(Dialect.initialize.call_count, 2)
+        is_true(engine.pool._pool.empty())
+
+        engine.dispose()
+
+        assert_raises(TypeError, engine.connect)
+        eq_(Dialect.initialize.call_count, 3)
+        is_true(engine.pool._pool.empty())
+
+        Dialect.initialize.side_effect = None
+
+        conn = engine.connect()
+        eq_(Dialect.initialize.call_count, 4)
+        conn.close()
+        is_false(engine.pool._pool.empty())
+
+        conn = engine.connect()
+        eq_(Dialect.initialize.call_count, 4)
+        conn.close()
+        is_false(engine.pool._pool.empty())
+
+        engine.dispose()
+        conn = engine.connect()
+
+        eq_(Dialect.initialize.call_count, 4)
+        conn.close()
+        is_false(engine.pool._pool.empty())
+
     def test_invalidate_conn_w_contextmanager_interrupt(self):
         # test [ticket:3803]
         pool = self.db.pool