From: Mike Bayer Date: Fri, 4 Feb 2022 14:04:49 +0000 (-0500) Subject: ensure exception raised for all stream w/ sync result X-Git-Tag: rel_2_0_0b1~499^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=faa9ef2cff53bde291df5ac3b5c4ed8f665ecd8c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure exception raised for all stream w/ sync result Fixed issue where the :meth:`_asyncio.AsyncSession.execute` method failed to raise an informative exception if the ``stream_results`` execution option were used, which is incompatible with a sync-style :class:`_result.Result` object. An exception is now raised in this scenario in the same way one is already raised when using ``stream_results`` in conjunction with the :meth:`_asyncio.AsyncConnection.execute` method. Additionally, for improved stability with state-sensitive dialects such as asyncmy, the cursor is now closed when this error condition is raised; previously with the asyncmy dialect, the connection would go into an invalid state with unconsumed server side results remaining. Fixes: #7667 Change-Id: I6eb7affe08584889b57423a90258295f8b7085dc --- diff --git a/doc/build/changelog/unreleased_14/7667.rst b/doc/build/changelog/unreleased_14/7667.rst new file mode 100644 index 0000000000..d66572feb0 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7667.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: bug, asyncio + :tickets: 7667 + + Fixed issue where the :meth:`_asyncio.AsyncSession.execute` method failed + to raise an informative exception if the ``stream_results`` execution + option were used, which is incompatible with a sync-style + :class:`_result.Result` object. An exception is now raised in this scenario + in the same way one is already raised when using ``stream_results`` in + conjunction with the :meth:`_asyncio.AsyncConnection.execute` method. + Additionally, for improved stability with state-sensitive dialects such as + asyncmy, the cursor is now closed when this error condition is raised; + previously with the asyncmy dialect, the connection would go into an + invalid state with unconsumed server side results remaining. + diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index abb3650d2b..f372b88985 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1645,6 +1645,7 @@ class CursorResult(BaseCursorResult, Result): _cursor_metadata = CursorResultMetaData _cursor_strategy_cls = CursorFetchStrategy _no_result_metadata = _NO_RESULT_METADATA + _is_cursor = True def _fetchiter_impl(self): fetchone = self.cursor_strategy.fetchone diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 3f916fea07..5970e2448f 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -268,6 +268,7 @@ class ResultInternal(InPlaceGenerative): _generate_rows = True _unique_filter_state = None _post_creational_filter = None + _is_cursor = False @HasMemoized.memoized_attribute def _row_getter(self): diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 9bbc04e773..fcf3b974db 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -7,6 +7,7 @@ from . import exc as async_exc from .base import ProxyComparable from .base import StartableContext +from .result import _ensure_sync_result from .result import AsyncResult from ... import exc from ... import inspection @@ -381,15 +382,8 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): execution_options, _require_await=True, ) - if result.context._is_server_side: - raise async_exc.AsyncMethodRequired( - "Can't use the connection.exec_driver_sql() method with a " - "server-side cursor." - "Use the connection.stream() method for an async " - "streaming result set." - ) - return result + return await _ensure_sync_result(result, self.exec_driver_sql) async def stream( self, @@ -462,14 +456,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): execution_options, _require_await=True, ) - if result.context._is_server_side: - raise async_exc.AsyncMethodRequired( - "Can't use the connection.execute() method with a " - "server-side cursor." - "Use the connection.stream() method for an async " - "streaming result set." - ) - return result + return await _ensure_sync_result(result, self.execute) async def scalar( self, diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index 81ef9915c5..62e4a9a0e5 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -7,6 +7,7 @@ import operator +from . import exc as async_exc from ...engine.result import _NO_ROW from ...engine.result import FilterResult from ...engine.result import FrozenResult @@ -646,3 +647,24 @@ class AsyncMappingResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, True, False) + + +async def _ensure_sync_result(result, calling_method): + if not result._is_cursor: + cursor_result = getattr(result, "raw", None) + else: + cursor_result = result + if cursor_result and cursor_result.context._is_server_side: + await greenlet_spawn(cursor_result.close) + raise async_exc.AsyncMethodRequired( + "Can't use the %s.%s() method with a " + "server-side cursor. " + "Use the %s.stream() method for an async " + "streaming result set." + % ( + calling_method.__self__.__class__.__name__, + calling_method.__name__, + calling_method.__self__.__class__.__name__, + ) + ) + return result diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 0840a0d7d9..22de2cab13 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -8,6 +8,7 @@ from . import engine from . import result as _result from .base import ReversibleProxy from .base import StartableContext +from .result import _ensure_sync_result from ... import util from ...orm import object_session from ...orm import Session @@ -208,7 +209,7 @@ class AsyncSession(ReversibleProxy): else: execution_options = _EXECUTE_OPTIONS - return await greenlet_spawn( + result = await greenlet_spawn( self.sync_session.execute, statement, params=params, @@ -216,6 +217,7 @@ class AsyncSession(ReversibleProxy): bind_arguments=bind_arguments, **kw, ) + return await _ensure_sync_result(result, self.execute) async def scalar( self, diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 0fdbc28dfb..1f40cbdecf 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -18,6 +18,7 @@ from sqlalchemy import union_all from sqlalchemy.ext.asyncio import async_engine_from_config from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import engine as _async_engine +from sqlalchemy.ext.asyncio import exc as async_exc from sqlalchemy.ext.asyncio import exc as asyncio_exc from sqlalchemy.ext.asyncio.base import ReversibleProxy from sqlalchemy.ext.asyncio.engine import AsyncConnection @@ -724,6 +725,32 @@ class AsyncInspection(EngineFixture): class AsyncResultTest(EngineFixture): + @async_test + async def test_no_ss_cursor_w_execute(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + with expect_raises_message( + async_exc.AsyncMethodRequired, + r"Can't use the AsyncConnection.execute\(\) method with a " + r"server-side cursor. Use the AsyncConnection.stream\(\) " + r"method for an async streaming result set.", + ): + await conn.execute(select(users)) + + @async_test + async def test_no_ss_cursor_w_exec_driver_sql(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + with expect_raises_message( + async_exc.AsyncMethodRequired, + r"Can't use the AsyncConnection.exec_driver_sql\(\) " + r"method with a " + r"server-side cursor. Use the AsyncConnection.stream\(\) " + r"method for an async streaming result set.", + ): + await conn.exec_driver_sql("SELECT * FROM users") + @testing.combinations( (None,), ("scalars",), ("mappings",), argnames="filter_" ) diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index bcaea05e53..f04b87f371 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -11,6 +11,7 @@ from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import exc as async_exc from sqlalchemy.ext.asyncio.base import ReversibleProxy from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload @@ -19,6 +20,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import async_test from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock @@ -165,6 +167,28 @@ class AsyncSessionQueryTest(AsyncFixture): ], ) + @testing.combinations("statement", "execute", argnames="location") + @async_test + async def test_no_ss_cursor_w_execute(self, async_session, location): + User = self.classes.User + + stmt = select(User) + if location == "statement": + stmt = stmt.execution_options(stream_results=True) + + with expect_raises_message( + async_exc.AsyncMethodRequired, + r"Can't use the AsyncSession.execute\(\) method with a " + r"server-side cursor. Use the AsyncSession.stream\(\) " + r"method for an async streaming result set.", + ): + if location == "execute": + await async_session.execute( + stmt, execution_options={"stream_results": True} + ) + else: + await async_session.execute(stmt) + class AsyncSessionTransactionTest(AsyncFixture): run_inserts = None