From f02349336fa4470dbb5ca8e4d16031b8aa86a74a Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Thu, 26 Aug 2021 22:00:33 +0200 Subject: [PATCH] Handle mappings passed to ``execution_options``. Fixed a bug in :meth:`_asyncio.AsyncSession.execute` and :meth:`_asyncio.AsyncSession.stream` that required ``execution_options`` to be an instance of ``immutabledict`` when defined. It now correctly accepts any mapping. Fixes: #6943 Change-Id: Ic09de480dc2da1b0bdce25acb60b8f01371971f9 --- doc/build/changelog/unreleased_14/6943.rst | 8 ++++++++ lib/sqlalchemy/ext/asyncio/session.py | 17 +++++++++++++++-- lib/sqlalchemy/orm/persistence.py | 2 +- lib/sqlalchemy/orm/session.py | 2 +- test/ext/asyncio/test_session_py3k.py | 14 ++++++++++---- 5 files changed, 35 insertions(+), 8 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6943.rst diff --git a/doc/build/changelog/unreleased_14/6943.rst b/doc/build/changelog/unreleased_14/6943.rst new file mode 100644 index 0000000000..4b980d0edd --- /dev/null +++ b/doc/build/changelog/unreleased_14/6943.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, asyncio + :tickets: 6943 + + Fixed a bug in :meth:`_asyncio.AsyncSession.execute` and + :meth:`_asyncio.AsyncSession.stream` that required ``execution_options`` + to be an instance of ``immutabledict`` when defined. It now + correctly accepts any mapping. diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index a10621eef3..5c6e7f5a7c 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -14,6 +14,9 @@ from ...orm import Session from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) +_STREAM_OPTIONS = util.immutabledict({"stream_results": True}) + @util.create_proxy_methods( Session, @@ -140,7 +143,12 @@ class AsyncSession(ReversibleProxy): """Execute a statement and return a buffered :class:`_engine.Result` object.""" - execution_options = execution_options.union({"prebuffer_rows": True}) + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS return await greenlet_spawn( self.sync_session.execute, @@ -205,7 +213,12 @@ class AsyncSession(ReversibleProxy): """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object.""" - execution_options = execution_options.union({"stream_results": True}) + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _STREAM_OPTIONS + ) + else: + execution_options = _STREAM_OPTIONS result = await greenlet_spawn( self.sync_session.execute, diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 4747d0bbac..fd484b52b3 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1833,7 +1833,7 @@ class BulkUDCompileState(CompileState): return ( statement, util.immutabledict(execution_options).union( - dict(_sa_orm_update_options=update_options) + {"_sa_orm_update_options": update_options} ), ) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index af803a1b03..0bdd5cc959 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1581,7 +1581,7 @@ class Session(_SessionClassMethods): :param execution_options: optional dictionary of execution options, which will be associated with the statement execution. This dictionary can provide a subset of the options that are accepted - by :meth:`_future.Connection.execution_options`, and may also + by :meth:`_engine.Connection.execution_options`, and may also provide additional options understood only in an ORM context. :param bind_arguments: dictionary of additional arguments to determine diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 0883cb026d..ebedfedbfb 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -65,7 +65,10 @@ class AsyncSessionTest(AsyncFixture): class AsyncSessionQueryTest(AsyncFixture): @async_test - async def test_execute(self, async_session): + @testing.combinations( + {}, dict(execution_options={"logging_token": "test"}), argnames="kw" + ) + async def test_execute(self, async_session, kw): User = self.classes.User stmt = ( @@ -74,7 +77,7 @@ class AsyncSessionQueryTest(AsyncFixture): .order_by(User.id) ) - result = await async_session.execute(stmt) + result = await async_session.execute(stmt, **kw) eq_(result.scalars().all(), self.static.user_address_result) @async_test @@ -103,7 +106,10 @@ class AsyncSessionQueryTest(AsyncFixture): @async_test @testing.requires.independent_cursors - async def test_stream_partitions(self, async_session): + @testing.combinations( + {}, dict(execution_options={"logging_token": "test"}), argnames="kw" + ) + async def test_stream_partitions(self, async_session, kw): User = self.classes.User stmt = ( @@ -112,7 +118,7 @@ class AsyncSessionQueryTest(AsyncFixture): .order_by(User.id) ) - result = await async_session.stream(stmt) + result = await async_session.stream(stmt, **kw) assert_result = [] async for partition in result.scalars().partitions(3): -- 2.47.2