]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
We can't promise CursorResult from session.execute()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2025 15:01:47 +0000 (11:01 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2025 17:34:49 +0000 (13:34 -0400)
Fixed typing bug where the :meth:`.Session.execute` method advertised that
it would return a :class:`.CursorResult` if given an insert/update/delete
statement.  This is not the general case as several flavors of ORM
insert/update do not actually yield a :class:`.CursorResult` which cannot
be differentiated at the typing overload level, so the method now yields
:class:`.Result` in all cases.  For those cases where
:class:`.CursorResult` is known to be returned and the ``.rowcount``
attribute is required, please use ``typing.cast()``.

Fixes: #12813
Change-Id: I8a7197100db312b3898c66ceddd6638e68c6bb44

doc/build/changelog/unreleased_20/12813.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
test/typing/plain_files/orm/typed_queries.py

diff --git a/doc/build/changelog/unreleased_20/12813.rst b/doc/build/changelog/unreleased_20/12813.rst
new file mode 100644 (file)
index 0000000..e478372
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 12813
+
+    Fixed typing bug where the :meth:`.Session.execute` method advertised that
+    it would return a :class:`.CursorResult` if given an insert/update/delete
+    statement.  This is not the general case as several flavors of ORM
+    insert/update do not actually yield a :class:`.CursorResult` which cannot
+    be differentiated at the typing overload level, so the method now yields
+    :class:`.Result` in all cases.  For those cases where
+    :class:`.CursorResult` is known to be returned and the ``.rowcount``
+    attribute is required, please use ``typing.cast()``.
index 6fbda514206186fe8bb9444e74ef8b6177d72b07..7730b7a52dd35e5cf2e729cc313b61d17f5d5ee5 100644 (file)
@@ -41,7 +41,6 @@ if TYPE_CHECKING:
     from .result import AsyncScalarResult
     from .session import AsyncSessionTransaction
     from ...engine import Connection
-    from ...engine import CursorResult
     from ...engine import Engine
     from ...engine import Result
     from ...engine import Row
@@ -58,7 +57,6 @@ if TYPE_CHECKING:
     from ...orm.session import _PKIdentityArgument
     from ...orm.session import _SessionBind
     from ...sql.base import Executable
-    from ...sql.dml import UpdateBase
     from ...sql.elements import ClauseElement
     from ...sql.selectable import ForUpdateParameter
     from ...sql.selectable import TypedReturnsRows
@@ -561,18 +559,6 @@ class async_scoped_session(Generic[_AS]):
         _add_event: Optional[Any] = None,
     ) -> Result[Unpack[_Ts]]: ...
 
-    @overload
-    async def execute(
-        self,
-        statement: UpdateBase,
-        params: Optional[_CoreAnyExecuteParams] = None,
-        *,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[_BindArguments] = None,
-        _parent_execute_state: Optional[Any] = None,
-        _add_event: Optional[Any] = None,
-    ) -> CursorResult[Unpack[TupleAny]]: ...
-
     @overload
     async def execute(
         self,
index 62ccb7c930f85d25a382d99fdb9fa47c7124f86f..77c7b6edae647cbe98786866e91e964c47123733 100644 (file)
@@ -49,7 +49,6 @@ if TYPE_CHECKING:
     from .engine import AsyncConnection
     from .engine import AsyncEngine
     from ...engine import Connection
-    from ...engine import CursorResult
     from ...engine import Engine
     from ...engine import Result
     from ...engine import Row
@@ -70,7 +69,6 @@ if TYPE_CHECKING:
     from ...orm.session import _SessionBindKey
     from ...sql._typing import _InfoType
     from ...sql.base import Executable
-    from ...sql.dml import UpdateBase
     from ...sql.elements import ClauseElement
     from ...sql.selectable import ForUpdateParameter
     from ...sql.selectable import TypedReturnsRows
@@ -414,18 +412,6 @@ class AsyncSession(ReversibleProxy[Session]):
         _add_event: Optional[Any] = None,
     ) -> Result[Unpack[_Ts]]: ...
 
-    @overload
-    async def execute(
-        self,
-        statement: UpdateBase,
-        params: Optional[_CoreAnyExecuteParams] = None,
-        *,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[_BindArguments] = None,
-        _parent_execute_state: Optional[Any] = None,
-        _add_event: Optional[Any] = None,
-    ) -> CursorResult[Unpack[TupleAny]]: ...
-
     @overload
     async def execute(
         self,
index 63065eca63267bc0b24857542b922573c4acad25..c5a9fe9ddc4d6827c7381e2f0ef784831b524934 100644 (file)
@@ -3255,11 +3255,14 @@ class Query(
         for ext in self._syntax_extensions:
             delete_._apply_syntax_extension_to_self(ext)
 
-        result: CursorResult[Any] = self.session.execute(
-            delete_,
-            self._params,
-            execution_options=self._execution_options.union(
-                {"synchronize_session": synchronize_session}
+        result = cast(
+            "CursorResult[Any]",
+            self.session.execute(
+                delete_,
+                self._params,
+                execution_options=self._execution_options.union(
+                    {"synchronize_session": synchronize_session}
+                ),
             ),
         )
         bulk_del.result = result  # type: ignore
@@ -3350,11 +3353,14 @@ class Query(
         for ext in self._syntax_extensions:
             upd._apply_syntax_extension_to_self(ext)
 
-        result: CursorResult[Any] = self.session.execute(
-            upd,
-            self._params,
-            execution_options=self._execution_options.union(
-                {"synchronize_session": synchronize_session}
+        result = cast(
+            "CursorResult[Any]",
+            self.session.execute(
+                upd,
+                self._params,
+                execution_options=self._execution_options.union(
+                    {"synchronize_session": synchronize_session}
+                ),
             ),
         )
         bulk_ud.result = result  # type: ignore
index 27cd734ea612574ff45acadc8e69222c00e807d7..7bf77e20c7c05db25d45deaecac3278bc4748904 100644 (file)
@@ -52,7 +52,6 @@ if TYPE_CHECKING:
     from .session import sessionmaker
     from .session import SessionTransaction
     from ..engine import Connection
-    from ..engine import CursorResult
     from ..engine import Engine
     from ..engine import Result
     from ..engine import Row
@@ -72,7 +71,6 @@ if TYPE_CHECKING:
     from ..sql._typing import _T7
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
-    from ..sql.dml import UpdateBase
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import ForUpdateParameter
@@ -713,18 +711,6 @@ class scoped_session(Generic[_S]):
         _add_event: Optional[Any] = None,
     ) -> Result[Unpack[_Ts]]: ...
 
-    @overload
-    def execute(
-        self,
-        statement: UpdateBase,
-        params: Optional[_CoreAnyExecuteParams] = None,
-        *,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[_BindArguments] = None,
-        _parent_execute_state: Optional[Any] = None,
-        _add_event: Optional[Any] = None,
-    ) -> CursorResult[Unpack[TupleAny]]: ...
-
     @overload
     def execute(
         self,
index 69d0f8aca9f9156b1a4da2bfd249ff4e085a3bfa..e4383c21cfbaea2bb3bf6301d5b38a139d04a77c 100644 (file)
@@ -107,7 +107,6 @@ if typing.TYPE_CHECKING:
     from .mapper import Mapper
     from .path_registry import PathRegistry
     from .query import RowReturningQuery
-    from ..engine import CursorResult
     from ..engine import Result
     from ..engine import Row
     from ..engine import RowMapping
@@ -132,7 +131,6 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
     from ..sql.base import ExecutableOption
-    from ..sql.dml import UpdateBase
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import ForUpdateParameter
@@ -2276,18 +2274,6 @@ class Session(_SessionClassMethods, EventTarget):
         _add_event: Optional[Any] = None,
     ) -> Result[Unpack[_Ts]]: ...
 
-    @overload
-    def execute(
-        self,
-        statement: UpdateBase,
-        params: Optional[_CoreAnyExecuteParams] = None,
-        *,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[_BindArguments] = None,
-        _parent_execute_state: Optional[Any] = None,
-        _add_event: Optional[Any] = None,
-    ) -> CursorResult[Unpack[TupleAny]]: ...
-
     @overload
     def execute(
         self,
index 424a03c8aec16b1ff4a4727074196c6eb21ed0a0..a3c07dd016f77e3603f7a8d90a21e1c34522b5c6 100644 (file)
@@ -442,37 +442,29 @@ def t_dml_insert() -> None:
 def t_dml_bare_insert() -> None:
     s1 = insert(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
+    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
     reveal_type(r1)
-    # EXPECTED_TYPE: int
-    reveal_type(r1.rowcount)
 
 
 def t_dml_bare_update() -> None:
     s1 = update(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
+    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
     reveal_type(r1)
-    # EXPECTED_TYPE: int
-    reveal_type(r1.rowcount)
 
 
 def t_dml_update_with_values() -> None:
     s1 = update(User).values({User.id: 123, User.data: "value"})
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
+    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
     reveal_type(r1)
-    # EXPECTED_TYPE: int
-    reveal_type(r1.rowcount)
 
 
 def t_dml_bare_delete() -> None:
     s1 = delete(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
+    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
     reveal_type(r1)
-    # EXPECTED_TYPE: int
-    reveal_type(r1.rowcount)
 
 
 def t_dml_update() -> None: