]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure exception raised for all stream w/ sync result
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Feb 2022 14:04:49 +0000 (09:04 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Feb 2022 17:49:24 +0000 (12:49 -0500)
Fixed issue where the :meth:`_asyncio.AsyncSession.execute` method failed
to raise an informative exception if the ``stream_results`` execution
option were used, which is incompatible with a sync-style
:class:`_result.Result` object. An exception is now raised in this scenario
in the same way one is already raised when using ``stream_results`` in
conjunction with the :meth:`_asyncio.AsyncConnection.execute` method.
Additionally, for improved stability with state-sensitive dialects such as
asyncmy, the cursor is now closed when this error condition is raised;
previously with the asyncmy dialect, the connection would go into an
invalid state with unconsumed server side results remaining.

Fixes: #7667
Change-Id: I6eb7affe08584889b57423a90258295f8b7085dc

doc/build/changelog/unreleased_14/7667.rst [new file with mode: 0644]
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/result.py
lib/sqlalchemy/ext/asyncio/session.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py

diff --git a/doc/build/changelog/unreleased_14/7667.rst b/doc/build/changelog/unreleased_14/7667.rst
new file mode 100644 (file)
index 0000000..d66572f
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 7667
+
+    Fixed issue where the :meth:`_asyncio.AsyncSession.execute` method failed
+    to raise an informative exception if the ``stream_results`` execution
+    option were used, which is incompatible with a sync-style
+    :class:`_result.Result` object. An exception is now raised in this scenario
+    in the same way one is already raised when using ``stream_results`` in
+    conjunction with the :meth:`_asyncio.AsyncConnection.execute` method.
+    Additionally, for improved stability with state-sensitive dialects such as
+    asyncmy, the cursor is now closed when this error condition is raised;
+    previously with the asyncmy dialect, the connection would go into an
+    invalid state with unconsumed server side results remaining.
+
index abb3650d2b3974bb535258c1d34ef85fae2b4550..f372b88985dc7c98ceb5a5dac2d3de5f5111e538 100644 (file)
@@ -1645,6 +1645,7 @@ class CursorResult(BaseCursorResult, Result):
     _cursor_metadata = CursorResultMetaData
     _cursor_strategy_cls = CursorFetchStrategy
     _no_result_metadata = _NO_RESULT_METADATA
+    _is_cursor = True
 
     def _fetchiter_impl(self):
         fetchone = self.cursor_strategy.fetchone
index 3f916fea0742d06e39fdec99dca4d7fcf3e5f560..5970e2448fac73a3b8d79104c8f6aa7c5c3f182e 100644 (file)
@@ -268,6 +268,7 @@ class ResultInternal(InPlaceGenerative):
     _generate_rows = True
     _unique_filter_state = None
     _post_creational_filter = None
+    _is_cursor = False
 
     @HasMemoized.memoized_attribute
     def _row_getter(self):
index 9bbc04e773da2cfd3471575d444150f5eddb930e..fcf3b974db8cad7cdfb4ad6409808e9e943bc2f4 100644 (file)
@@ -7,6 +7,7 @@
 from . import exc as async_exc
 from .base import ProxyComparable
 from .base import StartableContext
+from .result import _ensure_sync_result
 from .result import AsyncResult
 from ... import exc
 from ... import inspection
@@ -381,15 +382,8 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
             execution_options,
             _require_await=True,
         )
-        if result.context._is_server_side:
-            raise async_exc.AsyncMethodRequired(
-                "Can't use the connection.exec_driver_sql() method with a "
-                "server-side cursor."
-                "Use the connection.stream() method for an async "
-                "streaming result set."
-            )
 
-        return result
+        return await _ensure_sync_result(result, self.exec_driver_sql)
 
     async def stream(
         self,
@@ -462,14 +456,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
             execution_options,
             _require_await=True,
         )
-        if result.context._is_server_side:
-            raise async_exc.AsyncMethodRequired(
-                "Can't use the connection.execute() method with a "
-                "server-side cursor."
-                "Use the connection.stream() method for an async "
-                "streaming result set."
-            )
-        return result
+        return await _ensure_sync_result(result, self.execute)
 
     async def scalar(
         self,
index 81ef9915c528cded03c0278ac9daff4e3f1eab59..62e4a9a0e545b3464654fa77a5d2882955d92706 100644 (file)
@@ -7,6 +7,7 @@
 
 import operator
 
+from . import exc as async_exc
 from ...engine.result import _NO_ROW
 from ...engine.result import FilterResult
 from ...engine.result import FrozenResult
@@ -646,3 +647,24 @@ class AsyncMappingResult(AsyncCommon):
 
         """
         return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+async def _ensure_sync_result(result, calling_method):
+    if not result._is_cursor:
+        cursor_result = getattr(result, "raw", None)
+    else:
+        cursor_result = result
+    if cursor_result and cursor_result.context._is_server_side:
+        await greenlet_spawn(cursor_result.close)
+        raise async_exc.AsyncMethodRequired(
+            "Can't use the %s.%s() method with a "
+            "server-side cursor. "
+            "Use the %s.stream() method for an async "
+            "streaming result set."
+            % (
+                calling_method.__self__.__class__.__name__,
+                calling_method.__name__,
+                calling_method.__self__.__class__.__name__,
+            )
+        )
+    return result
index 0840a0d7d98bab5284b0675673f4d1ccd7001477..22de2cab136eb2b38ace733fbb8070e54d96b3ca 100644 (file)
@@ -8,6 +8,7 @@ from . import engine
 from . import result as _result
 from .base import ReversibleProxy
 from .base import StartableContext
+from .result import _ensure_sync_result
 from ... import util
 from ...orm import object_session
 from ...orm import Session
@@ -208,7 +209,7 @@ class AsyncSession(ReversibleProxy):
         else:
             execution_options = _EXECUTE_OPTIONS
 
-        return await greenlet_spawn(
+        result = await greenlet_spawn(
             self.sync_session.execute,
             statement,
             params=params,
@@ -216,6 +217,7 @@ class AsyncSession(ReversibleProxy):
             bind_arguments=bind_arguments,
             **kw,
         )
+        return await _ensure_sync_result(result, self.execute)
 
     async def scalar(
         self,
index 0fdbc28dfb846e56f9241c073af377a65f103a27..1f40cbdecf0a1ede37e35e1639c9fe76fcd8a088 100644 (file)
@@ -18,6 +18,7 @@ from sqlalchemy import union_all
 from sqlalchemy.ext.asyncio import async_engine_from_config
 from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.ext.asyncio import engine as _async_engine
+from sqlalchemy.ext.asyncio import exc as async_exc
 from sqlalchemy.ext.asyncio import exc as asyncio_exc
 from sqlalchemy.ext.asyncio.base import ReversibleProxy
 from sqlalchemy.ext.asyncio.engine import AsyncConnection
@@ -724,6 +725,32 @@ class AsyncInspection(EngineFixture):
 
 
 class AsyncResultTest(EngineFixture):
+    @async_test
+    async def test_no_ss_cursor_w_execute(self, async_engine):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            conn = await conn.execution_options(stream_results=True)
+            with expect_raises_message(
+                async_exc.AsyncMethodRequired,
+                r"Can't use the AsyncConnection.execute\(\) method with a "
+                r"server-side cursor. Use the AsyncConnection.stream\(\) "
+                r"method for an async streaming result set.",
+            ):
+                await conn.execute(select(users))
+
+    @async_test
+    async def test_no_ss_cursor_w_exec_driver_sql(self, async_engine):
+        async with async_engine.connect() as conn:
+            conn = await conn.execution_options(stream_results=True)
+            with expect_raises_message(
+                async_exc.AsyncMethodRequired,
+                r"Can't use the AsyncConnection.exec_driver_sql\(\) "
+                r"method with a "
+                r"server-side cursor. Use the AsyncConnection.stream\(\) "
+                r"method for an async streaming result set.",
+            ):
+                await conn.exec_driver_sql("SELECT * FROM users")
+
     @testing.combinations(
         (None,), ("scalars",), ("mappings",), argnames="filter_"
     )
index bcaea05e53f1e33183273dc17fe9b787e13b8294..f04b87f3718f1301fd3d6c794c0ef169c68560f5 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy import testing
 from sqlalchemy import update
 from sqlalchemy.ext.asyncio import async_object_session
 from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import exc as async_exc
 from sqlalchemy.ext.asyncio.base import ReversibleProxy
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
@@ -19,6 +20,7 @@ from sqlalchemy.orm import sessionmaker
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
@@ -165,6 +167,28 @@ class AsyncSessionQueryTest(AsyncFixture):
             ],
         )
 
+    @testing.combinations("statement", "execute", argnames="location")
+    @async_test
+    async def test_no_ss_cursor_w_execute(self, async_session, location):
+        User = self.classes.User
+
+        stmt = select(User)
+        if location == "statement":
+            stmt = stmt.execution_options(stream_results=True)
+
+        with expect_raises_message(
+            async_exc.AsyncMethodRequired,
+            r"Can't use the AsyncSession.execute\(\) method with a "
+            r"server-side cursor. Use the AsyncSession.stream\(\) "
+            r"method for an async streaming result set.",
+        ):
+            if location == "execute":
+                await async_session.execute(
+                    stmt, execution_options={"stream_results": True}
+                )
+            else:
+                await async_session.execute(stmt)
+
 
 class AsyncSessionTransactionTest(AsyncFixture):
     run_inserts = None