]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
We can't promise CursorResult from session.execute()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2025 15:01:47 +0000 (11:01 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2025 17:40:03 +0000 (13:40 -0400)
Fixed typing bug where the :meth:`.Session.execute` method advertised that
it would return a :class:`.CursorResult` if given an insert/update/delete
statement.  This is not the general case as several flavors of ORM
insert/update do not actually yield a :class:`.CursorResult` which cannot
be differentiated at the typing overload level, so the method now yields
:class:`.Result` in all cases.  For those cases where
:class:`.CursorResult` is known to be returned and the ``.rowcount``
attribute is required, please use ``typing.cast()``.

Fixes: #12813
Change-Id: I8a7197100db312b3898c66ceddd6638e68c6bb44
(cherry picked from commit a7275f8e06575dd6edcf84f8083361961e499f92)

doc/build/changelog/unreleased_20/12813.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
test/typing/plain_files/orm/typed_queries.py

diff --git a/doc/build/changelog/unreleased_20/12813.rst b/doc/build/changelog/unreleased_20/12813.rst
new file mode 100644 (file)
index 0000000..e478372
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 12813
+
+    Fixed typing bug where the :meth:`.Session.execute` method advertised that
+    it would return a :class:`.CursorResult` if given an insert/update/delete
+    statement.  This is not the general case as several flavors of ORM
+    insert/update do not actually yield a :class:`.CursorResult` which cannot
+    be differentiated at the typing overload level, so the method now yields
+    :class:`.Result` in all cases.  For those cases where
+    :class:`.CursorResult` is known to be returned and the ``.rowcount``
+    attribute is required, please use ``typing.cast()``.
index d2a9a51b231817a6a486ea221a45857da80a072e..107dabbe5e69915d9732ed4ef29146f15b2bfdd5 100644 (file)
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
     from .result import AsyncScalarResult
     from .session import AsyncSessionTransaction
     from ...engine import Connection
-    from ...engine import CursorResult
     from ...engine import Engine
     from ...engine import Result
     from ...engine import Row
@@ -55,7 +54,6 @@ if TYPE_CHECKING:
     from ...orm.session import _PKIdentityArgument
     from ...orm.session import _SessionBind
     from ...sql.base import Executable
-    from ...sql.dml import UpdateBase
     from ...sql.elements import ClauseElement
     from ...sql.selectable import ForUpdateParameter
     from ...sql.selectable import TypedReturnsRows
@@ -538,18 +536,6 @@ class async_scoped_session(Generic[_AS]):
         _add_event: Optional[Any] = None,
     ) -> Result[_T]: ...
 
-    @overload
-    async def execute(
-        self,
-        statement: UpdateBase,
-        params: Optional[_CoreAnyExecuteParams] = None,
-        *,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[_BindArguments] = None,
-        _parent_execute_state: Optional[Any] = None,
-        _add_event: Optional[Any] = None,
-    ) -> CursorResult[Any]: ...
-
     @overload
     async def execute(
         self,
index 68cbb59bfd62fa9c20a930990434aac46e1b2c46..9e07e19e8d16b7b0d3ebabc40fe5bb60a67e557d 100644 (file)
@@ -46,7 +46,6 @@ if TYPE_CHECKING:
     from .engine import AsyncConnection
     from .engine import AsyncEngine
     from ...engine import Connection
-    from ...engine import CursorResult
     from ...engine import Engine
     from ...engine import Result
     from ...engine import Row
@@ -67,7 +66,6 @@ if TYPE_CHECKING:
     from ...orm.session import _SessionBindKey
     from ...sql._typing import _InfoType
     from ...sql.base import Executable
-    from ...sql.dml import UpdateBase
     from ...sql.elements import ClauseElement
     from ...sql.selectable import ForUpdateParameter
     from ...sql.selectable import TypedReturnsRows
@@ -411,18 +409,6 @@ class AsyncSession(ReversibleProxy[Session]):
         _add_event: Optional[Any] = None,
     ) -> Result[_T]: ...
 
-    @overload
-    async def execute(
-        self,
-        statement: UpdateBase,
-        params: Optional[_CoreAnyExecuteParams] = None,
-        *,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[_BindArguments] = None,
-        _parent_execute_state: Optional[Any] = None,
-        _add_event: Optional[Any] = None,
-    ) -> CursorResult[Any]: ...
-
     @overload
     async def execute(
         self,
index 3489c15fd6f56e236e20cbf1b03c691062a4f9d1..620cd2116cd35c5685abfab2a95e1fcfa2ba81fc 100644 (file)
@@ -3205,11 +3205,14 @@ class Query(
             delete_ = delete_.with_dialect_options(**delete_args)
 
         delete_._where_criteria = self._where_criteria
-        result: CursorResult[Any] = self.session.execute(
-            delete_,
-            self._params,
-            execution_options=self._execution_options.union(
-                {"synchronize_session": synchronize_session}
+        result = cast(
+            "CursorResult[Any]",
+            self.session.execute(
+                delete_,
+                self._params,
+                execution_options=self._execution_options.union(
+                    {"synchronize_session": synchronize_session}
+                ),
             ),
         )
         bulk_del.result = result  # type: ignore
@@ -3296,11 +3299,14 @@ class Query(
             upd = upd.with_dialect_options(**update_args)
 
         upd._where_criteria = self._where_criteria
-        result: CursorResult[Any] = self.session.execute(
-            upd,
-            self._params,
-            execution_options=self._execution_options.union(
-                {"synchronize_session": synchronize_session}
+        result = cast(
+            "CursorResult[Any]",
+            self.session.execute(
+                upd,
+                self._params,
+                execution_options=self._execution_options.union(
+                    {"synchronize_session": synchronize_session}
+                ),
             ),
         )
         bulk_ud.result = result  # type: ignore
index df5a6534dce87d9c78d283cdff56803cfc6ee86c..cffbb4561f676d106b0d7c0d992c6367c4d5550b 100644 (file)
@@ -49,7 +49,6 @@ if TYPE_CHECKING:
     from .session import sessionmaker
     from .session import SessionTransaction
     from ..engine import Connection
-    from ..engine import CursorResult
     from ..engine import Engine
     from ..engine import Result
     from ..engine import Row
@@ -69,7 +68,6 @@ if TYPE_CHECKING:
     from ..sql._typing import _T7
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
-    from ..sql.dml import UpdateBase
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import ForUpdateParameter
@@ -685,18 +683,6 @@ class scoped_session(Generic[_S]):
         _add_event: Optional[Any] = None,
     ) -> Result[_T]: ...
 
-    @overload
-    def execute(
-        self,
-        statement: UpdateBase,
-        params: Optional[_CoreAnyExecuteParams] = None,
-        *,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[_BindArguments] = None,
-        _parent_execute_state: Optional[Any] = None,
-        _add_event: Optional[Any] = None,
-    ) -> CursorResult[Any]: ...
-
     @overload
     def execute(
         self,
index 6a589f3a33529e19169270e6f927ba50b339a832..f1456753b0e8841ca544ed10a0981ee36039d802 100644 (file)
@@ -102,7 +102,6 @@ if typing.TYPE_CHECKING:
     from .mapper import Mapper
     from .path_registry import PathRegistry
     from .query import RowReturningQuery
-    from ..engine import CursorResult
     from ..engine import Result
     from ..engine import Row
     from ..engine import RowMapping
@@ -127,7 +126,6 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
     from ..sql.base import ExecutableOption
-    from ..sql.dml import UpdateBase
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import ForUpdateParameter
@@ -2278,18 +2276,6 @@ class Session(_SessionClassMethods, EventTarget):
         _add_event: Optional[Any] = None,
     ) -> Result[_T]: ...
 
-    @overload
-    def execute(
-        self,
-        statement: UpdateBase,
-        params: Optional[_CoreAnyExecuteParams] = None,
-        *,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[_BindArguments] = None,
-        _parent_execute_state: Optional[Any] = None,
-        _add_event: Optional[Any] = None,
-    ) -> CursorResult[Any]: ...
-
     @overload
     def execute(
         self,
index b1226da30fcc72155d3154c833efc5936bfdbb25..1e305b7b20f9b957d6fe8dc596ec3fcfc0e75af5 100644 (file)
@@ -444,37 +444,29 @@ def t_dml_insert() -> None:
 def t_dml_bare_insert() -> None:
     s1 = insert(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: CursorResult[Any]
+    # EXPECTED_TYPE: Result[Any]
     reveal_type(r1)
-    # EXPECTED_TYPE: int
-    reveal_type(r1.rowcount)
 
 
 def t_dml_bare_update() -> None:
     s1 = update(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: CursorResult[Any]
+    # EXPECTED_TYPE: Result[Any]
     reveal_type(r1)
-    # EXPECTED_TYPE: int
-    reveal_type(r1.rowcount)
 
 
 def t_dml_update_with_values() -> None:
     s1 = update(User).values({User.id: 123, User.data: "value"})
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: CursorResult[Any]
+    # EXPECTED_TYPE: Result[Any]
     reveal_type(r1)
-    # EXPECTED_TYPE: int
-    reveal_type(r1.rowcount)
 
 
 def t_dml_bare_delete() -> None:
     s1 = delete(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: CursorResult[Any]
+    # EXPECTED_TYPE: Result[Any]
     reveal_type(r1)
-    # EXPECTED_TYPE: int
-    reveal_type(r1.rowcount)
 
 
 def t_dml_update() -> None: