From: Mike Bayer Date: Mon, 18 Aug 2025 15:01:47 +0000 (-0400) Subject: We can't promise CursorResult from session.execute() X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5b7fc0b2c71a7012130c7850cfba551693bacc98;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git We can't promise CursorResult from session.execute() Fixed typing bug where the :meth:`.Session.execute` method advertised that it would return a :class:`.CursorResult` if given an insert/update/delete statement. This is not the general case as several flavors of ORM insert/update do not actually yield a :class:`.CursorResult` which cannot be differentiated at the typing overload level, so the method now yields :class:`.Result` in all cases. For those cases where :class:`.CursorResult` is known to be returned and the ``.rowcount`` attribute is required, please use ``typing.cast()``. Fixes: #12813 Change-Id: I8a7197100db312b3898c66ceddd6638e68c6bb44 (cherry picked from commit a7275f8e06575dd6edcf84f8083361961e499f92) --- diff --git a/doc/build/changelog/unreleased_20/12813.rst b/doc/build/changelog/unreleased_20/12813.rst new file mode 100644 index 0000000000..e478372a11 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12813.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, typing + :tickets: 12813 + + Fixed typing bug where the :meth:`.Session.execute` method advertised that + it would return a :class:`.CursorResult` if given an insert/update/delete + statement. This is not the general case as several flavors of ORM + insert/update do not actually yield a :class:`.CursorResult` which cannot + be differentiated at the typing overload level, so the method now yields + :class:`.Result` in all cases. For those cases where + :class:`.CursorResult` is known to be returned and the ``.rowcount`` + attribute is required, please use ``typing.cast()``. diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index d2a9a51b23..107dabbe5e 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -38,7 +38,6 @@ if TYPE_CHECKING: from .result import AsyncScalarResult from .session import AsyncSessionTransaction from ...engine import Connection - from ...engine import CursorResult from ...engine import Engine from ...engine import Result from ...engine import Row @@ -55,7 +54,6 @@ if TYPE_CHECKING: from ...orm.session import _PKIdentityArgument from ...orm.session import _SessionBind from ...sql.base import Executable - from ...sql.dml import UpdateBase from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateParameter from ...sql.selectable import TypedReturnsRows @@ -538,18 +536,6 @@ class async_scoped_session(Generic[_AS]): _add_event: Optional[Any] = None, ) -> Result[_T]: ... - @overload - async def execute( - self, - statement: UpdateBase, - params: Optional[_CoreAnyExecuteParams] = None, - *, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[_BindArguments] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: ... - @overload async def execute( self, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 68cbb59bfd..9e07e19e8d 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -46,7 +46,6 @@ if TYPE_CHECKING: from .engine import AsyncConnection from .engine import AsyncEngine from ...engine import Connection - from ...engine import CursorResult from ...engine import Engine from ...engine import Result from ...engine import Row @@ -67,7 +66,6 @@ if TYPE_CHECKING: from ...orm.session import _SessionBindKey from ...sql._typing import _InfoType from ...sql.base import Executable - from ...sql.dml import UpdateBase from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateParameter from ...sql.selectable import TypedReturnsRows @@ -411,18 +409,6 @@ class AsyncSession(ReversibleProxy[Session]): _add_event: Optional[Any] = None, ) -> Result[_T]: ... - @overload - async def execute( - self, - statement: UpdateBase, - params: Optional[_CoreAnyExecuteParams] = None, - *, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[_BindArguments] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: ... - @overload async def execute( self, diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 3489c15fd6..620cd2116c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3205,11 +3205,14 @@ class Query( delete_ = delete_.with_dialect_options(**delete_args) delete_._where_criteria = self._where_criteria - result: CursorResult[Any] = self.session.execute( - delete_, - self._params, - execution_options=self._execution_options.union( - {"synchronize_session": synchronize_session} + result = cast( + "CursorResult[Any]", + self.session.execute( + delete_, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), ), ) bulk_del.result = result # type: ignore @@ -3296,11 +3299,14 @@ class Query( upd = upd.with_dialect_options(**update_args) upd._where_criteria = self._where_criteria - result: CursorResult[Any] = self.session.execute( - upd, - self._params, - execution_options=self._execution_options.union( - {"synchronize_session": synchronize_session} + result = cast( + "CursorResult[Any]", + self.session.execute( + upd, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), ), ) bulk_ud.result = result # type: ignore diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index df5a6534dc..cffbb4561f 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -49,7 +49,6 @@ if TYPE_CHECKING: from .session import sessionmaker from .session import SessionTransaction from ..engine import Connection - from ..engine import CursorResult from ..engine import Engine from ..engine import Result from ..engine import Row @@ -69,7 +68,6 @@ if TYPE_CHECKING: from ..sql._typing import _T7 from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable - from ..sql.dml import UpdateBase from ..sql.elements import ClauseElement from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateParameter @@ -685,18 +683,6 @@ class scoped_session(Generic[_S]): _add_event: Optional[Any] = None, ) -> Result[_T]: ... - @overload - def execute( - self, - statement: UpdateBase, - params: Optional[_CoreAnyExecuteParams] = None, - *, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[_BindArguments] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: ... - @overload def execute( self, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 6a589f3a33..f1456753b0 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -102,7 +102,6 @@ if typing.TYPE_CHECKING: from .mapper import Mapper from .path_registry import PathRegistry from .query import RowReturningQuery - from ..engine import CursorResult from ..engine import Result from ..engine import Row from ..engine import RowMapping @@ -127,7 +126,6 @@ if typing.TYPE_CHECKING: from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.base import ExecutableOption - from ..sql.dml import UpdateBase from ..sql.elements import ClauseElement from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateParameter @@ -2278,18 +2276,6 @@ class Session(_SessionClassMethods, EventTarget): _add_event: Optional[Any] = None, ) -> Result[_T]: ... - @overload - def execute( - self, - statement: UpdateBase, - params: Optional[_CoreAnyExecuteParams] = None, - *, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[_BindArguments] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: ... - @overload def execute( self, diff --git a/test/typing/plain_files/orm/typed_queries.py b/test/typing/plain_files/orm/typed_queries.py index b1226da30f..1e305b7b20 100644 --- a/test/typing/plain_files/orm/typed_queries.py +++ b/test/typing/plain_files/orm/typed_queries.py @@ -444,37 +444,29 @@ def t_dml_insert() -> None: def t_dml_bare_insert() -> None: s1 = insert(User) r1 = session.execute(s1) - # EXPECTED_TYPE: CursorResult[Any] + # EXPECTED_TYPE: Result[Any] reveal_type(r1) - # EXPECTED_TYPE: int - reveal_type(r1.rowcount) def t_dml_bare_update() -> None: s1 = update(User) r1 = session.execute(s1) - # EXPECTED_TYPE: CursorResult[Any] + # EXPECTED_TYPE: Result[Any] reveal_type(r1) - # EXPECTED_TYPE: int - reveal_type(r1.rowcount) def t_dml_update_with_values() -> None: s1 = update(User).values({User.id: 123, User.data: "value"}) r1 = session.execute(s1) - # EXPECTED_TYPE: CursorResult[Any] + # EXPECTED_TYPE: Result[Any] reveal_type(r1) - # EXPECTED_TYPE: int - reveal_type(r1.rowcount) def t_dml_bare_delete() -> None: s1 = delete(User) r1 = session.execute(s1) - # EXPECTED_TYPE: CursorResult[Any] + # EXPECTED_TYPE: Result[Any] reveal_type(r1) - # EXPECTED_TYPE: int - reveal_type(r1.rowcount) def t_dml_update() -> None: