]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support result.close() for all iterator patterns
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 25 Oct 2022 13:10:09 +0000 (09:10 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Nov 2022 22:42:52 +0000 (18:42 -0400)
This change contains new features for 2.0 only as well as some
behaviors that will be backported to 1.4.

For 1.4 and 2.0:

Fixed issue where the underlying DBAPI cursor would not be closed when
using :class:`_orm.Query` with :meth:`_orm.Query.yield_per` and direct
iteration, if a user-defined exception case were raised within the
iteration process, interrupting the iterator. This would lead to the usual
MySQL-related issues with server side cursors out of sync.

For 1.4 only:

A similar scenario can occur when using :term:`2.x` executions with direct
use of :class:`.Result`, in that case the end-user code has access to the
:class:`.Result` itself and should call :meth:`.Result.close` directly.
Version 2.0 will feature context-manager calling patterns to address this
use case.  However within the 1.4 scope, ensured that ``.close()`` methods
are available on all :class:`.Result` implementations including
:class:`.ScalarResult`, :class:`.MappingResult`.

For 2.0 only:

To better support the use case of iterating :class:`.Result` and
:class:`.AsyncResult` objects where user-defined exceptions may interrupt
the iteration, both objects as well as variants such as
:class:`.ScalarResult`, :class:`.MappingResult`,
:class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support
context manager usage, where the result will be closed at the end of
iteration.

Corrected various typing issues within the engine and async engine
packages.

Fixes: #8710
Change-Id: I3166328bfd3900957eb33cbf1061d0495c9df670

18 files changed:
doc/build/changelog/unreleased_14/8710.rst [new file with mode: 0644]
doc/build/changelog/unreleased_20/8710.rst [new file with mode: 0644]
doc/build/changelog/whatsnew_20.rst
doc/build/core/connections.rst
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/ext/asyncio/base.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/result.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
test/base/test_result.py
test/ext/asyncio/test_engine_py3k.py
test/ext/mypy/plain_files/engines.py [new file with mode: 0644]
test/orm/test_loading.py
test/orm/test_query.py
test/sql/test_resultset.py

diff --git a/doc/build/changelog/unreleased_14/8710.rst b/doc/build/changelog/unreleased_14/8710.rst
new file mode 100644 (file)
index 0000000..c687d4f
--- /dev/null
@@ -0,0 +1,31 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8710
+
+    Fixed issue where the underlying DBAPI cursor would not be closed when
+    using :class:`_orm.Query` and direct iteration, if a user-defined exception
+    case were raised within the iteration process, interrupting the iterator
+    which otherwise is not possible to re-use in this context. When using
+    :meth:`_orm.Query.yield_per` to create server-side cursors, this would lead
+    to the usual MySQL-related issues with server side cursors out of sync.
+
+    To resolve, a catch for ``GeneratorExit`` is applied within the default
+    iterator, which applies only in those cases where the interpreter is
+    calling ``.close()`` on the iterator in any case.
+
+    A similar scenario can occur when using :term:`2.x` executions with direct
+    use of :class:`.Result`, in that case the end-user code has access to the
+    :class:`.Result` itself and should call :meth:`.Result.close` directly.
+    Version 2.0 will feature context-manager calling patterns to address this
+    use case.  However within the 1.4 scope, ensured that ``.close()`` methods
+    are available on all :class:`.Result` implementations including
+    :class:`.ScalarResult`, :class:`.MappingResult`.
+
+
+.. change::
+    :tags: bug, engine
+    :tickets: 8710
+
+    Ensured all :class:`.Result` objects include a :meth:`.Result.close` method
+    as well as a :attr:`.Result.closed` attribute, including on
+    :class:`.ScalarResult` and :class:`.MappingResult`.
diff --git a/doc/build/changelog/unreleased_20/8710.rst b/doc/build/changelog/unreleased_20/8710.rst
new file mode 100644 (file)
index 0000000..0b1f166
--- /dev/null
@@ -0,0 +1,28 @@
+.. change::
+    :tags: feature, engine
+    :tickets: 8710
+
+    To better support the use case of iterating :class:`.Result` and
+    :class:`.AsyncResult` objects where user-defined exceptions may interrupt
+    the iteration, both objects as well as variants such as
+    :class:`.ScalarResult`, :class:`.MappingResult`,
+    :class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support
+    context manager usage, where the result will be closed at the end of
+    the context manager block.
+
+    In addition, ensured that all the above
+    mentioned :class:`.Result` objects include a :meth:`.Result.close` method
+    as well as :attr:`.Result.closed` accessors, including
+    :class:`.ScalarResult` and :class:`.MappingResult` which previously did
+    not have a ``.close()`` method.
+
+    .. seealso::
+
+        :ref:`change_8710`
+
+
+.. change::
+    :tags: bug, typing
+
+    Corrected various typing issues within the engine and async engine
+    packages.
index 3d4eca6b2b5a25e9b67353ce0b8699f7a84d60d2..a2def39f1856e33f4e75a1afbd77f0a663d7478d 100644 (file)
@@ -1539,6 +1539,37 @@ backend::
 
 :ticket:`7631`
 
+
+.. _change_8710:
+
+Context Manager Support for ``Result``, ``AsyncResult``
+-------------------------------------------------------
+
+The :class:`.Result` object now supports context manager use, which will
+ensure the object and its underlying cursor is closed at the end of the block.
+This is useful in particular with server side cursors, where it's important that
+the open cursor object is closed at the end of an operation, even if user-defined
+exceptions have occurred::
+
+    with engine.connect() as conn:
+        with conn.execution_options(yield_per=100).execute(
+            text("select * from table")
+        ) as result:
+            for row in result:
+                print(f"{row}")
+
+With asyncio use, the :class:`.AsyncResult` and :class:`.AsyncConnection` have
+been altered to provide for optional async context manager use, as in::
+
+    async with async_engine.connect() as conn:
+        async with conn.execution_options(yield_per=100).execute(
+            text("select * from table")
+        ) as result:
+            for row in result:
+                print(f"{row}")
+
+:ticket:`8710`
+
 Behavioral Changes
 ------------------
 
index 7b96c1df9a9a86fee63e4e5de0d42d5688f5974f..a39c452f61b6fd56ece9e89894de23ae99251eb9 100644 (file)
@@ -16,7 +16,7 @@ higher level management services, the :class:`_engine.Engine` and
 :class:`_engine.Connection` are king (and queen?) - read on.
 
 Basic Usage
-===========
+-----------
 
 Recall from :doc:`/core/engines` that an :class:`_engine.Engine` is created via
 the :func:`_sa.create_engine` call::
@@ -82,7 +82,7 @@ in the :ref:`unified_tutorial` for a tutorial.
 
 
 Using Transactions
-==================
+------------------
 
 .. note::
 
@@ -94,7 +94,7 @@ Using Transactions
   information.
 
 Commit As You Go
-----------------
+~~~~~~~~~~~~~~~~
 
 The :class:`~sqlalchemy.engine.Connection` object always emits SQL statements
 within the context of a transaction block.   The first time the
@@ -156,7 +156,7 @@ emitted, a new transaction begins implicitly::
    mode when using a "future" style engine.
 
 Begin Once
-----------------
+~~~~~~~~~~
 
 The :class:`_engine.Connection` object provides a more explicit transaction
 management style referred towards as **begin once**. In contrast to "commit as
@@ -184,7 +184,7 @@ once" block::
         # transaction is committed
 
 Connect and Begin Once from the Engine
----------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 A convenient shorthand form for the above "begin once" block is to use
 the :meth:`_engine.Engine.begin` method at the level of the originating
@@ -226,7 +226,7 @@ returned by the :meth:`_engine.Connection.begin` method::
         further commands.
 
 Mixing Styles
--------------
+~~~~~~~~~~~~~
 
 The "commit as you go" and "begin once" styles can be freely mixed within
 a single :meth:`_engine.Engine.connect` block, provided that the call to
@@ -262,7 +262,7 @@ When developing code that uses "begin once", the library will raise
 .. _dbapi_autocommit:
 
 Setting Transaction Isolation Levels including DBAPI Autocommit
-=================================================================
+---------------------------------------------------------------
 
 Most DBAPIs support the concept of configurable transaction :term:`isolation` levels.
 These are traditionally the four levels "READ UNCOMMITTED", "READ COMMITTED",
@@ -294,7 +294,7 @@ SQLAlchemy dialects should support these isolation levels as well as autocommit
 to as great a degree as possible.
 
 Setting Isolation Level or DBAPI Autocommit for a Connection
-------------------------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 For an individual :class:`_engine.Connection` object that's acquired from
 :meth:`.Engine.connect`, the isolation level can be set for the duration of
@@ -341,7 +341,7 @@ begin a transaction::
    on a per-transaction basis.
 
 Setting Isolation Level or DBAPI Autocommit for an Engine
-----------------------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 The :paramref:`_engine.Connection.execution_options.isolation_level` option may
 also be set engine wide, as is often preferable.  This may be
@@ -361,7 +361,7 @@ subsequent operations.
 .. _dbapi_autocommit_multiple:
 
 Maintaining Multiple Isolation Levels for a Single Engine
-----------------------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 The isolation level may also be set per engine, with a potentially greater
 level of flexibility, using either the
@@ -429,7 +429,7 @@ reverted when a connection is returned to the connection pool.
 .. _dbapi_autocommit_understanding:
 
 Understanding the DBAPI-Level Autocommit Isolation Level
----------------------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 In the parent section, we introduced the concept of the
 :paramref:`_engine.Connection.execution_options.isolation_level`
@@ -588,7 +588,7 @@ To sum up:
 .. _engine_stream_results:
 
 Using Server Side Cursors (a.k.a. stream results)
-==================================================
+-------------------------------------------------
 
 Some backends feature explicit support for the concept of "server
 side cursors" versus "client side cursors".  A client side cursor here
@@ -644,7 +644,7 @@ or per-statement basis.    Similar options exist when using an ORM
 
 
 Streaming with a fixed buffer via yield_per
---------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 As individual row-fetch operations with fully unbuffered server side cursors
 are typically more expensive than fetching batches of rows at once, The
@@ -683,12 +683,13 @@ combination has includes:
 These three behaviors are illustrated in the example below::
 
     with engine.connect() as conn:
-        result = conn.execution_options(yield_per=100).execute(text("select * from table"))
-
-        for partition in result.partitions():
-            # partition is an iterable that will be at most 100 items
-            for row in partition:
-                print(f"{row}")
+        with conn.execution_options(yield_per=100).execute(
+            text("select * from table")
+        ) as result:
+            for partition in result.partitions():
+                # partition is an iterable that will be at most 100 items
+                for row in partition:
+                    print(f"{row}")
 
 The above example illustrates the combination of ``yield_per=100`` along
 with using the :meth:`_engine.Result.partitions` method to run processing
@@ -699,6 +700,13 @@ buffered for each 100 rows fetched.    Calling a method such as
 :meth:`_engine.Result.all` should **not** be used, as this will fully
 fetch all remaining rows at once and defeat the purpose of using ``yield_per``.
 
+.. tip::
+
+    The :class:`.Result` object may be used as a context manager as illustrated
+    above.  When iterating with a server-side cursor, this is the best way to
+    ensure the :class:`.Result` object is closed, even if exceptions are
+    raised within the iteration process.
+
 The :paramref:`_engine.Connection.execution_options.yield_per` option
 is portable to the ORM as well, used by a :class:`_orm.Session` to fetch
 ORM objects, where it also limits the amount of ORM objects generated at once.
@@ -715,7 +723,7 @@ for further background on using
 .. _engine_stream_results_sr:
 
 Streaming with a dynamically growing buffer using stream_results
------------------------------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 To enable server side cursors without a specific partition size, the
 :paramref:`_engine.Connection.execution_options.stream_results` option may be
@@ -731,11 +739,12 @@ of 1000 rows.   The maximum size of this buffer can be affected using the
 :paramref:`_engine.Connection.execution_options.max_row_buffer` execution option::
 
     with engine.connect() as conn:
-        conn = conn.execution_options(stream_results=True, max_row_buffer=100)
-        result = conn.execute(text("select * from table"))
+        with conn.execution_options(stream_results=True, max_row_buffer=100).execute(
+            text("select * from table")
+        ) as result:
 
-        for row in result:
-            print(f"{row}")
+            for row in result:
+                print(f"{row}")
 
 While the :paramref:`_engine.Connection.execution_options.stream_results`
 option may be combined with use of the :meth:`_engine.Result.partitions`
@@ -757,7 +766,7 @@ up to use the :meth:`_engine.Result.partitions` method.
 .. _schema_translating:
 
 Translation of Schema Names
-===========================
+---------------------------
 
 To support multi-tenancy applications that distribute common sets of tables
 into multiple schemas, the
@@ -853,7 +862,7 @@ as the schema name is passed to these methods explicitly.
 
 
 SQL Compilation Caching
-=======================
+-----------------------
 
 .. versionadded:: 1.4  SQLAlchemy now has a transparent query caching system
    that substantially lowers the Python computational overhead involved in
@@ -903,7 +912,7 @@ detail the configuration and advanced usage patterns for the cache.
 
 
 Configuration
--------------
+~~~~~~~~~~~~~
 
 The cache itself is a dictionary-like object called an ``LRUCache``, which is
 an internal SQLAlchemy dictionary subclass that tracks the usage of particular
@@ -930,7 +939,7 @@ cache's behavior, described in the next section.
 .. _sql_caching_logging:
 
 Estimating Cache Performance Using Logging
-------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 The above cache size of 1200 is actually fairly large.   For small applications,
 a size of 100 is likely sufficient.  To estimate the optimal size of the cache,
@@ -1148,7 +1157,7 @@ obviously an extremely small size, and the default size of 500 is fine to be lef
 at its default.
 
 How much memory does the cache use?
------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 The previous section detailed some techniques to check if the
 :paramref:`_sa.create_engine.query_cache_size` needs to be bigger.   How do we know
@@ -1170,7 +1179,7 @@ moderate Core statement takes up about 12K while a small ORM statement takes abo
 .. _engine_compiled_cache:
 
 Disabling or using an alternate dictionary to cache some (or all) statements
------------------------------------------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 The internal cache used is known as ``LRUCache``, but this is mostly just
 a dictionary.  Any dictionary may be used as a cache for any series of
@@ -1200,7 +1209,7 @@ The cache can also be disabled with this argument by sending a value of
 .. _engine_thirdparty_caching:
 
 Caching for Third Party Dialects
----------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 The caching feature requires that the dialect's compiler produces SQL
 strings that are safe to reuse for many statement invocations, given
@@ -1342,7 +1351,7 @@ SELECTs with LIMIT/OFFSET are correctly rendered and cached.
 .. _engine_lambda_caching:
 
 Using Lambdas to add significant speed gains to statement production
---------------------------------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 .. deepalchemy:: This technique is generally non-essential except in very performance
    intensive scenarios, and intended for experienced Python programmers.
@@ -1764,7 +1773,7 @@ performance example.
 .. _engine_insertmanyvalues:
 
 "Insert Many Values" Behavior for INSERT statements
-====================================================
+---------------------------------------------------
 
 .. versionadded:: 2.0 see :ref:`change_6047` for background on the change
    including sample performance tests
@@ -1853,7 +1862,7 @@ as follows:
   the same usage patterns and equivalent performance benefits.
 
 Enabling/Disabling the feature
-------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 To disable the "insertmanyvalues" feature for a given backend for an
 :class:`.Engine` overall, pass the
@@ -1886,7 +1895,7 @@ such a table may need to include ``implicit_returning=False`` (see
 .. _engine_insertmanyvalues_page_size:
 
 Controlling the Batch Size
----------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 A key characteristic of "insertmanyvalues" is that the size of the INSERT
 statement is limited on a fixed max number of "values" clauses as well as a
@@ -1954,7 +1963,7 @@ Or configured on the statement itself::
 .. _engine_insertmanyvalues_events:
 
 Logging and Events
--------------------
+~~~~~~~~~~~~~~~~~~
 
 The "insertmanyvalues" feature integrates fully with SQLAlchemy's statement
 logging as well as cursor events such as :meth:`.ConnectionEvents.before_cursor_execute`.
@@ -1979,7 +1988,7 @@ an excerpt of this logging:
   [insertmanyvalues batch 10 of 10] ('d900', 900, 9000, 'd901', ...
 
 Upsert Support
---------------
+~~~~~~~~~~~~~~
 
 The PostgreSQL, SQLite, and MariaDB dialects offer backend-specific
 "upsert" constructs :func:`_postgresql.insert`, :func:`_sqlite.insert`
@@ -1993,7 +2002,7 @@ with RETURNING to take place.
 .. _engine_disposal:
 
 Engine Disposal
-===============
+---------------
 
 The :class:`_engine.Engine` refers to a connection pool, which means under normal
 circumstances, there are open database connections present while the
@@ -2070,7 +2079,7 @@ for guidelines on how to disable pooling.
 .. _dbapi_connections:
 
 Working with Driver SQL and Raw DBAPI Connections
-=================================================
+-------------------------------------------------
 
 The introduction on using :meth:`_engine.Connection.execute` made use of the
 :func:`_expression.text` construct in order to illustrate how textual SQL statements
@@ -2082,7 +2091,7 @@ SQL in that it normalizes how bound parameters are passed, as well as that
 it supports datatyping behavior for parameters and result set rows.
 
 Invoking SQL strings directly to the driver
---------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 For the use case where one wants to invoke textual SQL directly passed to the
 underlying driver (known as the :term:`DBAPI`) without any intervention
@@ -2097,7 +2106,7 @@ method may be used::
 .. _dbapi_connections_cursor:
 
 Working with the DBAPI cursor directly
---------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 There are some cases where SQLAlchemy does not provide a genericized way
 at accessing some :term:`DBAPI` functions, such as calling stored procedures as well
@@ -2153,7 +2162,7 @@ Some recipes for DBAPI connection use follow.
 .. _stored_procedures:
 
 Calling Stored Procedures and User Defined Functions
-------------------------------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 SQLAlchemy supports calling stored procedures and user defined functions
 several ways. Please note that all DBAPIs have different practices, so you must
@@ -2198,7 +2207,7 @@ situations to determine the correct syntax and patterns to use.
 
 
 Multiple Result Sets
---------------------
+~~~~~~~~~~~~~~~~~~~~
 
 Multiple result set support is available from a raw DBAPI cursor using the
 `nextset <https://legacy.python.org/dev/peps/pep-0249/#nextset>`_ method::
@@ -2215,7 +2224,7 @@ Multiple result set support is available from a raw DBAPI cursor using the
         connection.close()
 
 Registering New Dialects
-========================
+------------------------
 
 The :func:`_sa.create_engine` function call locates the given dialect
 using setuptools entrypoints.   These entry points can be established
@@ -2250,7 +2259,7 @@ The above entrypoint would then be accessed as ``create_engine("mysql+foodialect
 
 
 Registering Dialects In-Process
--------------------------------
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 SQLAlchemy also allows a dialect to be registered within the current process, bypassing
 the need for separate installation.   Use the ``register()`` function as follows::
@@ -2265,7 +2274,7 @@ The above will respond to ``create_engine("mysql+foodialect://")`` and load the
 
 
 Connection / Engine API
-=======================
+-----------------------
 
 .. autoclass:: Connection
    :members:
@@ -2296,7 +2305,7 @@ Connection / Engine API
 
 
 Result Set  API
-=================
+---------------
 
 .. autoclass:: ChunkedIteratorResult
     :members:
index 05ca17063b028244b744bfaf55948a849a465c0f..bcd2f0ea9d9beddbb1247af501a0acd65cd4d97c 100644 (file)
@@ -928,6 +928,12 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
     def __init__(self, cursor_metadata: ResultMetaData):
         self._metadata = cursor_metadata
 
+    def __enter__(self) -> Result[_TP]:
+        return self
+
+    def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
+        self.close()
+
     def close(self) -> None:
         """close this :class:`_result.Result`.
 
@@ -950,6 +956,19 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
         """
         self._soft_close(hard=True)
 
+    @property
+    def _soft_closed(self) -> bool:
+        raise NotImplementedError()
+
+    @property
+    def closed(self) -> bool:
+        """return True if this :class:`.Result` reports .closed
+
+        .. versionadded:: 1.4.43
+
+        """
+        raise NotImplementedError()
+
     @_generative
     def yield_per(self: SelfResult, num: int) -> SelfResult:
         """Configure the row-fetching strategy to fetch ``num`` rows at a time.
@@ -1574,6 +1593,12 @@ class FilterResult(ResultInternal[_R]):
 
     _real_result: Result[Any]
 
+    def __enter__(self: SelfFilterResult) -> SelfFilterResult:
+        return self
+
+    def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
+        self._real_result.__exit__(type_, value, traceback)
+
     @_generative
     def yield_per(self: SelfFilterResult, num: int) -> SelfFilterResult:
         """Configure the row-fetching strategy to fetch ``num`` rows at a time.
@@ -1599,6 +1624,27 @@ class FilterResult(ResultInternal[_R]):
     def _soft_close(self, hard: bool = False) -> None:
         self._real_result._soft_close(hard=hard)
 
+    @property
+    def _soft_closed(self) -> bool:
+        return self._real_result._soft_closed
+
+    @property
+    def closed(self) -> bool:
+        """return True if the underlying :class:`.Result` reports .closed
+
+        .. versionadded:: 1.4.43
+
+        """
+        return self._real_result.closed
+
+    def close(self) -> None:
+        """Close this :class:`.FilterResult`.
+
+        .. versionadded:: 1.4.43
+
+        """
+        self._real_result.close()
+
     @property
     def _attributes(self) -> Dict[Any, Any]:
         return self._real_result._attributes
@@ -2172,7 +2218,7 @@ class IteratorResult(Result[_TP]):
         self,
         cursor_metadata: ResultMetaData,
         iterator: Iterator[_InterimSupportsScalarsRowType],
-        raw: Optional[Any] = None,
+        raw: Optional[Result[Any]] = None,
         _source_supports_scalars: bool = False,
     ):
         self._metadata = cursor_metadata
@@ -2180,6 +2226,15 @@ class IteratorResult(Result[_TP]):
         self.raw = raw
         self._source_supports_scalars = _source_supports_scalars
 
+    @property
+    def closed(self) -> bool:
+        """return True if this :class:`.IteratorResult` has been closed
+
+        .. versionadded:: 1.4.43
+
+        """
+        return self._hard_closed
+
     def _soft_close(self, hard: bool = False, **kw: Any) -> None:
         if hard:
             self._hard_closed = True
@@ -2262,7 +2317,7 @@ class ChunkedIteratorResult(IteratorResult[_TP]):
             [Optional[int]], Iterator[Sequence[_InterimRowType[_R]]]
         ],
         source_supports_scalars: bool = False,
-        raw: Optional[Any] = None,
+        raw: Optional[Result[Any]] = None,
         dynamic_yield_per: bool = False,
     ):
         self._metadata = cursor_metadata
index 7fdd2d7e064314f1b3b5b55b2304ea28cc1c8d3a..13d5e40b24d3cf8283e4d692890ef867a748b98f 100644 (file)
@@ -10,8 +10,13 @@ from __future__ import annotations
 import abc
 import functools
 from typing import Any
+from typing import AsyncGenerator
+from typing import AsyncIterator
+from typing import Awaitable
+from typing import Callable
 from typing import ClassVar
 from typing import Dict
+from typing import Generator
 from typing import Generic
 from typing import NoReturn
 from typing import Optional
@@ -25,6 +30,7 @@ from ... import util
 from ...util.typing import Literal
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 
 _PT = TypeVar("_PT", bound=Any)
@@ -114,27 +120,29 @@ class ReversibleProxy(Generic[_PT]):
 
 
 SelfStartableContext = TypeVar(
-    "SelfStartableContext", bound="StartableContext"
+    "SelfStartableContext", bound="StartableContext[Any]"
 )
 
 
-class StartableContext(abc.ABC):
+class StartableContext(Awaitable[_T_co], abc.ABC):
     __slots__ = ()
 
     @abc.abstractmethod
     async def start(
         self: SelfStartableContext, is_ctxmanager: bool = False
-    ) -> Any:
+    ) -> _T_co:
         raise NotImplementedError()
 
-    def __await__(self) -> Any:
+    def __await__(self) -> Generator[Any, Any, _T_co]:
         return self.start().__await__()
 
-    async def __aenter__(self: SelfStartableContext) -> Any:
-        return await self.start(is_ctxmanager=True)
+    async def __aenter__(self: SelfStartableContext) -> _T_co:
+        return await self.start(is_ctxmanager=True)  # type: ignore
 
     @abc.abstractmethod
-    async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
+    async def __aexit__(
+        self, type_: Any, value: Any, traceback: Any
+    ) -> Optional[bool]:
         pass
 
     def _raise_for_not_started(self) -> NoReturn:
@@ -144,6 +152,129 @@ class StartableContext(abc.ABC):
         )
 
 
+class GeneratorStartableContext(StartableContext[_T_co]):
+    __slots__ = ("gen",)
+
+    gen: AsyncGenerator[_T_co, Any]
+
+    def __init__(
+        self,
+        func: Callable[..., AsyncIterator[_T_co]],
+        args: tuple[Any, ...],
+        kwds: dict[str, Any],
+    ):
+        self.gen = func(*args, **kwds)  # type: ignore
+
+    async def start(self, is_ctxmanager: bool = False) -> _T_co:
+        try:
+            start_value = await util.anext_(self.gen)
+        except StopAsyncIteration:
+            raise RuntimeError("generator didn't yield") from None
+
+        # if not a context manager, then interrupt the generator, don't
+        # let it complete.   this step is technically not needed, as the
+        # generator will close in any case at gc time.  not clear if having
+        # this here is a good idea or not (though it helps for clarity IMO)
+        if not is_ctxmanager:
+            await self.gen.aclose()
+
+        return start_value
+
+    async def __aexit__(
+        self, typ: Any, value: Any, traceback: Any
+    ) -> Optional[bool]:
+        # vendored from contextlib.py
+        if typ is None:
+            try:
+                await util.anext_(self.gen)
+            except StopAsyncIteration:
+                return False
+            else:
+                raise RuntimeError("generator didn't stop")
+        else:
+            if value is None:
+                # Need to force instantiation so we can reliably
+                # tell if we get the same exception back
+                value = typ()
+            try:
+                await self.gen.athrow(typ, value, traceback)
+            except StopAsyncIteration as exc:
+                # Suppress StopIteration *unless* it's the same exception that
+                # was passed to throw().  This prevents a StopIteration
+                # raised inside the "with" statement from being suppressed.
+                return exc is not value
+            except RuntimeError as exc:
+                # Don't re-raise the passed in exception. (issue27122)
+                if exc is value:
+                    return False
+                # Avoid suppressing if a Stop(Async)Iteration exception
+                # was passed to athrow() and later wrapped into a RuntimeError
+                # (see PEP 479 for sync generators; async generators also
+                # have this behavior). But do this only if the exception
+                # wrapped
+                # by the RuntimeError is actully Stop(Async)Iteration (see
+                # issue29692).
+                if (
+                    isinstance(value, (StopIteration, StopAsyncIteration))
+                    and exc.__cause__ is value
+                ):
+                    return False
+                raise
+            except BaseException as exc:
+                # only re-raise if it's *not* the exception that was
+                # passed to throw(), because __exit__() must not raise
+                # an exception unless __exit__() itself failed.  But throw()
+                # has to raise the exception to signal propagation, so this
+                # fixes the impedance mismatch between the throw() protocol
+                # and the __exit__() protocol.
+                if exc is not value:
+                    raise
+                return False
+            raise RuntimeError("generator didn't stop after athrow()")
+
+
+def asyncstartablecontext(
+    func: Callable[..., AsyncIterator[_T_co]]
+) -> Callable[..., GeneratorStartableContext[_T_co]]:
+    """@asyncstartablecontext decorator.
+
+    the decorated function can be called either as ``async with fn()``, **or**
+    ``await fn()``.   This is decidedly different from what
+    ``@contextlib.asynccontextmanager`` supports, and the usage pattern
+    is different as well.
+
+    Typical usage::
+
+        @asyncstartablecontext
+        async def some_async_generator(<arguments>):
+            <setup>
+            try:
+                yield <value>
+            except GeneratorExit:
+                # return value was awaited, no context manager is present
+                # and caller will .close() the resource explicitly
+                pass
+            else:
+                <context manager cleanup>
+
+
+    Above, ``GeneratorExit`` is caught if the function were used as an
+    ``await``.  In this case, it's essential that the cleanup does **not**
+    occur, so there should not be a ``finally`` block.
+
+    If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__``
+    and we were invoked as a context manager, and cleanup should proceed.
+
+
+    """
+
+    @functools.wraps(func)
+    def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]:
+        return GeneratorStartableContext(func, args, kwds)
+
+    return helper
+
+
 class ProxyComparable(ReversibleProxy[_PT]):
     __slots__ = ()
 
index 3890888754aeef13b2429f2549986d14e26bbeef..854039c9827e3dfde90a0d758c8e0277aaa9139c 100644 (file)
@@ -7,7 +7,9 @@
 from __future__ import annotations
 
 import asyncio
+import contextlib
 from typing import Any
+from typing import AsyncIterator
 from typing import Callable
 from typing import Dict
 from typing import Generator
@@ -21,6 +23,8 @@ from typing import TypeVar
 from typing import Union
 
 from . import exc as async_exc
+from .base import asyncstartablecontext
+from .base import GeneratorStartableContext
 from .base import ProxyComparable
 from .base import StartableContext
 from .result import _ensure_sync_result
@@ -133,7 +137,9 @@ class AsyncConnectable:
     ],
 )
 class AsyncConnection(
-    ProxyComparable[Connection], StartableContext, AsyncConnectable
+    ProxyComparable[Connection],
+    StartableContext["AsyncConnection"],
+    AsyncConnectable,
 ):
     """An asyncio proxy for a :class:`_engine.Connection`.
 
@@ -446,34 +452,67 @@ class AsyncConnection(
         return await _ensure_sync_result(result, self.exec_driver_sql)
 
     @overload
-    async def stream(
+    def stream(
         self,
         statement: TypedReturnsRows[_T],
         parameters: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: Optional[CoreExecuteOptionsParameter] = None,
-    ) -> AsyncResult[_T]:
+    ) -> GeneratorStartableContext[AsyncResult[_T]]:
         ...
 
     @overload
-    async def stream(
+    def stream(
         self,
         statement: Executable,
         parameters: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: Optional[CoreExecuteOptionsParameter] = None,
-    ) -> AsyncResult[Any]:
+    ) -> GeneratorStartableContext[AsyncResult[Any]]:
         ...
 
+    @asyncstartablecontext
     async def stream(
         self,
         statement: Executable,
         parameters: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: Optional[CoreExecuteOptionsParameter] = None,
-    ) -> AsyncResult[Any]:
-        """Execute a statement and return a streaming
-        :class:`_asyncio.AsyncResult` object."""
+    ) -> AsyncIterator[AsyncResult[Any]]:
+        """Execute a statement and return an awaitable yielding a
+        :class:`_asyncio.AsyncResult` object.
+
+        E.g.::
+
+            result = await conn.stream(stmt):
+            async for row in result:
+                print(f"{row}")
+
+        The :meth:`.AsyncConnection.stream`
+        method supports optional context manager use against the
+        :class:`.AsyncResult` object, as in::
+
+            async with conn.stream(stmt) as result:
+                async for row in result:
+                    print(f"{row}")
+
+        In the above pattern, the :meth:`.AsyncResult.close` method is
+        invoked unconditionally, even if the iterator is interrupted by an
+        exception throw.   Context manager use remains optional, however,
+        and the function may be called in either an ``async with fn():`` or
+        ``await fn()`` style.
+
+        .. versionadded:: 2.0.0b3 added context manager support
+
+
+        :return: an awaitable object that will yield an
+         :class:`_asyncio.AsyncResult` object.
+
+        .. seealso::
+
+            :meth:`.AsyncConnection.stream_scalars`
+
+        """
 
         result = await greenlet_spawn(
             self._proxied.execute,
@@ -484,10 +523,15 @@ class AsyncConnection(
             ),
             _require_await=True,
         )
-        if not result.context._is_server_side:
-            # TODO: real exception here
-            assert False, "server side result expected"
-        return AsyncResult(result)
+        assert result.context._is_server_side
+        ar = AsyncResult(result)
+        try:
+            yield ar
+        except GeneratorExit:
+            pass
+        else:
+            task = asyncio.create_task(ar.close())
+            await asyncio.shield(task)
 
     @overload
     async def execute(
@@ -642,48 +686,77 @@ class AsyncConnection(
         return result.scalars()
 
     @overload
-    async def stream_scalars(
+    def stream_scalars(
         self,
         statement: TypedReturnsRows[Tuple[_T]],
         parameters: Optional[_CoreSingleExecuteParams] = None,
         *,
         execution_options: Optional[CoreExecuteOptionsParameter] = None,
-    ) -> AsyncScalarResult[_T]:
+    ) -> GeneratorStartableContext[AsyncScalarResult[_T]]:
         ...
 
     @overload
-    async def stream_scalars(
+    def stream_scalars(
         self,
         statement: Executable,
         parameters: Optional[_CoreSingleExecuteParams] = None,
         *,
         execution_options: Optional[CoreExecuteOptionsParameter] = None,
-    ) -> AsyncScalarResult[Any]:
+    ) -> GeneratorStartableContext[AsyncScalarResult[Any]]:
         ...
 
+    @asyncstartablecontext
     async def stream_scalars(
         self,
         statement: Executable,
         parameters: Optional[_CoreSingleExecuteParams] = None,
         *,
         execution_options: Optional[CoreExecuteOptionsParameter] = None,
-    ) -> AsyncScalarResult[Any]:
-        r"""Executes a SQL statement and returns a streaming scalar result
-        object.
+    ) -> AsyncIterator[AsyncScalarResult[Any]]:
+        r"""Execute a statement and return an awaitable yielding a
+        :class:`_asyncio.AsyncScalarResult` object.
+
+        E.g.::
+
+            result = await conn.stream_scalars(stmt):
+            async for scalar in result:
+                print(f"{scalar}")
 
         This method is shorthand for invoking the
         :meth:`_engine.AsyncResult.scalars` method after invoking the
         :meth:`_engine.Connection.stream` method.  Parameters are equivalent.
 
-        :return: an :class:`_asyncio.AsyncScalarResult` object.
+        The :meth:`.AsyncConnection.stream_scalars`
+        method supports optional context manager use against the
+        :class:`.AsyncScalarResult` object, as in::
+
+            async with conn.stream_scalars(stmt) as result:
+                async for scalar in result:
+                    print(f"{scalar}")
+
+        In the above pattern, the :meth:`.AsyncScalarResult.close` method is
+        invoked unconditionally, even if the iterator is interrupted by an
+        exception throw.  Context manager use remains optional, however,
+        and the function may be called in either an ``async with fn():`` or
+        ``await fn()`` style.
+
+        .. versionadded:: 2.0.0b3 added context manager support
+
+        :return: an awaitable object that will yield an
+         :class:`_asyncio.AsyncScalarResult` object.
 
         .. versionadded:: 1.4.24
 
+        .. seealso::
+
+            :meth:`.AsyncConnection.stream`
+
         """
-        result = await self.stream(
+
+        async with self.stream(
             statement, parameters, execution_options=execution_options
-        )
-        return result.scalars()
+        ) as result:
+            yield result.scalars()
 
     async def run_sync(
         self, fn: Callable[..., Any], *arg: Any, **kw: Any
@@ -856,37 +929,6 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
         :ref:`asyncio_events`
     """
 
-    class _trans_ctx(StartableContext):
-        __slots__ = ("conn", "transaction")
-
-        conn: AsyncConnection
-        transaction: AsyncTransaction
-
-        def __init__(self, conn: AsyncConnection):
-            self.conn = conn
-
-        if TYPE_CHECKING:
-
-            async def __aenter__(self) -> AsyncConnection:
-                ...
-
-        async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
-            await self.conn.start(is_ctxmanager=is_ctxmanager)
-            self.transaction = self.conn.begin()
-            await self.transaction.__aenter__()
-
-            return self.conn
-
-        async def __aexit__(
-            self, type_: Any, value: Any, traceback: Any
-        ) -> None:
-            async def go() -> None:
-                await self.transaction.__aexit__(type_, value, traceback)
-                await self.conn.close()
-
-            task = asyncio.create_task(go())
-            await asyncio.shield(task)
-
     def __init__(self, sync_engine: Engine):
         if not sync_engine.dialect.is_async:
             raise exc.InvalidRequestError(
@@ -903,7 +945,8 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
     def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
         return AsyncEngine(target)
 
-    def begin(self) -> AsyncEngine._trans_ctx:
+    @contextlib.asynccontextmanager
+    async def begin(self) -> AsyncIterator[AsyncConnection]:
         """Return a context manager which when entered will deliver an
         :class:`_asyncio.AsyncConnection` with an
         :class:`_asyncio.AsyncTransaction` established.
@@ -919,7 +962,10 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
 
         """
         conn = self.connect()
-        return self._trans_ctx(conn)
+
+        async with conn:
+            async with conn.begin():
+                yield conn
 
     def connect(self) -> AsyncConnection:
         """Return an :class:`_asyncio.AsyncConnection` object.
@@ -1185,7 +1231,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
     # END PROXY METHODS AsyncEngine
 
 
-class AsyncTransaction(ProxyComparable[Transaction], StartableContext):
+class AsyncTransaction(
+    ProxyComparable[Transaction], StartableContext["AsyncTransaction"]
+):
     """An asyncio proxy for a :class:`_engine.Transaction`."""
 
     __slots__ = ("connection", "sync_transaction", "nested")
index 8f1c07fd8c89d39a9eea8b56ccdd979034cf1532..41ead5ee20b01e9d69f1ab922aa12a46fb0aa3b9 100644 (file)
@@ -47,11 +47,21 @@ class AsyncCommon(FilterResult[_R]):
     _real_result: Result[Any]
     _metadata: ResultMetaData
 
-    async def close(self) -> None:
+    async def close(self) -> None:  # type: ignore[override]
         """Close this result."""
 
         await greenlet_spawn(self._real_result.close)
 
+    @property
+    def closed(self) -> bool:
+        """proxies the .closed attribute of the underlying result object,
+        if any, else raises ``AttributeError``.
+
+        .. versionadded:: 2.0.0b3
+
+        """
+        return self._real_result.closed  # type: ignore
+
 
 SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]")
 
@@ -95,6 +105,16 @@ class AsyncResult(AsyncCommon[Row[_TP]]):
                 "_row_getter", real_result.__dict__["_row_getter"]
             )
 
+    @property
+    def closed(self) -> bool:
+        """proxies the .closed attribute of the underlying result object,
+        if any, else raises ``AttributeError``.
+
+        .. versionadded:: 2.0.0b3
+
+        """
+        return self._real_result.closed  # type: ignore
+
     @property
     def t(self) -> AsyncTupleResult[_TP]:
         """Apply a "typed tuple" typing filter to returned rows.
index 0aa9661e9f67bc716d8a458f8ecf2c4fef7de2c5..1956ea588ade60ff42c405b6b81adc65faa253e0 100644 (file)
@@ -1547,7 +1547,8 @@ class _AsyncSessionContextManager(Generic[_AS]):
 
 
 class AsyncSessionTransaction(
-    ReversibleProxy[SessionTransaction], StartableContext
+    ReversibleProxy[SessionTransaction],
+    StartableContext["AsyncSessionTransaction"],
 ):
     """A wrapper for the ORM :class:`_orm.SessionTransaction` object.
 
index 63035f585bdc412233cf1088dce5bc54ba71ae61..9ac6d07daeff635fb11008367dd7e150e3038060 100644 (file)
@@ -2713,7 +2713,14 @@ class Query(
             return None
 
     def __iter__(self) -> Iterable[_T]:
-        return self._iter().__iter__()  # type: ignore
+        result = self._iter()
+        try:
+            yield from result
+        except GeneratorExit:
+            # issue #8710 - direct iteration is not re-usable after
+            # an iterable block is broken, so close the result
+            result._soft_close()
+            raise
 
     def _iter(self) -> Union[ScalarResult[_T], Result[_T]]:
         # new style execution.
index cd7e0fd81a96be2463d799f702259087889ccc03..4952cb5011e96c35ab99c2b926f56030980856e5 100644 (file)
@@ -46,6 +46,7 @@ from ._collections import UniqueAppender as UniqueAppender
 from ._collections import update_copy as update_copy
 from ._collections import WeakPopulateDict as WeakPopulateDict
 from ._collections import WeakSequence as WeakSequence
+from .compat import anext_ as anext_
 from .compat import arm as arm
 from .compat import b as b
 from .compat import b64decode as b64decode
index 4ce1e7ff32c063f3c1370144297b8a65ba0c8a1a..cda5ab6c12551b85c2a7f820f5de474092168b22 100644 (file)
@@ -111,6 +111,29 @@ else:
         return a
 
 
+if py310:
+    anext_ = anext
+else:
+
+    _NOT_PROVIDED = object()
+    from collections.abc import AsyncIterator
+
+    async def anext_(async_iterator, default=_NOT_PROVIDED):
+        """vendored from https://github.com/python/cpython/pull/8895"""
+
+        if not isinstance(async_iterator, AsyncIterator):
+            raise TypeError(
+                f"anext expected an AsyncIterator, got {type(async_iterator)}"
+            )
+        anxt = type(async_iterator).__anext__
+        try:
+            return await anxt(async_iterator)
+        except StopAsyncIteration:
+            if default is _NOT_PROVIDED:
+                raise
+            return default
+
+
 def importlib_metadata_get(group):
     ep = importlib_metadata.entry_points()
     if not typing.TYPE_CHECKING and hasattr(ep, "select"):
index 90938263f517a64463566a6c711be7ccc15e51a1..3e6444daa1ffe1e840c23ae1085dc90cbe0f99f7 100644 (file)
@@ -253,6 +253,21 @@ class ResultTest(fixtures.TestBase):
 
         return res
 
+    def test_close_attributes(self):
+        """test #8710"""
+        r1 = self._fixture()
+
+        is_false(r1.closed)
+        is_false(r1._soft_closed)
+
+        r1._soft_close()
+        is_false(r1.closed)
+        is_true(r1._soft_closed)
+
+        r1.close()
+        is_true(r1.closed)
+        is_true(r1._soft_closed)
+
     def test_class_presented(self):
         """To support different kinds of objects returned vs. rows,
         there are two wrapper classes for Result.
index cdf70ca678592c5f7a0540989b9f059533c0a9e5..2eebb433dbb7af53739ba824d97a3a641420842c 100644 (file)
@@ -799,6 +799,42 @@ class AsyncResultTest(EngineFixture):
             ):
                 await conn.exec_driver_sql("SELECT * FROM users")
 
+    @async_test
+    async def test_stream_ctxmanager(self, async_engine):
+        async with async_engine.connect() as conn:
+            conn = await conn.execution_options(stream_results=True)
+
+            async with conn.stream(select(self.tables.users)) as result:
+                assert not result._real_result._soft_closed
+                assert not result.closed
+                with expect_raises_message(Exception, "hi"):
+                    i = 0
+                    async for row in result:
+                        if i > 2:
+                            raise Exception("hi")
+                        i += 1
+            assert result._real_result._soft_closed
+            assert result.closed
+
+    @async_test
+    async def test_stream_scalars_ctxmanager(self, async_engine):
+        async with async_engine.connect() as conn:
+            conn = await conn.execution_options(stream_results=True)
+
+            async with conn.stream_scalars(
+                select(self.tables.users)
+            ) as result:
+                assert not result._real_result._soft_closed
+                assert not result.closed
+                with expect_raises_message(Exception, "hi"):
+                    i = 0
+                    async for scalar in result:
+                        if i > 2:
+                            raise Exception("hi")
+                        i += 1
+            assert result._real_result._soft_closed
+            assert result.closed
+
     @testing.combinations(
         (None,), ("scalars",), ("mappings",), argnames="filter_"
     )
@@ -831,13 +867,20 @@ class AsyncResultTest(EngineFixture):
                 eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
 
     @testing.combinations(
-        (None,), ("scalars",), ("mappings",), argnames="filter_"
+        (None,),
+        ("scalars",),
+        ("stream_scalars",),
+        ("mappings",),
+        argnames="filter_",
     )
     @async_test
     async def test_aiter(self, async_engine, filter_):
         users = self.tables.users
         async with async_engine.connect() as conn:
-            result = await conn.stream(select(users))
+            if filter_ == "stream_scalars":
+                result = await conn.stream_scalars(select(users.c.user_name))
+            else:
+                result = await conn.stream(select(users))
 
             if filter_ == "mappings":
                 result = result.mappings()
@@ -857,7 +900,7 @@ class AsyncResultTest(EngineFixture):
                         for i in range(1, 20)
                     ],
                 )
-            elif filter_ == "scalars":
+            elif filter_ in ("scalars", "stream_scalars"):
                 eq_(
                     rows,
                     ["name%d" % i for i in range(1, 20)],
diff --git a/test/ext/mypy/plain_files/engines.py b/test/ext/mypy/plain_files/engines.py
new file mode 100644 (file)
index 0000000..c920ad5
--- /dev/null
@@ -0,0 +1,86 @@
+from sqlalchemy import create_engine
+from sqlalchemy import text
+from sqlalchemy.ext.asyncio import create_async_engine
+
+
+def regular() -> None:
+
+    e = create_engine("sqlite://")
+
+    # EXPECTED_TYPE: Engine
+    reveal_type(e)
+
+    with e.connect() as conn:
+
+        # EXPECTED_TYPE: Connection
+        reveal_type(conn)
+
+        result = conn.execute(text("select * from table"))
+
+        # EXPECTED_TYPE: CursorResult[Any]
+        reveal_type(result)
+
+    with e.begin() as conn:
+
+        # EXPECTED_TYPE: Connection
+        reveal_type(conn)
+
+        result = conn.execute(text("select * from table"))
+
+        # EXPECTED_TYPE: CursorResult[Any]
+        reveal_type(result)
+
+
+async def asyncio() -> None:
+    e = create_async_engine("sqlite://")
+
+    # EXPECTED_TYPE: AsyncEngine
+    reveal_type(e)
+
+    async with e.connect() as conn:
+
+        # EXPECTED_TYPE: AsyncConnection
+        reveal_type(conn)
+
+        result = await conn.execute(text("select * from table"))
+
+        # EXPECTED_TYPE: CursorResult[Any]
+        reveal_type(result)
+
+        # stream with direct await
+        async_result = await conn.stream(text("select * from table"))
+
+        # EXPECTED_TYPE: AsyncResult[Any]
+        reveal_type(async_result)
+
+        # stream with context manager
+        async with conn.stream(
+            text("select * from table")
+        ) as ctx_async_result:
+            # EXPECTED_TYPE: AsyncResult[Any]
+            reveal_type(ctx_async_result)
+
+        # stream_scalars with direct await
+        async_scalar_result = await conn.stream_scalars(
+            text("select * from table")
+        )
+
+        # EXPECTED_TYPE: AsyncScalarResult[Any]
+        reveal_type(async_scalar_result)
+
+        # stream_scalars with context manager
+        async with conn.stream_scalars(
+            text("select * from table")
+        ) as ctx_async_scalar_result:
+            # EXPECTED_TYPE: AsyncScalarResult[Any]
+            reveal_type(ctx_async_scalar_result)
+
+    async with e.begin() as conn:
+
+        # EXPECTED_TYPE: AsyncConnection
+        reveal_type(conn)
+
+        result = await conn.execute(text("select * from table"))
+
+        # EXPECTED_TYPE: CursorResult[Any]
+        reveal_type(result)
index cc3c3f49424f5e9c9356bac58448c38329b15707..d0b5c9d8f9c40c026063feb33f829eb395f08f6c 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy.orm import loading
 from sqlalchemy.orm import relationship
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
@@ -152,6 +153,24 @@ class InstancesTest(_fixtures.FixtureTest):
     def setup_mappers(cls):
         cls._setup_stock_mapping()
 
+    def test_cursor_close_exception_raised_in_iteration(self):
+        """test #8710"""
+
+        User = self.classes.User
+        s = fixture_session()
+
+        stmt = select(User).execution_options(yield_per=1)
+
+        result = s.execute(stmt)
+        raw_cursor = result.raw
+
+        for row in result:
+            with expect_raises_message(Exception, "whoops"):
+                for row in result:
+                    raise Exception("whoops")
+
+        is_true(raw_cursor._soft_closed)
+
     def test_cursor_close_w_failed_rowproc(self):
         User = self.classes.User
         s = fixture_session()
index 463009065666b2429d26cc48f4f0902ca0f80b60..c05fdaf4fc456b07c175e4e281556b55eefc8748 100644 (file)
@@ -5417,6 +5417,58 @@ class YieldTest(_fixtures.FixtureTest):
         result.close()
         assert_raises(sa.exc.ResourceClosedError, result.all)
 
+    def test_yield_per_close_on_interrupted_iteration_legacy(self):
+        """test #8710"""
+
+        self._eagerload_mappings()
+
+        User = self.classes.User
+
+        asserted_result = None
+
+        class _Query(Query):
+            def _iter(self):
+                nonlocal asserted_result
+                asserted_result = super(_Query, self)._iter()
+                return asserted_result
+
+        sess = fixture_session(query_cls=_Query)
+
+        with expect_raises_message(Exception, "hi"):
+            for i, row in enumerate(sess.query(User).yield_per(1)):
+                assert not asserted_result._soft_closed
+                assert not asserted_result.closed
+
+                if i > 1:
+                    raise Exception("hi")
+
+        assert asserted_result._soft_closed
+        assert not asserted_result.closed
+
+    def test_yield_per_close_on_interrupted_iteration(self):
+        """test #8710"""
+
+        self._eagerload_mappings()
+
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        with expect_raises_message(Exception, "hi"):
+            result = sess.execute(select(User).execution_options(yield_per=1))
+            for i, row in enumerate(result):
+                assert not result._soft_closed
+                assert not result.closed
+
+                if i > 1:
+                    raise Exception("hi")
+
+        assert not result._soft_closed
+        assert not result.closed
+        result.close()
+        assert result._soft_closed
+        assert result.closed
+
     def test_yield_per_and_execution_options_legacy(self):
         self._eagerload_mappings()
 
index 4f776e30033e4658d233f438777fbbce7552e9c3..fa86d75ee837e5b509c399da38a11ea8e59b4da7 100644 (file)
@@ -53,6 +53,7 @@ from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import le_
 from sqlalchemy.testing import mock
@@ -2033,6 +2034,89 @@ class CursorResultTest(fixtures.TablesTest):
             partition = next(result.partitions())
             eq_(len(partition), value)
 
+    @testing.fixture
+    def autoclose_row_fixture(self, connection):
+        users = self.tables.users
+        connection.execute(
+            users.insert(),
+            [
+                {"user_id": 1, "name": "u1"},
+                {"user_id": 2, "name": "u2"},
+                {"user_id": 3, "name": "u3"},
+                {"user_id": 4, "name": "u4"},
+                {"user_id": 5, "name": "u5"},
+            ],
+        )
+
+    @testing.fixture(params=["plain", "scalars", "mapping"])
+    def result_fixture(self, request, connection):
+        users = self.tables.users
+
+        result_type = request.param
+
+        if result_type == "plain":
+            result = connection.execute(select(users))
+        elif result_type == "scalars":
+            result = connection.scalars(select(users))
+        elif result_type == "mapping":
+            result = connection.execute(select(users)).mappings()
+        else:
+            assert False
+
+        return result
+
+    def test_results_can_close(self, autoclose_row_fixture, result_fixture):
+        """test #8710"""
+
+        r1 = result_fixture
+
+        is_false(r1.closed)
+        is_false(r1._soft_closed)
+
+        r1._soft_close()
+        is_false(r1.closed)
+        is_true(r1._soft_closed)
+
+        r1.close()
+        is_true(r1.closed)
+        is_true(r1._soft_closed)
+
+    def test_autoclose_rows_exhausted_plain(
+        self, connection, autoclose_row_fixture, result_fixture
+    ):
+        result = result_fixture
+
+        assert not result._soft_closed
+        assert not result.closed
+
+        read_iterator = list(result)
+        eq_(len(read_iterator), 5)
+
+        assert result._soft_closed
+        assert not result.closed
+
+        result.close()
+        assert result.closed
+
+    def test_result_ctxmanager(
+        self, connection, autoclose_row_fixture, result_fixture
+    ):
+        """test #8710"""
+
+        result = result_fixture
+
+        with expect_raises_message(Exception, "hi"):
+            with result:
+                assert not result._soft_closed
+                assert not result.closed
+
+                for i, obj in enumerate(result):
+                    if i > 2:
+                        raise Exception("hi")
+
+        assert result._soft_closed
+        assert result.closed
+
 
 class KeyTargetingTest(fixtures.TablesTest):
     run_inserts = "once"
@@ -3113,6 +3197,47 @@ class AlternateCursorResultTest(fixtures.TablesTest):
         # buffer of 98, plus buffer of 99 - 89, 10 rows
         eq_(len(result.cursor_strategy._rowbuffer), 10)
 
+        for i, row in enumerate(result):
+            if i == 206:
+                break
+
+        eq_(i, 206)
+
+    def test_iterator_remains_unbroken(self, connection):
+        """test related to #8710.
+
+        demonstrate that we can't close the cursor by catching
+        GeneratorExit inside of our iteration.  Leaving the iterable
+        block using break, then picking up again, would be directly
+        impacted by this.  So this provides a clear rationale for
+        providing context manager support for result objects.
+
+        """
+        table = self.tables.test
+
+        connection.execute(
+            table.insert(),
+            [{"x": i, "y": "t_%d" % i} for i in range(15, 250)],
+        )
+
+        result = connection.execute(table.select())
+        result = result.yield_per(100)
+        for i, row in enumerate(result):
+            if i == 188:
+                # this will raise GeneratorExit inside the iterator.
+                # so we can't close the DBAPI cursor here, we have plenty
+                # more rows to yield
+                break
+
+        eq_(i, 188)
+
+        # demonstrate getting more rows
+        for i, row in enumerate(result, 188):
+            if i == 206:
+                break
+
+        eq_(i, 206)
+
     @testing.combinations(True, False, argnames="close_on_init")
     @testing.combinations(
         "fetchone", "fetchmany", "fetchall", argnames="fetch_style"