From 43cf4a9e5d66946a6a982ab3e1e513bb426eb35b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 2 Mar 2022 21:43:53 -0500 Subject: [PATCH] improve error raise for dialect/pool events w/ async engine Fixed issues where a descriptive error message was not raised for some classes of event listening with an async engine, which should instead be a sync engine instance. Change-Id: I00b9f4fe9373ef5fd5464fac10651cc4024f648e --- .../unreleased_14/async_no_event.rst | 6 +++ lib/sqlalchemy/engine/events.py | 10 ++++- lib/sqlalchemy/ext/asyncio/events.py | 12 ++++- lib/sqlalchemy/pool/events.py | 11 +++-- test/base/test_events.py | 45 +++++++++++++++++++ test/ext/asyncio/test_engine_py3k.py | 22 +++++++++ 6 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/async_no_event.rst diff --git a/doc/build/changelog/unreleased_14/async_no_event.rst b/doc/build/changelog/unreleased_14/async_no_event.rst new file mode 100644 index 0000000000..8deda89453 --- /dev/null +++ b/doc/build/changelog/unreleased_14/async_no_event.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, asyncio + + Fixed issues where a descriptive error message was not raised for some + classes of event listening with an async engine, which should instead be a + sync engine instance. \ No newline at end of file diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index 0cbf56a6d5..699faf4897 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -779,7 +779,7 @@ class DialectEvents(event.Events[Dialect]): @classmethod def _accept_with( cls, target: Union[Engine, Type[Engine], Dialect, Type[Dialect]] - ) -> Union[Dialect, Type[Dialect]]: + ) -> Optional[Union[Dialect, Type[Dialect]]]: if isinstance(target, type): if issubclass(target, Engine): return Dialect @@ -787,8 +787,14 @@ class DialectEvents(event.Events[Dialect]): return target elif isinstance(target, Engine): return target.dialect - else: + elif isinstance(target, Dialect): return target + elif hasattr(target, "dispatch") and hasattr( + target.dispatch._events, "_no_async_engine_events" + ): + target.dispatch._events._no_async_engine_events() + else: + return None def do_connect( self, diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py index a059b93e6b..c5d5e0126e 100644 --- a/lib/sqlalchemy/ext/asyncio/events.py +++ b/lib/sqlalchemy/ext/asyncio/events.py @@ -16,21 +16,29 @@ class AsyncConnectionEvents(engine_event.ConnectionEvents): _dispatch_target = AsyncConnectable @classmethod - def _listen(cls, event_key, retval=False): + def _no_async_engine_events(cls): raise NotImplementedError( "asynchronous events are not implemented at this time. Apply " "synchronous listeners to the AsyncEngine.sync_engine or " "AsyncConnection.sync_connection attributes." ) + @classmethod + def _listen(cls, event_key, retval=False): + cls._no_async_engine_events() + class AsyncSessionEvents(orm_event.SessionEvents): _target_class_doc = "SomeSession" _dispatch_target = AsyncSession @classmethod - def _listen(cls, event_key, retval=False): + def _no_async_engine_events(cls): raise NotImplementedError( "asynchronous events are not implemented at this time. Apply " "synchronous listeners to the AsyncSession.sync_session." ) + + @classmethod + def _listen(cls, event_key, retval=False): + cls._no_async_engine_events() diff --git a/lib/sqlalchemy/pool/events.py b/lib/sqlalchemy/pool/events.py index d0d89291bc..be2b406a34 100644 --- a/lib/sqlalchemy/pool/events.py +++ b/lib/sqlalchemy/pool/events.py @@ -59,7 +59,7 @@ class PoolEvents(event.Events[Pool]): @classmethod def _accept_with( cls, target: Union[Pool, Type[Pool], Engine, Type[Engine]] - ) -> Union[Pool, Type[Pool]]: + ) -> Optional[Union[Pool, Type[Pool]]]: if not typing.TYPE_CHECKING: Engine = util.preloaded.engine.Engine @@ -71,9 +71,14 @@ class PoolEvents(event.Events[Pool]): return target elif isinstance(target, Engine): return target.pool - else: - assert isinstance(target, Pool) + elif isinstance(target, Pool): return target + elif hasattr(target, "dispatch") and hasattr( + target.dispatch._events, "_no_async_engine_events" + ): + target.dispatch._events._no_async_engine_events() + else: + return None @classmethod def _listen( # type: ignore[override] # would rather keep **kw diff --git a/test/base/test_events.py b/test/base/test_events.py index 2b785ba0b3..7de245c642 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -10,6 +10,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_deprecated +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not @@ -198,6 +199,50 @@ class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase): eq_(m1.mock_calls, [call(5, 6), call(9, 10)]) + def test_real_name_wrong_dispatch(self): + m1 = Mock() + + class E1(event.Events): + @classmethod + def _accept_with(cls, target): + if isinstance(target, T1): + return target + else: + m1.yup() + return None + + def event_one(self, x, y): + pass + + def event_two(self, x): + pass + + def event_three(self, x): + pass + + class T1: + dispatch = event.dispatcher(E1) + + class T2: + pass + + class E2(event.Events): + + _dispatch_target = T2 + + def event_four(self, x): + pass + + with expect_raises_message( + exc.InvalidRequestError, "No such event 'event_three'" + ): + + @event.listens_for(E2, "event_three") + def go(*arg): + pass + + eq_(m1.mock_calls, [call.yup()]) + def test_exec_once_exception(self): m1 = Mock() m1.side_effect = ValueError diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 1f40cbdecf..0fe14dc921 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -650,6 +650,28 @@ class AsyncEventTest(EngineFixture): ): event.listen(conn, "before_cursor_execute", mock.Mock()) + @async_test + async def test_no_async_listeners_dialect_event(self, async_engine): + with testing.expect_raises_message( + NotImplementedError, + "asynchronous events are not implemented " + "at this time. Apply synchronous listeners to the " + "AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes.", + ): + event.listen(async_engine, "do_execute", mock.Mock()) + + @async_test + async def test_no_async_listeners_pool_event(self, async_engine): + with testing.expect_raises_message( + NotImplementedError, + "asynchronous events are not implemented " + "at this time. Apply synchronous listeners to the " + "AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes.", + ): + event.listen(async_engine, "checkout", mock.Mock()) + @async_test async def test_sync_before_cursor_execute_engine(self, async_engine): canary = mock.Mock() -- 2.47.2