From 67f62aac5b49b6d048ca39019e5bd123d3c9cfb2 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 18 Aug 2025 11:01:47 -0400 Subject: [PATCH] We can't promise CursorResult from session.execute() Fixed typing bug where the :meth:`.Session.execute` method advertised that it would return a :class:`.CursorResult` if given an insert/update/delete statement. This is not the general case as several flavors of ORM insert/update do not actually yield a :class:`.CursorResult` which cannot be differentiated at the typing overload level, so the method now yields :class:`.Result` in all cases. For those cases where :class:`.CursorResult` is known to be returned and the ``.rowcount`` attribute is required, please use ``typing.cast()``. Fixes: #12813 Change-Id: I8a7197100db312b3898c66ceddd6638e68c6bb44 --- doc/build/changelog/unreleased_20/12813.rst | 12 +++++++++ lib/sqlalchemy/ext/asyncio/scoping.py | 14 ----------- lib/sqlalchemy/ext/asyncio/session.py | 14 ----------- lib/sqlalchemy/orm/query.py | 26 ++++++++++++-------- lib/sqlalchemy/orm/scoping.py | 14 ----------- lib/sqlalchemy/orm/session.py | 14 ----------- test/typing/plain_files/orm/typed_queries.py | 16 +++--------- 7 files changed, 32 insertions(+), 78 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/12813.rst diff --git a/doc/build/changelog/unreleased_20/12813.rst b/doc/build/changelog/unreleased_20/12813.rst new file mode 100644 index 0000000000..e478372a11 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12813.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, typing + :tickets: 12813 + + Fixed typing bug where the :meth:`.Session.execute` method advertised that + it would return a :class:`.CursorResult` if given an insert/update/delete + statement. This is not the general case as several flavors of ORM + insert/update do not actually yield a :class:`.CursorResult` which cannot + be differentiated at the typing overload level, so the method now yields + :class:`.Result` in all cases. For those cases where + :class:`.CursorResult` is known to be returned and the ``.rowcount`` + attribute is required, please use ``typing.cast()``. diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 6fbda51420..7730b7a52d 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -41,7 +41,6 @@ if TYPE_CHECKING: from .result import AsyncScalarResult from .session import AsyncSessionTransaction from ...engine import Connection - from ...engine import CursorResult from ...engine import Engine from ...engine import Result from ...engine import Row @@ -58,7 +57,6 @@ if TYPE_CHECKING: from ...orm.session import _PKIdentityArgument from ...orm.session import _SessionBind from ...sql.base import Executable - from ...sql.dml import UpdateBase from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateParameter from ...sql.selectable import TypedReturnsRows @@ -561,18 +559,6 @@ class async_scoped_session(Generic[_AS]): _add_event: Optional[Any] = None, ) -> Result[Unpack[_Ts]]: ... - @overload - async def execute( - self, - statement: UpdateBase, - params: Optional[_CoreAnyExecuteParams] = None, - *, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[_BindArguments] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - ) -> CursorResult[Unpack[TupleAny]]: ... - @overload async def execute( self, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 62ccb7c930..77c7b6edae 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -49,7 +49,6 @@ if TYPE_CHECKING: from .engine import AsyncConnection from .engine import AsyncEngine from ...engine import Connection - from ...engine import CursorResult from ...engine import Engine from ...engine import Result from ...engine import Row @@ -70,7 +69,6 @@ if TYPE_CHECKING: from ...orm.session import _SessionBindKey from ...sql._typing import _InfoType from ...sql.base import Executable - from ...sql.dml import UpdateBase from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateParameter from ...sql.selectable import TypedReturnsRows @@ -414,18 +412,6 @@ class AsyncSession(ReversibleProxy[Session]): _add_event: Optional[Any] = None, ) -> Result[Unpack[_Ts]]: ... - @overload - async def execute( - self, - statement: UpdateBase, - params: Optional[_CoreAnyExecuteParams] = None, - *, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[_BindArguments] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - ) -> CursorResult[Unpack[TupleAny]]: ... - @overload async def execute( self, diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 63065eca63..c5a9fe9ddc 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3255,11 +3255,14 @@ class Query( for ext in self._syntax_extensions: delete_._apply_syntax_extension_to_self(ext) - result: CursorResult[Any] = self.session.execute( - delete_, - self._params, - execution_options=self._execution_options.union( - {"synchronize_session": synchronize_session} + result = cast( + "CursorResult[Any]", + self.session.execute( + delete_, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), ), ) bulk_del.result = result # type: ignore @@ -3350,11 +3353,14 @@ class Query( for ext in self._syntax_extensions: upd._apply_syntax_extension_to_self(ext) - result: CursorResult[Any] = self.session.execute( - upd, - self._params, - execution_options=self._execution_options.union( - {"synchronize_session": synchronize_session} + result = cast( + "CursorResult[Any]", + self.session.execute( + upd, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), ), ) bulk_ud.result = result # type: ignore diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 27cd734ea6..7bf77e20c7 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -52,7 +52,6 @@ if TYPE_CHECKING: from .session import sessionmaker from .session import SessionTransaction from ..engine import Connection - from ..engine import CursorResult from ..engine import Engine from ..engine import Result from ..engine import Row @@ -72,7 +71,6 @@ if TYPE_CHECKING: from ..sql._typing import _T7 from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable - from ..sql.dml import UpdateBase from ..sql.elements import ClauseElement from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateParameter @@ -713,18 +711,6 @@ class scoped_session(Generic[_S]): _add_event: Optional[Any] = None, ) -> Result[Unpack[_Ts]]: ... - @overload - def execute( - self, - statement: UpdateBase, - params: Optional[_CoreAnyExecuteParams] = None, - *, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[_BindArguments] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - ) -> CursorResult[Unpack[TupleAny]]: ... - @overload def execute( self, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 69d0f8aca9..e4383c21cf 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -107,7 +107,6 @@ if typing.TYPE_CHECKING: from .mapper import Mapper from .path_registry import PathRegistry from .query import RowReturningQuery - from ..engine import CursorResult from ..engine import Result from ..engine import Row from ..engine import RowMapping @@ -132,7 +131,6 @@ if typing.TYPE_CHECKING: from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.base import ExecutableOption - from ..sql.dml import UpdateBase from ..sql.elements import ClauseElement from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateParameter @@ -2276,18 +2274,6 @@ class Session(_SessionClassMethods, EventTarget): _add_event: Optional[Any] = None, ) -> Result[Unpack[_Ts]]: ... - @overload - def execute( - self, - statement: UpdateBase, - params: Optional[_CoreAnyExecuteParams] = None, - *, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[_BindArguments] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - ) -> CursorResult[Unpack[TupleAny]]: ... - @overload def execute( self, diff --git a/test/typing/plain_files/orm/typed_queries.py b/test/typing/plain_files/orm/typed_queries.py index 424a03c8ae..a3c07dd016 100644 --- a/test/typing/plain_files/orm/typed_queries.py +++ b/test/typing/plain_files/orm/typed_queries.py @@ -442,37 +442,29 @@ def t_dml_insert() -> None: def t_dml_bare_insert() -> None: s1 = insert(User) r1 = session.execute(s1) - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] + # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] reveal_type(r1) - # EXPECTED_TYPE: int - reveal_type(r1.rowcount) def t_dml_bare_update() -> None: s1 = update(User) r1 = session.execute(s1) - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] + # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] reveal_type(r1) - # EXPECTED_TYPE: int - reveal_type(r1.rowcount) def t_dml_update_with_values() -> None: s1 = update(User).values({User.id: 123, User.data: "value"}) r1 = session.execute(s1) - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] + # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] reveal_type(r1) - # EXPECTED_TYPE: int - reveal_type(r1.rowcount) def t_dml_bare_delete() -> None: s1 = delete(User) r1 = session.execute(s1) - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] + # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] reveal_type(r1) - # EXPECTED_TYPE: int - reveal_type(r1.rowcount) def t_dml_update() -> None: -- 2.47.3