From 037c051997306cf0c5550cda1e7630cdebcfdfec Mon Sep 17 00:00:00 2001 From: seria Date: Thu, 18 Sep 2025 05:37:02 +0800 Subject: [PATCH] =?utf8?q?=E2=9C=A8=20Add=20overload=20for=20`exec`=20meth?= =?utf8?q?od=20to=20support=20`insert`,=20`update`,=20`delete`=20statement?= =?utf8?q?s=20(#1342)?= MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem Co-authored-by: Motov Yurii <109919500+YuriiMotov@users.noreply.github.com> --- sqlmodel/ext/asyncio/session.py | 19 ++++++++++++++++++- sqlmodel/orm/session.py | 19 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 467d0bd8..54488357 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -12,6 +12,7 @@ from typing import ( ) from sqlalchemy import util +from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams from sqlalchemy.engine.result import Result, ScalarResult, TupleResult from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession @@ -19,6 +20,7 @@ from sqlalchemy.ext.asyncio.result import _ensure_sync_result from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS from sqlalchemy.orm._typing import OrmExecuteOptionsParameter from sqlalchemy.sql.base import Executable as _Executable +from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.util.concurrency import greenlet_spawn from typing_extensions import deprecated @@ -57,12 +59,25 @@ class AsyncSession(_AsyncSession): _add_event: Optional[Any] = None, ) -> ScalarResult[_TSelectParam]: ... + @overload + async def exec( + self, + statement: UpdateBase, + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: ... + async def exec( self, statement: Union[ Select[_TSelectParam], SelectOfScalar[_TSelectParam], Executable[_TSelectParam], + UpdateBase, ], *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, @@ -70,7 +85,9 @@ class AsyncSession(_AsyncSession): bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]: + ) -> Union[ + TupleResult[_TSelectParam], ScalarResult[_TSelectParam], CursorResult[Any] + ]: if execution_options: execution_options = util.immutabledict(execution_options).union( _EXECUTE_OPTIONS diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index b6087509..dca4733d 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -10,6 +10,7 @@ from typing import ( ) from sqlalchemy import util +from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams from sqlalchemy.engine.result import Result, ScalarResult, TupleResult from sqlalchemy.orm import Query as _Query @@ -17,6 +18,7 @@ from sqlalchemy.orm import Session as _Session from sqlalchemy.orm._typing import OrmExecuteOptionsParameter from sqlalchemy.sql._typing import _ColumnsClauseArgument from sqlalchemy.sql.base import Executable as _Executable +from sqlalchemy.sql.dml import UpdateBase from sqlmodel.sql.base import Executable from sqlmodel.sql.expression import Select, SelectOfScalar from typing_extensions import deprecated @@ -49,12 +51,25 @@ class Session(_Session): _add_event: Optional[Any] = None, ) -> ScalarResult[_TSelectParam]: ... + @overload + def exec( + self, + statement: UpdateBase, + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: ... + def exec( self, statement: Union[ Select[_TSelectParam], SelectOfScalar[_TSelectParam], Executable[_TSelectParam], + UpdateBase, ], *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, @@ -62,7 +77,9 @@ class Session(_Session): bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]: + ) -> Union[ + TupleResult[_TSelectParam], ScalarResult[_TSelectParam], CursorResult[Any] + ]: results = super().execute( statement, params=params, -- 2.47.3