From: Mehdi Gmira Date: Mon, 7 Aug 2023 14:50:39 +0000 (-0400) Subject: Fix annotations X-Git-Tag: rel_2_0_20~16^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0661cd99c4e06115d3dd8318a6bda5e2b41d11ae;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix annotations Typing improvements: * :class:`.CursorResult` is returned for some forms of :meth:`_orm.Session.execute` where DML without RETURNING is used * fixed type for :paramref:`_orm.Query.with_for_update.of` parameter within :meth:`_orm.Query.with_for_update` * improvements to ``_DMLColumnArgument`` type used by some DML methods to pass column expressions * Add overload to :func:`_sql.literal` so that it is inferred that the return type is ``BindParameter[NullType]`` where :paramref:`_sql.literal.type_` param is None * Add overloads to :meth:`_sql.ColumnElement.op` so that the inferred type when :paramref:`_sql.ColumnElement.op.return_type` is not provided is ``Callable[[Any], BinaryExpression[Any]]`` * Add missing overload to :meth:`_sql.ColumnElement.__add__` Pull request courtesy Mehdi Gmira. Fixes: #9185 Closes: #10108 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10108 Pull-request-sha: 6017526c9cd1282025885cb002de1f984f64205b Change-Id: I77a2a199b7a8b137b405001bef8813cf2d327bca --- diff --git a/doc/build/changelog/unreleased_20/9185.rst b/doc/build/changelog/unreleased_20/9185.rst new file mode 100644 index 0000000000..a28e8f9c72 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9185.rst @@ -0,0 +1,22 @@ +.. change:: + :tags: bug, typing + :tickets: 9185 + + Typing improvements: + + * :class:`.CursorResult` is returned for some forms of + :meth:`_orm.Session.execute` where DML without RETURNING is used + * fixed type for :paramref:`_orm.Query.with_for_update.of` parameter within + :meth:`_orm.Query.with_for_update` + * improvements to ``_DMLColumnArgument`` type used by some DML methods to + pass column expressions + * Add overload to :func:`_sql.literal` so that it is inferred that the + return type is ``BindParameter[NullType]`` where + :paramref:`_sql.literal.type_` param is None + * Add overloads to :meth:`_sql.ColumnElement.op` so that the inferred + type when :paramref:`_sql.ColumnElement.op.return_type` is not provided + is ``Callable[[Any], BinaryExpression[Any]]`` + * Add missing overload to :meth:`_sql.ColumnElement.__add__` + + Pull request courtesy Mehdi Gmira. + diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 155da8f4c9..b70c3366b1 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: from .result import AsyncScalarResult from .session import AsyncSessionTransaction from ...engine import Connection + from ...engine import CursorResult from ...engine import Engine from ...engine import Result from ...engine import Row @@ -54,6 +55,7 @@ if TYPE_CHECKING: from ...orm.session import _PKIdentityArgument from ...orm.session import _SessionBind from ...sql.base import Executable + from ...sql.dml import UpdateBase from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateParameter from ...sql.selectable import TypedReturnsRows @@ -554,6 +556,19 @@ class async_scoped_session(Generic[_AS]): ) -> Result[_T]: ... + @overload + async def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: + ... + @overload async def execute( self, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 3d176b4e7b..da69c4fb3e 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from .engine import AsyncConnection from .engine import AsyncEngine from ...engine import Connection + from ...engine import CursorResult from ...engine import Engine from ...engine import Result from ...engine import Row @@ -62,6 +63,7 @@ if TYPE_CHECKING: from ...orm.session import _SessionBindKey from ...sql._typing import _InfoType from ...sql.base import Executable + from ...sql.dml import UpdateBase from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateParameter from ...sql.selectable import TypedReturnsRows @@ -398,6 +400,19 @@ class AsyncSession(ReversibleProxy[Session]): ) -> Result[_T]: ... + @overload + async def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: + ... + @overload async def execute( self, diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index e6381bee16..14e75fab94 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -136,6 +136,7 @@ if TYPE_CHECKING: from ..sql.base import ExecutableOption from ..sql.elements import ColumnElement from ..sql.elements import Label + from ..sql.selectable import _ForUpdateOfArgument from ..sql.selectable import _JoinTargetElement from ..sql.selectable import _SetupJoinsElement from ..sql.selectable import Alias @@ -1786,12 +1787,7 @@ class Query( *, nowait: bool = False, read: bool = False, - of: Optional[ - Union[ - _ColumnExpressionArgument[Any], - Sequence[_ColumnExpressionArgument[Any]], - ] - ] = None, + of: Optional[_ForUpdateOfArgument] = None, skip_locked: bool = False, key_share: bool = False, ) -> Self: @@ -3177,14 +3173,11 @@ class Query( delete_ = sql.delete(*self._raw_columns) # type: ignore delete_._where_criteria = self._where_criteria - result: CursorResult[Any] = cast( - "CursorResult[Any]", - self.session.execute( - delete_, - self._params, - execution_options=self._execution_options.union( - {"synchronize_session": synchronize_session} - ), + result: CursorResult[Any] = self.session.execute( + delete_, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} ), ) bulk_del.result = result # type: ignore @@ -3270,14 +3263,11 @@ class Query( upd = upd.with_dialect_options(**update_args) upd._where_criteria = self._where_criteria - result: CursorResult[Any] = cast( - "CursorResult[Any]", - self.session.execute( - upd, - self._params, - execution_options=self._execution_options.union( - {"synchronize_session": synchronize_session} - ), + result: CursorResult[Any] = self.session.execute( + upd, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} ), ) bulk_ud.result = result # type: ignore diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 2fbd4fbe51..fc144d98c4 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -49,6 +49,7 @@ if TYPE_CHECKING: from .session import sessionmaker from .session import SessionTransaction from ..engine import Connection + from ..engine import CursorResult from ..engine import Engine from ..engine import Result from ..engine import Row @@ -68,6 +69,7 @@ if TYPE_CHECKING: from ..sql._typing import _T7 from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable + from ..sql.dml import UpdateBase from ..sql.elements import ClauseElement from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateParameter @@ -639,6 +641,19 @@ class scoped_session(Generic[_S]): ) -> Result[_T]: ... + @overload + def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: + ... + @overload def execute( self, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index fade4c4884..e5eb5036dd 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -101,6 +101,7 @@ if typing.TYPE_CHECKING: from .mapper import Mapper from .path_registry import PathRegistry from .query import RowReturningQuery + from ..engine import CursorResult from ..engine import Result from ..engine import Row from ..engine import RowMapping @@ -125,6 +126,7 @@ if typing.TYPE_CHECKING: from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.base import ExecutableOption + from ..sql.dml import UpdateBase from ..sql.elements import ClauseElement from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateParameter @@ -2170,6 +2172,19 @@ class Session(_SessionClassMethods, EventTarget): ) -> Result[_T]: ... + @overload + def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: + ... + @overload def execute( self, diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b4f6fc71cf..f83c4b4771 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -234,7 +234,7 @@ _DMLColumnArgument = Union[ str, _HasClauseElement, roles.DMLColumnRole, - "SQLCoreOperations", + "SQLCoreOperations[Any]", ] """A DML column expression. This is a "key" inside of insert().values(), update().values(), and related. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 1c86c1669e..75857fcfc4 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -123,11 +123,38 @@ _NT = TypeVar("_NT", bound="_NUMERIC") _NMT = TypeVar("_NMT", bound="_NUMBER") +@overload def literal( value: Any, - type_: Optional[_TypeEngineArgument[_T]] = None, + type_: _TypeEngineArgument[_T], literal_execute: bool = False, ) -> BindParameter[_T]: + ... + + +@overload +def literal( + value: _T, + type_: None = None, + literal_execute: bool = False, +) -> BindParameter[_T]: + ... + + +@overload +def literal( + value: Any, + type_: Optional[_TypeEngineArgument[Any]] = None, + literal_execute: bool = False, +) -> BindParameter[Any]: + ... + + +def literal( + value: Any, + type_: Optional[_TypeEngineArgument[Any]] = None, + literal_execute: bool = False, +) -> BindParameter[Any]: r"""Return a literal clause, bound to a bind parameter. Literal clauses are created automatically when non- @@ -799,14 +826,37 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): ) -> ColumnElement[Any]: ... + @overload + def op( + self, + opstring: str, + precedence: int = ..., + is_comparison: bool = ..., + *, + return_type: _TypeEngineArgument[_OPT], + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[_OPT]]: + ... + + @overload + def op( + self, + opstring: str, + precedence: int = ..., + is_comparison: bool = ..., + return_type: Optional[_TypeEngineArgument[Any]] = ..., + python_impl: Optional[Callable[..., Any]] = ..., + ) -> Callable[[Any], BinaryExpression[Any]]: + ... + def op( self, opstring: str, precedence: int = 0, is_comparison: bool = False, - return_type: Optional[_TypeEngineArgument[_OPT]] = None, + return_type: Optional[_TypeEngineArgument[Any]] = None, python_impl: Optional[Callable[..., Any]] = None, - ) -> Callable[[Any], BinaryExpression[_OPT]]: + ) -> Callable[[Any], BinaryExpression[Any]]: ... def bool_op( @@ -1078,6 +1128,10 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): ) -> ColumnElement[str]: ... + @overload + def __add__(self, other: Any) -> ColumnElement[Any]: + ... + def __add__(self, other: Any) -> ColumnElement[Any]: ... diff --git a/test/typing/plain_files/orm/typed_queries.py b/test/typing/plain_files/orm/typed_queries.py index c3468d9528..530e5f670f 100644 --- a/test/typing/plain_files/orm/typed_queries.py +++ b/test/typing/plain_files/orm/typed_queries.py @@ -440,6 +440,33 @@ def t_dml_insert() -> None: reveal_type(r3) +def t_dml_bare_insert() -> None: + s1 = insert(User) + r1 = session.execute(s1) + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(r1) + # EXPECTED_TYPE: int + reveal_type(r1.rowcount) + + +def t_dml_bare_update() -> None: + s1 = update(User) + r1 = session.execute(s1) + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(r1) + # EXPECTED_TYPE: int + reveal_type(r1.rowcount) + + +def t_dml_bare_delete() -> None: + s1 = delete(User) + r1 = session.execute(s1) + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(r1) + # EXPECTED_TYPE: int + reveal_type(r1.rowcount) + + def t_dml_update() -> None: s1 = update(User).returning(User.id, User.name) diff --git a/test/typing/plain_files/sql/common_sql_element.py b/test/typing/plain_files/sql/common_sql_element.py index fd90e31a11..1152a04b17 100644 --- a/test/typing/plain_files/sql/common_sql_element.py +++ b/test/typing/plain_files/sql/common_sql_element.py @@ -13,6 +13,7 @@ from sqlalchemy import asc from sqlalchemy import Column from sqlalchemy import desc from sqlalchemy import Integer +from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import SQLColumnExpression @@ -138,3 +139,24 @@ s9174_5 = select(user_table).with_for_update(of=user_table.c.id) s9174_6 = select(user_table).with_for_update( of=[user_table.c.id, user_table.c.email] ) + +# with_for_update but for query +session = Session() +user = session.query(User).with_for_update(of=User) +user = session.query(User).with_for_update(of=User.id) +user = session.query(User).with_for_update(of=[User.id, User.email]) +user = session.query(user_table).with_for_update(of=user_table) +user = session.query(user_table).with_for_update(of=user_table.c.id) +user = session.query(user_table).with_for_update( + of=[user_table.c.id, user_table.c.email] +) + +# literal +# EXPECTED_TYPE: BindParameter[str] +reveal_type(literal("5")) +# EXPECTED_TYPE: BindParameter[str] +reveal_type(literal("5", None)) +# EXPECTED_TYPE: BindParameter[int] +reveal_type(literal("123", Integer)) +# EXPECTED_TYPE: BindParameter[int] +reveal_type(literal("123", Integer)) diff --git a/test/typing/plain_files/sql/operators.py b/test/typing/plain_files/sql/operators.py index 8258ec65b1..2e2f31df9c 100644 --- a/test/typing/plain_files/sql/operators.py +++ b/test/typing/plain_files/sql/operators.py @@ -140,6 +140,11 @@ op_d: "ColumnElement[int]" = col.op("&", return_type=BigInteger)("1") op_e: "ColumnElement[bool]" = col.bool_op("&")("1") +op_a1 = col.op("&")(1) +# EXPECTED_TYPE: BinaryExpression[Any] +reveal_type(op_a1) + + # op functions t1 = operators.eq(A.id, 1) select().where(t1)