From: Mike Bayer Date: Tue, 25 Oct 2022 13:10:09 +0000 (-0400) Subject: Support result.close() for all iterator patterns X-Git-Tag: rel_2_0_0b3~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b96321ae79a0366c33ca739e6e67aaf5f4420db4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support result.close() for all iterator patterns This change contains new features for 2.0 only as well as some behaviors that will be backported to 1.4. For 1.4 and 2.0: Fixed issue where the underlying DBAPI cursor would not be closed when using :class:`_orm.Query` with :meth:`_orm.Query.yield_per` and direct iteration, if a user-defined exception case were raised within the iteration process, interrupting the iterator. This would lead to the usual MySQL-related issues with server side cursors out of sync. For 1.4 only: A similar scenario can occur when using :term:`2.x` executions with direct use of :class:`.Result`, in that case the end-user code has access to the :class:`.Result` itself and should call :meth:`.Result.close` directly. Version 2.0 will feature context-manager calling patterns to address this use case. However within the 1.4 scope, ensured that ``.close()`` methods are available on all :class:`.Result` implementations including :class:`.ScalarResult`, :class:`.MappingResult`. For 2.0 only: To better support the use case of iterating :class:`.Result` and :class:`.AsyncResult` objects where user-defined exceptions may interrupt the iteration, both objects as well as variants such as :class:`.ScalarResult`, :class:`.MappingResult`, :class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support context manager usage, where the result will be closed at the end of iteration. Corrected various typing issues within the engine and async engine packages. Fixes: #8710 Change-Id: I3166328bfd3900957eb33cbf1061d0495c9df670 --- diff --git a/doc/build/changelog/unreleased_14/8710.rst b/doc/build/changelog/unreleased_14/8710.rst new file mode 100644 index 0000000000..c687d4f838 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8710.rst @@ -0,0 +1,31 @@ +.. change:: + :tags: bug, orm + :tickets: 8710 + + Fixed issue where the underlying DBAPI cursor would not be closed when + using :class:`_orm.Query` and direct iteration, if a user-defined exception + case were raised within the iteration process, interrupting the iterator + which otherwise is not possible to re-use in this context. When using + :meth:`_orm.Query.yield_per` to create server-side cursors, this would lead + to the usual MySQL-related issues with server side cursors out of sync. + + To resolve, a catch for ``GeneratorExit`` is applied within the default + iterator, which applies only in those cases where the interpreter is + calling ``.close()`` on the iterator in any case. + + A similar scenario can occur when using :term:`2.x` executions with direct + use of :class:`.Result`, in that case the end-user code has access to the + :class:`.Result` itself and should call :meth:`.Result.close` directly. + Version 2.0 will feature context-manager calling patterns to address this + use case. However within the 1.4 scope, ensured that ``.close()`` methods + are available on all :class:`.Result` implementations including + :class:`.ScalarResult`, :class:`.MappingResult`. + + +.. change:: + :tags: bug, engine + :tickets: 8710 + + Ensured all :class:`.Result` objects include a :meth:`.Result.close` method + as well as a :attr:`.Result.closed` attribute, including on + :class:`.ScalarResult` and :class:`.MappingResult`. diff --git a/doc/build/changelog/unreleased_20/8710.rst b/doc/build/changelog/unreleased_20/8710.rst new file mode 100644 index 0000000000..0b1f166517 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8710.rst @@ -0,0 +1,28 @@ +.. change:: + :tags: feature, engine + :tickets: 8710 + + To better support the use case of iterating :class:`.Result` and + :class:`.AsyncResult` objects where user-defined exceptions may interrupt + the iteration, both objects as well as variants such as + :class:`.ScalarResult`, :class:`.MappingResult`, + :class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support + context manager usage, where the result will be closed at the end of + the context manager block. + + In addition, ensured that all the above + mentioned :class:`.Result` objects include a :meth:`.Result.close` method + as well as :attr:`.Result.closed` accessors, including + :class:`.ScalarResult` and :class:`.MappingResult` which previously did + not have a ``.close()`` method. + + .. seealso:: + + :ref:`change_8710` + + +.. change:: + :tags: bug, typing + + Corrected various typing issues within the engine and async engine + packages. diff --git a/doc/build/changelog/whatsnew_20.rst b/doc/build/changelog/whatsnew_20.rst index 3d4eca6b2b..a2def39f18 100644 --- a/doc/build/changelog/whatsnew_20.rst +++ b/doc/build/changelog/whatsnew_20.rst @@ -1539,6 +1539,37 @@ backend:: :ticket:`7631` + +.. _change_8710: + +Context Manager Support for ``Result``, ``AsyncResult`` +------------------------------------------------------- + +The :class:`.Result` object now supports context manager use, which will +ensure the object and its underlying cursor is closed at the end of the block. +This is useful in particular with server side cursors, where it's important that +the open cursor object is closed at the end of an operation, even if user-defined +exceptions have occurred:: + + with engine.connect() as conn: + with conn.execution_options(yield_per=100).execute( + text("select * from table") + ) as result: + for row in result: + print(f"{row}") + +With asyncio use, the :class:`.AsyncResult` and :class:`.AsyncConnection` have +been altered to provide for optional async context manager use, as in:: + + async with async_engine.connect() as conn: + async with conn.execution_options(yield_per=100).execute( + text("select * from table") + ) as result: + for row in result: + print(f"{row}") + +:ticket:`8710` + Behavioral Changes ------------------ diff --git a/doc/build/core/connections.rst b/doc/build/core/connections.rst index 7b96c1df9a..a39c452f61 100644 --- a/doc/build/core/connections.rst +++ b/doc/build/core/connections.rst @@ -16,7 +16,7 @@ higher level management services, the :class:`_engine.Engine` and :class:`_engine.Connection` are king (and queen?) - read on. Basic Usage -=========== +----------- Recall from :doc:`/core/engines` that an :class:`_engine.Engine` is created via the :func:`_sa.create_engine` call:: @@ -82,7 +82,7 @@ in the :ref:`unified_tutorial` for a tutorial. Using Transactions -================== +------------------ .. note:: @@ -94,7 +94,7 @@ Using Transactions information. Commit As You Go ----------------- +~~~~~~~~~~~~~~~~ The :class:`~sqlalchemy.engine.Connection` object always emits SQL statements within the context of a transaction block. The first time the @@ -156,7 +156,7 @@ emitted, a new transaction begins implicitly:: mode when using a "future" style engine. Begin Once ----------------- +~~~~~~~~~~ The :class:`_engine.Connection` object provides a more explicit transaction management style referred towards as **begin once**. In contrast to "commit as @@ -184,7 +184,7 @@ once" block:: # transaction is committed Connect and Begin Once from the Engine ---------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ A convenient shorthand form for the above "begin once" block is to use the :meth:`_engine.Engine.begin` method at the level of the originating @@ -226,7 +226,7 @@ returned by the :meth:`_engine.Connection.begin` method:: further commands. Mixing Styles -------------- +~~~~~~~~~~~~~ The "commit as you go" and "begin once" styles can be freely mixed within a single :meth:`_engine.Engine.connect` block, provided that the call to @@ -262,7 +262,7 @@ When developing code that uses "begin once", the library will raise .. _dbapi_autocommit: Setting Transaction Isolation Levels including DBAPI Autocommit -================================================================= +--------------------------------------------------------------- Most DBAPIs support the concept of configurable transaction :term:`isolation` levels. These are traditionally the four levels "READ UNCOMMITTED", "READ COMMITTED", @@ -294,7 +294,7 @@ SQLAlchemy dialects should support these isolation levels as well as autocommit to as great a degree as possible. Setting Isolation Level or DBAPI Autocommit for a Connection ------------------------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For an individual :class:`_engine.Connection` object that's acquired from :meth:`.Engine.connect`, the isolation level can be set for the duration of @@ -341,7 +341,7 @@ begin a transaction:: on a per-transaction basis. Setting Isolation Level or DBAPI Autocommit for an Engine ----------------------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The :paramref:`_engine.Connection.execution_options.isolation_level` option may also be set engine wide, as is often preferable. This may be @@ -361,7 +361,7 @@ subsequent operations. .. _dbapi_autocommit_multiple: Maintaining Multiple Isolation Levels for a Single Engine ----------------------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The isolation level may also be set per engine, with a potentially greater level of flexibility, using either the @@ -429,7 +429,7 @@ reverted when a connection is returned to the connection pool. .. _dbapi_autocommit_understanding: Understanding the DBAPI-Level Autocommit Isolation Level ---------------------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In the parent section, we introduced the concept of the :paramref:`_engine.Connection.execution_options.isolation_level` @@ -588,7 +588,7 @@ To sum up: .. _engine_stream_results: Using Server Side Cursors (a.k.a. stream results) -================================================== +------------------------------------------------- Some backends feature explicit support for the concept of "server side cursors" versus "client side cursors". A client side cursor here @@ -644,7 +644,7 @@ or per-statement basis. Similar options exist when using an ORM Streaming with a fixed buffer via yield_per --------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ As individual row-fetch operations with fully unbuffered server side cursors are typically more expensive than fetching batches of rows at once, The @@ -683,12 +683,13 @@ combination has includes: These three behaviors are illustrated in the example below:: with engine.connect() as conn: - result = conn.execution_options(yield_per=100).execute(text("select * from table")) - - for partition in result.partitions(): - # partition is an iterable that will be at most 100 items - for row in partition: - print(f"{row}") + with conn.execution_options(yield_per=100).execute( + text("select * from table") + ) as result: + for partition in result.partitions(): + # partition is an iterable that will be at most 100 items + for row in partition: + print(f"{row}") The above example illustrates the combination of ``yield_per=100`` along with using the :meth:`_engine.Result.partitions` method to run processing @@ -699,6 +700,13 @@ buffered for each 100 rows fetched. Calling a method such as :meth:`_engine.Result.all` should **not** be used, as this will fully fetch all remaining rows at once and defeat the purpose of using ``yield_per``. +.. tip:: + + The :class:`.Result` object may be used as a context manager as illustrated + above. When iterating with a server-side cursor, this is the best way to + ensure the :class:`.Result` object is closed, even if exceptions are + raised within the iteration process. + The :paramref:`_engine.Connection.execution_options.yield_per` option is portable to the ORM as well, used by a :class:`_orm.Session` to fetch ORM objects, where it also limits the amount of ORM objects generated at once. @@ -715,7 +723,7 @@ for further background on using .. _engine_stream_results_sr: Streaming with a dynamically growing buffer using stream_results ------------------------------------------------------------------ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To enable server side cursors without a specific partition size, the :paramref:`_engine.Connection.execution_options.stream_results` option may be @@ -731,11 +739,12 @@ of 1000 rows. The maximum size of this buffer can be affected using the :paramref:`_engine.Connection.execution_options.max_row_buffer` execution option:: with engine.connect() as conn: - conn = conn.execution_options(stream_results=True, max_row_buffer=100) - result = conn.execute(text("select * from table")) + with conn.execution_options(stream_results=True, max_row_buffer=100).execute( + text("select * from table") + ) as result: - for row in result: - print(f"{row}") + for row in result: + print(f"{row}") While the :paramref:`_engine.Connection.execution_options.stream_results` option may be combined with use of the :meth:`_engine.Result.partitions` @@ -757,7 +766,7 @@ up to use the :meth:`_engine.Result.partitions` method. .. _schema_translating: Translation of Schema Names -=========================== +--------------------------- To support multi-tenancy applications that distribute common sets of tables into multiple schemas, the @@ -853,7 +862,7 @@ as the schema name is passed to these methods explicitly. SQL Compilation Caching -======================= +----------------------- .. versionadded:: 1.4 SQLAlchemy now has a transparent query caching system that substantially lowers the Python computational overhead involved in @@ -903,7 +912,7 @@ detail the configuration and advanced usage patterns for the cache. Configuration -------------- +~~~~~~~~~~~~~ The cache itself is a dictionary-like object called an ``LRUCache``, which is an internal SQLAlchemy dictionary subclass that tracks the usage of particular @@ -930,7 +939,7 @@ cache's behavior, described in the next section. .. _sql_caching_logging: Estimating Cache Performance Using Logging ------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The above cache size of 1200 is actually fairly large. For small applications, a size of 100 is likely sufficient. To estimate the optimal size of the cache, @@ -1148,7 +1157,7 @@ obviously an extremely small size, and the default size of 500 is fine to be lef at its default. How much memory does the cache use? ------------------------------------ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The previous section detailed some techniques to check if the :paramref:`_sa.create_engine.query_cache_size` needs to be bigger. How do we know @@ -1170,7 +1179,7 @@ moderate Core statement takes up about 12K while a small ORM statement takes abo .. _engine_compiled_cache: Disabling or using an alternate dictionary to cache some (or all) statements ------------------------------------------------------------------------------ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The internal cache used is known as ``LRUCache``, but this is mostly just a dictionary. Any dictionary may be used as a cache for any series of @@ -1200,7 +1209,7 @@ The cache can also be disabled with this argument by sending a value of .. _engine_thirdparty_caching: Caching for Third Party Dialects ---------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The caching feature requires that the dialect's compiler produces SQL strings that are safe to reuse for many statement invocations, given @@ -1342,7 +1351,7 @@ SELECTs with LIMIT/OFFSET are correctly rendered and cached. .. _engine_lambda_caching: Using Lambdas to add significant speed gains to statement production --------------------------------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. deepalchemy:: This technique is generally non-essential except in very performance intensive scenarios, and intended for experienced Python programmers. @@ -1764,7 +1773,7 @@ performance example. .. _engine_insertmanyvalues: "Insert Many Values" Behavior for INSERT statements -==================================================== +--------------------------------------------------- .. versionadded:: 2.0 see :ref:`change_6047` for background on the change including sample performance tests @@ -1853,7 +1862,7 @@ as follows: the same usage patterns and equivalent performance benefits. Enabling/Disabling the feature ------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To disable the "insertmanyvalues" feature for a given backend for an :class:`.Engine` overall, pass the @@ -1886,7 +1895,7 @@ such a table may need to include ``implicit_returning=False`` (see .. _engine_insertmanyvalues_page_size: Controlling the Batch Size ---------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~ A key characteristic of "insertmanyvalues" is that the size of the INSERT statement is limited on a fixed max number of "values" clauses as well as a @@ -1954,7 +1963,7 @@ Or configured on the statement itself:: .. _engine_insertmanyvalues_events: Logging and Events -------------------- +~~~~~~~~~~~~~~~~~~ The "insertmanyvalues" feature integrates fully with SQLAlchemy's statement logging as well as cursor events such as :meth:`.ConnectionEvents.before_cursor_execute`. @@ -1979,7 +1988,7 @@ an excerpt of this logging: [insertmanyvalues batch 10 of 10] ('d900', 900, 9000, 'd901', ... Upsert Support --------------- +~~~~~~~~~~~~~~ The PostgreSQL, SQLite, and MariaDB dialects offer backend-specific "upsert" constructs :func:`_postgresql.insert`, :func:`_sqlite.insert` @@ -1993,7 +2002,7 @@ with RETURNING to take place. .. _engine_disposal: Engine Disposal -=============== +--------------- The :class:`_engine.Engine` refers to a connection pool, which means under normal circumstances, there are open database connections present while the @@ -2070,7 +2079,7 @@ for guidelines on how to disable pooling. .. _dbapi_connections: Working with Driver SQL and Raw DBAPI Connections -================================================= +------------------------------------------------- The introduction on using :meth:`_engine.Connection.execute` made use of the :func:`_expression.text` construct in order to illustrate how textual SQL statements @@ -2082,7 +2091,7 @@ SQL in that it normalizes how bound parameters are passed, as well as that it supports datatyping behavior for parameters and result set rows. Invoking SQL strings directly to the driver --------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For the use case where one wants to invoke textual SQL directly passed to the underlying driver (known as the :term:`DBAPI`) without any intervention @@ -2097,7 +2106,7 @@ method may be used:: .. _dbapi_connections_cursor: Working with the DBAPI cursor directly --------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ There are some cases where SQLAlchemy does not provide a genericized way at accessing some :term:`DBAPI` functions, such as calling stored procedures as well @@ -2153,7 +2162,7 @@ Some recipes for DBAPI connection use follow. .. _stored_procedures: Calling Stored Procedures and User Defined Functions ------------------------------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SQLAlchemy supports calling stored procedures and user defined functions several ways. Please note that all DBAPIs have different practices, so you must @@ -2198,7 +2207,7 @@ situations to determine the correct syntax and patterns to use. Multiple Result Sets --------------------- +~~~~~~~~~~~~~~~~~~~~ Multiple result set support is available from a raw DBAPI cursor using the `nextset `_ method:: @@ -2215,7 +2224,7 @@ Multiple result set support is available from a raw DBAPI cursor using the connection.close() Registering New Dialects -======================== +------------------------ The :func:`_sa.create_engine` function call locates the given dialect using setuptools entrypoints. These entry points can be established @@ -2250,7 +2259,7 @@ The above entrypoint would then be accessed as ``create_engine("mysql+foodialect Registering Dialects In-Process -------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SQLAlchemy also allows a dialect to be registered within the current process, bypassing the need for separate installation. Use the ``register()`` function as follows:: @@ -2265,7 +2274,7 @@ The above will respond to ``create_engine("mysql+foodialect://")`` and load the Connection / Engine API -======================= +----------------------- .. autoclass:: Connection :members: @@ -2296,7 +2305,7 @@ Connection / Engine API Result Set API -================= +--------------- .. autoclass:: ChunkedIteratorResult :members: diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 05ca17063b..bcd2f0ea9d 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -928,6 +928,12 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): def __init__(self, cursor_metadata: ResultMetaData): self._metadata = cursor_metadata + def __enter__(self) -> Result[_TP]: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + def close(self) -> None: """close this :class:`_result.Result`. @@ -950,6 +956,19 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): """ self._soft_close(hard=True) + @property + def _soft_closed(self) -> bool: + raise NotImplementedError() + + @property + def closed(self) -> bool: + """return True if this :class:`.Result` reports .closed + + .. versionadded:: 1.4.43 + + """ + raise NotImplementedError() + @_generative def yield_per(self: SelfResult, num: int) -> SelfResult: """Configure the row-fetching strategy to fetch ``num`` rows at a time. @@ -1574,6 +1593,12 @@ class FilterResult(ResultInternal[_R]): _real_result: Result[Any] + def __enter__(self: SelfFilterResult) -> SelfFilterResult: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self._real_result.__exit__(type_, value, traceback) + @_generative def yield_per(self: SelfFilterResult, num: int) -> SelfFilterResult: """Configure the row-fetching strategy to fetch ``num`` rows at a time. @@ -1599,6 +1624,27 @@ class FilterResult(ResultInternal[_R]): def _soft_close(self, hard: bool = False) -> None: self._real_result._soft_close(hard=hard) + @property + def _soft_closed(self) -> bool: + return self._real_result._soft_closed + + @property + def closed(self) -> bool: + """return True if the underlying :class:`.Result` reports .closed + + .. versionadded:: 1.4.43 + + """ + return self._real_result.closed + + def close(self) -> None: + """Close this :class:`.FilterResult`. + + .. versionadded:: 1.4.43 + + """ + self._real_result.close() + @property def _attributes(self) -> Dict[Any, Any]: return self._real_result._attributes @@ -2172,7 +2218,7 @@ class IteratorResult(Result[_TP]): self, cursor_metadata: ResultMetaData, iterator: Iterator[_InterimSupportsScalarsRowType], - raw: Optional[Any] = None, + raw: Optional[Result[Any]] = None, _source_supports_scalars: bool = False, ): self._metadata = cursor_metadata @@ -2180,6 +2226,15 @@ class IteratorResult(Result[_TP]): self.raw = raw self._source_supports_scalars = _source_supports_scalars + @property + def closed(self) -> bool: + """return True if this :class:`.IteratorResult` has been closed + + .. versionadded:: 1.4.43 + + """ + return self._hard_closed + def _soft_close(self, hard: bool = False, **kw: Any) -> None: if hard: self._hard_closed = True @@ -2262,7 +2317,7 @@ class ChunkedIteratorResult(IteratorResult[_TP]): [Optional[int]], Iterator[Sequence[_InterimRowType[_R]]] ], source_supports_scalars: bool = False, - raw: Optional[Any] = None, + raw: Optional[Result[Any]] = None, dynamic_yield_per: bool = False, ): self._metadata = cursor_metadata diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 7fdd2d7e06..13d5e40b24 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -10,8 +10,13 @@ from __future__ import annotations import abc import functools from typing import Any +from typing import AsyncGenerator +from typing import AsyncIterator +from typing import Awaitable +from typing import Callable from typing import ClassVar from typing import Dict +from typing import Generator from typing import Generic from typing import NoReturn from typing import Optional @@ -25,6 +30,7 @@ from ... import util from ...util.typing import Literal _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _PT = TypeVar("_PT", bound=Any) @@ -114,27 +120,29 @@ class ReversibleProxy(Generic[_PT]): SelfStartableContext = TypeVar( - "SelfStartableContext", bound="StartableContext" + "SelfStartableContext", bound="StartableContext[Any]" ) -class StartableContext(abc.ABC): +class StartableContext(Awaitable[_T_co], abc.ABC): __slots__ = () @abc.abstractmethod async def start( self: SelfStartableContext, is_ctxmanager: bool = False - ) -> Any: + ) -> _T_co: raise NotImplementedError() - def __await__(self) -> Any: + def __await__(self) -> Generator[Any, Any, _T_co]: return self.start().__await__() - async def __aenter__(self: SelfStartableContext) -> Any: - return await self.start(is_ctxmanager=True) + async def __aenter__(self: SelfStartableContext) -> _T_co: + return await self.start(is_ctxmanager=True) # type: ignore @abc.abstractmethod - async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + async def __aexit__( + self, type_: Any, value: Any, traceback: Any + ) -> Optional[bool]: pass def _raise_for_not_started(self) -> NoReturn: @@ -144,6 +152,129 @@ class StartableContext(abc.ABC): ) +class GeneratorStartableContext(StartableContext[_T_co]): + __slots__ = ("gen",) + + gen: AsyncGenerator[_T_co, Any] + + def __init__( + self, + func: Callable[..., AsyncIterator[_T_co]], + args: tuple[Any, ...], + kwds: dict[str, Any], + ): + self.gen = func(*args, **kwds) # type: ignore + + async def start(self, is_ctxmanager: bool = False) -> _T_co: + try: + start_value = await util.anext_(self.gen) + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None + + # if not a context manager, then interrupt the generator, don't + # let it complete. this step is technically not needed, as the + # generator will close in any case at gc time. not clear if having + # this here is a good idea or not (though it helps for clarity IMO) + if not is_ctxmanager: + await self.gen.aclose() + + return start_value + + async def __aexit__( + self, typ: Any, value: Any, traceback: Any + ) -> Optional[bool]: + # vendored from contextlib.py + if typ is None: + try: + await util.anext_(self.gen) + except StopAsyncIteration: + return False + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + # Need to force instantiation so we can reliably + # tell if we get the same exception back + value = typ() + try: + await self.gen.athrow(typ, value, traceback) + except StopAsyncIteration as exc: + # Suppress StopIteration *unless* it's the same exception that + # was passed to throw(). This prevents a StopIteration + # raised inside the "with" statement from being suppressed. + return exc is not value + except RuntimeError as exc: + # Don't re-raise the passed in exception. (issue27122) + if exc is value: + return False + # Avoid suppressing if a Stop(Async)Iteration exception + # was passed to athrow() and later wrapped into a RuntimeError + # (see PEP 479 for sync generators; async generators also + # have this behavior). But do this only if the exception + # wrapped + # by the RuntimeError is actully Stop(Async)Iteration (see + # issue29692). + if ( + isinstance(value, (StopIteration, StopAsyncIteration)) + and exc.__cause__ is value + ): + return False + raise + except BaseException as exc: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + if exc is not value: + raise + return False + raise RuntimeError("generator didn't stop after athrow()") + + +def asyncstartablecontext( + func: Callable[..., AsyncIterator[_T_co]] +) -> Callable[..., GeneratorStartableContext[_T_co]]: + """@asyncstartablecontext decorator. + + the decorated function can be called either as ``async with fn()``, **or** + ``await fn()``. This is decidedly different from what + ``@contextlib.asynccontextmanager`` supports, and the usage pattern + is different as well. + + Typical usage:: + + @asyncstartablecontext + async def some_async_generator(): + + try: + yield + except GeneratorExit: + # return value was awaited, no context manager is present + # and caller will .close() the resource explicitly + pass + else: + + + + Above, ``GeneratorExit`` is caught if the function were used as an + ``await``. In this case, it's essential that the cleanup does **not** + occur, so there should not be a ``finally`` block. + + If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__`` + and we were invoked as a context manager, and cleanup should proceed. + + + """ + + @functools.wraps(func) + def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]: + return GeneratorStartableContext(func, args, kwds) + + return helper + + class ProxyComparable(ReversibleProxy[_PT]): __slots__ = () diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 3890888754..854039c982 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -7,7 +7,9 @@ from __future__ import annotations import asyncio +import contextlib from typing import Any +from typing import AsyncIterator from typing import Callable from typing import Dict from typing import Generator @@ -21,6 +23,8 @@ from typing import TypeVar from typing import Union from . import exc as async_exc +from .base import asyncstartablecontext +from .base import GeneratorStartableContext from .base import ProxyComparable from .base import StartableContext from .result import _ensure_sync_result @@ -133,7 +137,9 @@ class AsyncConnectable: ], ) class AsyncConnection( - ProxyComparable[Connection], StartableContext, AsyncConnectable + ProxyComparable[Connection], + StartableContext["AsyncConnection"], + AsyncConnectable, ): """An asyncio proxy for a :class:`_engine.Connection`. @@ -446,34 +452,67 @@ class AsyncConnection( return await _ensure_sync_result(result, self.exec_driver_sql) @overload - async def stream( + def stream( self, statement: TypedReturnsRows[_T], parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[_T]: + ) -> GeneratorStartableContext[AsyncResult[_T]]: ... @overload - async def stream( + def stream( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[Any]: + ) -> GeneratorStartableContext[AsyncResult[Any]]: ... + @asyncstartablecontext async def stream( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[Any]: - """Execute a statement and return a streaming - :class:`_asyncio.AsyncResult` object.""" + ) -> AsyncIterator[AsyncResult[Any]]: + """Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncResult` object. + + E.g.:: + + result = await conn.stream(stmt): + async for row in result: + print(f"{row}") + + The :meth:`.AsyncConnection.stream` + method supports optional context manager use against the + :class:`.AsyncResult` object, as in:: + + async with conn.stream(stmt) as result: + async for row in result: + print(f"{row}") + + In the above pattern, the :meth:`.AsyncResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncResult` object. + + .. seealso:: + + :meth:`.AsyncConnection.stream_scalars` + + """ result = await greenlet_spawn( self._proxied.execute, @@ -484,10 +523,15 @@ class AsyncConnection( ), _require_await=True, ) - if not result.context._is_server_side: - # TODO: real exception here - assert False, "server side result expected" - return AsyncResult(result) + assert result.context._is_server_side + ar = AsyncResult(result) + try: + yield ar + except GeneratorExit: + pass + else: + task = asyncio.create_task(ar.close()) + await asyncio.shield(task) @overload async def execute( @@ -642,48 +686,77 @@ class AsyncConnection( return result.scalars() @overload - async def stream_scalars( + def stream_scalars( self, statement: TypedReturnsRows[Tuple[_T]], parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[_T]: + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ... @overload - async def stream_scalars( + def stream_scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[Any]: + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ... + @asyncstartablecontext async def stream_scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[Any]: - r"""Executes a SQL statement and returns a streaming scalar result - object. + ) -> AsyncIterator[AsyncScalarResult[Any]]: + r"""Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncScalarResult` object. + + E.g.:: + + result = await conn.stream_scalars(stmt): + async for scalar in result: + print(f"{scalar}") This method is shorthand for invoking the :meth:`_engine.AsyncResult.scalars` method after invoking the :meth:`_engine.Connection.stream` method. Parameters are equivalent. - :return: an :class:`_asyncio.AsyncScalarResult` object. + The :meth:`.AsyncConnection.stream_scalars` + method supports optional context manager use against the + :class:`.AsyncScalarResult` object, as in:: + + async with conn.stream_scalars(stmt) as result: + async for scalar in result: + print(f"{scalar}") + + In the above pattern, the :meth:`.AsyncScalarResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncScalarResult` object. .. versionadded:: 1.4.24 + .. seealso:: + + :meth:`.AsyncConnection.stream` + """ - result = await self.stream( + + async with self.stream( statement, parameters, execution_options=execution_options - ) - return result.scalars() + ) as result: + yield result.scalars() async def run_sync( self, fn: Callable[..., Any], *arg: Any, **kw: Any @@ -856,37 +929,6 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): :ref:`asyncio_events` """ - class _trans_ctx(StartableContext): - __slots__ = ("conn", "transaction") - - conn: AsyncConnection - transaction: AsyncTransaction - - def __init__(self, conn: AsyncConnection): - self.conn = conn - - if TYPE_CHECKING: - - async def __aenter__(self) -> AsyncConnection: - ... - - async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: - await self.conn.start(is_ctxmanager=is_ctxmanager) - self.transaction = self.conn.begin() - await self.transaction.__aenter__() - - return self.conn - - async def __aexit__( - self, type_: Any, value: Any, traceback: Any - ) -> None: - async def go() -> None: - await self.transaction.__aexit__(type_, value, traceback) - await self.conn.close() - - task = asyncio.create_task(go()) - await asyncio.shield(task) - def __init__(self, sync_engine: Engine): if not sync_engine.dialect.is_async: raise exc.InvalidRequestError( @@ -903,7 +945,8 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: return AsyncEngine(target) - def begin(self) -> AsyncEngine._trans_ctx: + @contextlib.asynccontextmanager + async def begin(self) -> AsyncIterator[AsyncConnection]: """Return a context manager which when entered will deliver an :class:`_asyncio.AsyncConnection` with an :class:`_asyncio.AsyncTransaction` established. @@ -919,7 +962,10 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): """ conn = self.connect() - return self._trans_ctx(conn) + + async with conn: + async with conn.begin(): + yield conn def connect(self) -> AsyncConnection: """Return an :class:`_asyncio.AsyncConnection` object. @@ -1185,7 +1231,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): # END PROXY METHODS AsyncEngine -class AsyncTransaction(ProxyComparable[Transaction], StartableContext): +class AsyncTransaction( + ProxyComparable[Transaction], StartableContext["AsyncTransaction"] +): """An asyncio proxy for a :class:`_engine.Transaction`.""" __slots__ = ("connection", "sync_transaction", "nested") diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index 8f1c07fd8c..41ead5ee20 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -47,11 +47,21 @@ class AsyncCommon(FilterResult[_R]): _real_result: Result[Any] _metadata: ResultMetaData - async def close(self) -> None: + async def close(self) -> None: # type: ignore[override] """Close this result.""" await greenlet_spawn(self._real_result.close) + @property + def closed(self) -> bool: + """proxies the .closed attribute of the underlying result object, + if any, else raises ``AttributeError``. + + .. versionadded:: 2.0.0b3 + + """ + return self._real_result.closed # type: ignore + SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]") @@ -95,6 +105,16 @@ class AsyncResult(AsyncCommon[Row[_TP]]): "_row_getter", real_result.__dict__["_row_getter"] ) + @property + def closed(self) -> bool: + """proxies the .closed attribute of the underlying result object, + if any, else raises ``AttributeError``. + + .. versionadded:: 2.0.0b3 + + """ + return self._real_result.closed # type: ignore + @property def t(self) -> AsyncTupleResult[_TP]: """Apply a "typed tuple" typing filter to returned rows. diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 0aa9661e9f..1956ea588a 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -1547,7 +1547,8 @@ class _AsyncSessionContextManager(Generic[_AS]): class AsyncSessionTransaction( - ReversibleProxy[SessionTransaction], StartableContext + ReversibleProxy[SessionTransaction], + StartableContext["AsyncSessionTransaction"], ): """A wrapper for the ORM :class:`_orm.SessionTransaction` object. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 63035f585b..9ac6d07dae 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2713,7 +2713,14 @@ class Query( return None def __iter__(self) -> Iterable[_T]: - return self._iter().__iter__() # type: ignore + result = self._iter() + try: + yield from result + except GeneratorExit: + # issue #8710 - direct iteration is not re-usable after + # an iterable block is broken, so close the result + result._soft_close() + raise def _iter(self) -> Union[ScalarResult[_T], Result[_T]]: # new style execution. diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index cd7e0fd81a..4952cb5011 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -46,6 +46,7 @@ from ._collections import UniqueAppender as UniqueAppender from ._collections import update_copy as update_copy from ._collections import WeakPopulateDict as WeakPopulateDict from ._collections import WeakSequence as WeakSequence +from .compat import anext_ as anext_ from .compat import arm as arm from .compat import b as b from .compat import b64decode as b64decode diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 4ce1e7ff32..cda5ab6c12 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -111,6 +111,29 @@ else: return a +if py310: + anext_ = anext +else: + + _NOT_PROVIDED = object() + from collections.abc import AsyncIterator + + async def anext_(async_iterator, default=_NOT_PROVIDED): + """vendored from https://github.com/python/cpython/pull/8895""" + + if not isinstance(async_iterator, AsyncIterator): + raise TypeError( + f"anext expected an AsyncIterator, got {type(async_iterator)}" + ) + anxt = type(async_iterator).__anext__ + try: + return await anxt(async_iterator) + except StopAsyncIteration: + if default is _NOT_PROVIDED: + raise + return default + + def importlib_metadata_get(group): ep = importlib_metadata.entry_points() if not typing.TYPE_CHECKING and hasattr(ep, "select"): diff --git a/test/base/test_result.py b/test/base/test_result.py index 90938263f5..3e6444daa1 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -253,6 +253,21 @@ class ResultTest(fixtures.TestBase): return res + def test_close_attributes(self): + """test #8710""" + r1 = self._fixture() + + is_false(r1.closed) + is_false(r1._soft_closed) + + r1._soft_close() + is_false(r1.closed) + is_true(r1._soft_closed) + + r1.close() + is_true(r1.closed) + is_true(r1._soft_closed) + def test_class_presented(self): """To support different kinds of objects returned vs. rows, there are two wrapper classes for Result. diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index cdf70ca678..2eebb433db 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -799,6 +799,42 @@ class AsyncResultTest(EngineFixture): ): await conn.exec_driver_sql("SELECT * FROM users") + @async_test + async def test_stream_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + + async with conn.stream(select(self.tables.users)) as result: + assert not result._real_result._soft_closed + assert not result.closed + with expect_raises_message(Exception, "hi"): + i = 0 + async for row in result: + if i > 2: + raise Exception("hi") + i += 1 + assert result._real_result._soft_closed + assert result.closed + + @async_test + async def test_stream_scalars_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + + async with conn.stream_scalars( + select(self.tables.users) + ) as result: + assert not result._real_result._soft_closed + assert not result.closed + with expect_raises_message(Exception, "hi"): + i = 0 + async for scalar in result: + if i > 2: + raise Exception("hi") + i += 1 + assert result._real_result._soft_closed + assert result.closed + @testing.combinations( (None,), ("scalars",), ("mappings",), argnames="filter_" ) @@ -831,13 +867,20 @@ class AsyncResultTest(EngineFixture): eq_(all_, [(i, "name%d" % i) for i in range(1, 20)]) @testing.combinations( - (None,), ("scalars",), ("mappings",), argnames="filter_" + (None,), + ("scalars",), + ("stream_scalars",), + ("mappings",), + argnames="filter_", ) @async_test async def test_aiter(self, async_engine, filter_): users = self.tables.users async with async_engine.connect() as conn: - result = await conn.stream(select(users)) + if filter_ == "stream_scalars": + result = await conn.stream_scalars(select(users.c.user_name)) + else: + result = await conn.stream(select(users)) if filter_ == "mappings": result = result.mappings() @@ -857,7 +900,7 @@ class AsyncResultTest(EngineFixture): for i in range(1, 20) ], ) - elif filter_ == "scalars": + elif filter_ in ("scalars", "stream_scalars"): eq_( rows, ["name%d" % i for i in range(1, 20)], diff --git a/test/ext/mypy/plain_files/engines.py b/test/ext/mypy/plain_files/engines.py new file mode 100644 index 0000000000..c920ad55dc --- /dev/null +++ b/test/ext/mypy/plain_files/engines.py @@ -0,0 +1,86 @@ +from sqlalchemy import create_engine +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine + + +def regular() -> None: + + e = create_engine("sqlite://") + + # EXPECTED_TYPE: Engine + reveal_type(e) + + with e.connect() as conn: + + # EXPECTED_TYPE: Connection + reveal_type(conn) + + result = conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + with e.begin() as conn: + + # EXPECTED_TYPE: Connection + reveal_type(conn) + + result = conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + +async def asyncio() -> None: + e = create_async_engine("sqlite://") + + # EXPECTED_TYPE: AsyncEngine + reveal_type(e) + + async with e.connect() as conn: + + # EXPECTED_TYPE: AsyncConnection + reveal_type(conn) + + result = await conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + # stream with direct await + async_result = await conn.stream(text("select * from table")) + + # EXPECTED_TYPE: AsyncResult[Any] + reveal_type(async_result) + + # stream with context manager + async with conn.stream( + text("select * from table") + ) as ctx_async_result: + # EXPECTED_TYPE: AsyncResult[Any] + reveal_type(ctx_async_result) + + # stream_scalars with direct await + async_scalar_result = await conn.stream_scalars( + text("select * from table") + ) + + # EXPECTED_TYPE: AsyncScalarResult[Any] + reveal_type(async_scalar_result) + + # stream_scalars with context manager + async with conn.stream_scalars( + text("select * from table") + ) as ctx_async_scalar_result: + # EXPECTED_TYPE: AsyncScalarResult[Any] + reveal_type(ctx_async_scalar_result) + + async with e.begin() as conn: + + # EXPECTED_TYPE: AsyncConnection + reveal_type(conn) + + result = await conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) diff --git a/test/orm/test_loading.py b/test/orm/test_loading.py index cc3c3f4942..d0b5c9d8f9 100644 --- a/test/orm/test_loading.py +++ b/test/orm/test_loading.py @@ -6,6 +6,7 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy.orm import loading from sqlalchemy.orm import relationship +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message @@ -152,6 +153,24 @@ class InstancesTest(_fixtures.FixtureTest): def setup_mappers(cls): cls._setup_stock_mapping() + def test_cursor_close_exception_raised_in_iteration(self): + """test #8710""" + + User = self.classes.User + s = fixture_session() + + stmt = select(User).execution_options(yield_per=1) + + result = s.execute(stmt) + raw_cursor = result.raw + + for row in result: + with expect_raises_message(Exception, "whoops"): + for row in result: + raise Exception("whoops") + + is_true(raw_cursor._soft_closed) + def test_cursor_close_w_failed_rowproc(self): User = self.classes.User s = fixture_session() diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 4630090656..c05fdaf4fc 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -5417,6 +5417,58 @@ class YieldTest(_fixtures.FixtureTest): result.close() assert_raises(sa.exc.ResourceClosedError, result.all) + def test_yield_per_close_on_interrupted_iteration_legacy(self): + """test #8710""" + + self._eagerload_mappings() + + User = self.classes.User + + asserted_result = None + + class _Query(Query): + def _iter(self): + nonlocal asserted_result + asserted_result = super(_Query, self)._iter() + return asserted_result + + sess = fixture_session(query_cls=_Query) + + with expect_raises_message(Exception, "hi"): + for i, row in enumerate(sess.query(User).yield_per(1)): + assert not asserted_result._soft_closed + assert not asserted_result.closed + + if i > 1: + raise Exception("hi") + + assert asserted_result._soft_closed + assert not asserted_result.closed + + def test_yield_per_close_on_interrupted_iteration(self): + """test #8710""" + + self._eagerload_mappings() + + User = self.classes.User + + sess = fixture_session() + + with expect_raises_message(Exception, "hi"): + result = sess.execute(select(User).execution_options(yield_per=1)) + for i, row in enumerate(result): + assert not result._soft_closed + assert not result.closed + + if i > 1: + raise Exception("hi") + + assert not result._soft_closed + assert not result.closed + result.close() + assert result._soft_closed + assert result.closed + def test_yield_per_and_execution_options_legacy(self): self._eagerload_mappings() diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 4f776e3003..fa86d75ee8 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -53,6 +53,7 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import le_ from sqlalchemy.testing import mock @@ -2033,6 +2034,89 @@ class CursorResultTest(fixtures.TablesTest): partition = next(result.partitions()) eq_(len(partition), value) + @testing.fixture + def autoclose_row_fixture(self, connection): + users = self.tables.users + connection.execute( + users.insert(), + [ + {"user_id": 1, "name": "u1"}, + {"user_id": 2, "name": "u2"}, + {"user_id": 3, "name": "u3"}, + {"user_id": 4, "name": "u4"}, + {"user_id": 5, "name": "u5"}, + ], + ) + + @testing.fixture(params=["plain", "scalars", "mapping"]) + def result_fixture(self, request, connection): + users = self.tables.users + + result_type = request.param + + if result_type == "plain": + result = connection.execute(select(users)) + elif result_type == "scalars": + result = connection.scalars(select(users)) + elif result_type == "mapping": + result = connection.execute(select(users)).mappings() + else: + assert False + + return result + + def test_results_can_close(self, autoclose_row_fixture, result_fixture): + """test #8710""" + + r1 = result_fixture + + is_false(r1.closed) + is_false(r1._soft_closed) + + r1._soft_close() + is_false(r1.closed) + is_true(r1._soft_closed) + + r1.close() + is_true(r1.closed) + is_true(r1._soft_closed) + + def test_autoclose_rows_exhausted_plain( + self, connection, autoclose_row_fixture, result_fixture + ): + result = result_fixture + + assert not result._soft_closed + assert not result.closed + + read_iterator = list(result) + eq_(len(read_iterator), 5) + + assert result._soft_closed + assert not result.closed + + result.close() + assert result.closed + + def test_result_ctxmanager( + self, connection, autoclose_row_fixture, result_fixture + ): + """test #8710""" + + result = result_fixture + + with expect_raises_message(Exception, "hi"): + with result: + assert not result._soft_closed + assert not result.closed + + for i, obj in enumerate(result): + if i > 2: + raise Exception("hi") + + assert result._soft_closed + assert result.closed + class KeyTargetingTest(fixtures.TablesTest): run_inserts = "once" @@ -3113,6 +3197,47 @@ class AlternateCursorResultTest(fixtures.TablesTest): # buffer of 98, plus buffer of 99 - 89, 10 rows eq_(len(result.cursor_strategy._rowbuffer), 10) + for i, row in enumerate(result): + if i == 206: + break + + eq_(i, 206) + + def test_iterator_remains_unbroken(self, connection): + """test related to #8710. + + demonstrate that we can't close the cursor by catching + GeneratorExit inside of our iteration. Leaving the iterable + block using break, then picking up again, would be directly + impacted by this. So this provides a clear rationale for + providing context manager support for result objects. + + """ + table = self.tables.test + + connection.execute( + table.insert(), + [{"x": i, "y": "t_%d" % i} for i in range(15, 250)], + ) + + result = connection.execute(table.select()) + result = result.yield_per(100) + for i, row in enumerate(result): + if i == 188: + # this will raise GeneratorExit inside the iterator. + # so we can't close the DBAPI cursor here, we have plenty + # more rows to yield + break + + eq_(i, 188) + + # demonstrate getting more rows + for i, row in enumerate(result, 188): + if i == 206: + break + + eq_(i, 206) + @testing.combinations(True, False, argnames="close_on_init") @testing.combinations( "fetchone", "fetchmany", "fetchall", argnames="fetch_style"