]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improve error raise for dialect/pool events w/ async engine
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Mar 2022 02:43:53 +0000 (21:43 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Mar 2022 02:47:01 +0000 (21:47 -0500)
Fixed issues where a descriptive error message was not raised for some
classes of event listening with an async engine, which should instead be a
sync engine instance.

Change-Id: I00b9f4fe9373ef5fd5464fac10651cc4024f648e

doc/build/changelog/unreleased_14/async_no_event.rst [new file with mode: 0644]
lib/sqlalchemy/engine/events.py
lib/sqlalchemy/ext/asyncio/events.py
lib/sqlalchemy/pool/events.py
test/base/test_events.py
test/ext/asyncio/test_engine_py3k.py

diff --git a/doc/build/changelog/unreleased_14/async_no_event.rst b/doc/build/changelog/unreleased_14/async_no_event.rst
new file mode 100644 (file)
index 0000000..8deda89
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, asyncio
+
+    Fixed issues where a descriptive error message was not raised for some
+    classes of event listening with an async engine, which should instead be a
+    sync engine instance.
\ No newline at end of file
index 0cbf56a6d500448c4024489cf314aa0e2c85a907..699faf48975516e70247e2a6e7e20ba3f4cdc940 100644 (file)
@@ -779,7 +779,7 @@ class DialectEvents(event.Events[Dialect]):
     @classmethod
     def _accept_with(
         cls, target: Union[Engine, Type[Engine], Dialect, Type[Dialect]]
-    ) -> Union[Dialect, Type[Dialect]]:
+    ) -> Optional[Union[Dialect, Type[Dialect]]]:
         if isinstance(target, type):
             if issubclass(target, Engine):
                 return Dialect
@@ -787,8 +787,14 @@ class DialectEvents(event.Events[Dialect]):
                 return target
         elif isinstance(target, Engine):
             return target.dialect
-        else:
+        elif isinstance(target, Dialect):
             return target
+        elif hasattr(target, "dispatch") and hasattr(
+            target.dispatch._events, "_no_async_engine_events"
+        ):
+            target.dispatch._events._no_async_engine_events()
+        else:
+            return None
 
     def do_connect(
         self,
index a059b93e6b916c2a01723f84aa5d5adaa8515ae8..c5d5e0126e9fa481ec0dcf7a4e5ea0100f8b9c46 100644 (file)
@@ -16,21 +16,29 @@ class AsyncConnectionEvents(engine_event.ConnectionEvents):
     _dispatch_target = AsyncConnectable
 
     @classmethod
-    def _listen(cls, event_key, retval=False):
+    def _no_async_engine_events(cls):
         raise NotImplementedError(
             "asynchronous events are not implemented at this time.  Apply "
             "synchronous listeners to the AsyncEngine.sync_engine or "
             "AsyncConnection.sync_connection attributes."
         )
 
+    @classmethod
+    def _listen(cls, event_key, retval=False):
+        cls._no_async_engine_events()
+
 
 class AsyncSessionEvents(orm_event.SessionEvents):
     _target_class_doc = "SomeSession"
     _dispatch_target = AsyncSession
 
     @classmethod
-    def _listen(cls, event_key, retval=False):
+    def _no_async_engine_events(cls):
         raise NotImplementedError(
             "asynchronous events are not implemented at this time.  Apply "
             "synchronous listeners to the AsyncSession.sync_session."
         )
+
+    @classmethod
+    def _listen(cls, event_key, retval=False):
+        cls._no_async_engine_events()
index d0d89291bc1ed76e95c96342d19cef93c8e54659..be2b406a344b31a67eba457cd7210d4a10a8902b 100644 (file)
@@ -59,7 +59,7 @@ class PoolEvents(event.Events[Pool]):
     @classmethod
     def _accept_with(
         cls, target: Union[Pool, Type[Pool], Engine, Type[Engine]]
-    ) -> Union[Pool, Type[Pool]]:
+    ) -> Optional[Union[Pool, Type[Pool]]]:
         if not typing.TYPE_CHECKING:
             Engine = util.preloaded.engine.Engine
 
@@ -71,9 +71,14 @@ class PoolEvents(event.Events[Pool]):
                 return target
         elif isinstance(target, Engine):
             return target.pool
-        else:
-            assert isinstance(target, Pool)
+        elif isinstance(target, Pool):
             return target
+        elif hasattr(target, "dispatch") and hasattr(
+            target.dispatch._events, "_no_async_engine_events"
+        ):
+            target.dispatch._events._no_async_engine_events()
+        else:
+            return None
 
     @classmethod
     def _listen(  # type: ignore[override]   # would rather keep **kw
index 2b785ba0b3a56c897b115dcd9b2b70344c1d4ed0..7de245c6421026d9a70a84979b656f523a09f8d3 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_deprecated
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
@@ -198,6 +199,50 @@ class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
 
         eq_(m1.mock_calls, [call(5, 6), call(9, 10)])
 
+    def test_real_name_wrong_dispatch(self):
+        m1 = Mock()
+
+        class E1(event.Events):
+            @classmethod
+            def _accept_with(cls, target):
+                if isinstance(target, T1):
+                    return target
+                else:
+                    m1.yup()
+                    return None
+
+            def event_one(self, x, y):
+                pass
+
+            def event_two(self, x):
+                pass
+
+            def event_three(self, x):
+                pass
+
+        class T1:
+            dispatch = event.dispatcher(E1)
+
+        class T2:
+            pass
+
+        class E2(event.Events):
+
+            _dispatch_target = T2
+
+            def event_four(self, x):
+                pass
+
+        with expect_raises_message(
+            exc.InvalidRequestError, "No such event 'event_three'"
+        ):
+
+            @event.listens_for(E2, "event_three")
+            def go(*arg):
+                pass
+
+        eq_(m1.mock_calls, [call.yup()])
+
     def test_exec_once_exception(self):
         m1 = Mock()
         m1.side_effect = ValueError
index 1f40cbdecf0a1ede37e35e1639c9fe76fcd8a088..0fe14dc9216929b7dec0f02e11f8a0a81b6c4d3f 100644 (file)
@@ -650,6 +650,28 @@ class AsyncEventTest(EngineFixture):
             ):
                 event.listen(conn, "before_cursor_execute", mock.Mock())
 
+    @async_test
+    async def test_no_async_listeners_dialect_event(self, async_engine):
+        with testing.expect_raises_message(
+            NotImplementedError,
+            "asynchronous events are not implemented "
+            "at this time.  Apply synchronous listeners to the "
+            "AsyncEngine.sync_engine or "
+            "AsyncConnection.sync_connection attributes.",
+        ):
+            event.listen(async_engine, "do_execute", mock.Mock())
+
+    @async_test
+    async def test_no_async_listeners_pool_event(self, async_engine):
+        with testing.expect_raises_message(
+            NotImplementedError,
+            "asynchronous events are not implemented "
+            "at this time.  Apply synchronous listeners to the "
+            "AsyncEngine.sync_engine or "
+            "AsyncConnection.sync_connection attributes.",
+        ):
+            event.listen(async_engine, "checkout", mock.Mock())
+
     @async_test
     async def test_sync_before_cursor_execute_engine(self, async_engine):
         canary = mock.Mock()