]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix annotations
authorMehdi Gmira <mgmira@wiremind.io>
Mon, 7 Aug 2023 14:50:39 +0000 (10:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Aug 2023 22:12:58 +0000 (18:12 -0400)
Typing improvements:

* :class:`.CursorResult` is returned for some forms of
:meth:`_orm.Session.execute` where DML without RETURNING is used
* fixed type for :paramref:`_orm.Query.with_for_update.of` parameter within
:meth:`_orm.Query.with_for_update`
* improvements to ``_DMLColumnArgument`` type used by some DML methods to
pass column expressions
* Add overload to :func:`_sql.literal` so that it is inferred that the
return type is ``BindParameter[NullType]`` where
:paramref:`_sql.literal.type_` param is None
* Add overloads to :meth:`_sql.ColumnElement.op` so that the inferred
type when :paramref:`_sql.ColumnElement.op.return_type` is not provided
is ``Callable[[Any], BinaryExpression[Any]]``
* Add missing overload to :meth:`_sql.ColumnElement.__add__`

Pull request courtesy Mehdi Gmira.

Fixes: #9185
Closes: #10108
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10108
Pull-request-sha: 6017526c9cd1282025885cb002de1f984f64205b

Change-Id: I77a2a199b7a8b137b405001bef8813cf2d327bca

doc/build/changelog/unreleased_20/9185.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/elements.py
test/typing/plain_files/orm/typed_queries.py
test/typing/plain_files/sql/common_sql_element.py
test/typing/plain_files/sql/operators.py

diff --git a/doc/build/changelog/unreleased_20/9185.rst b/doc/build/changelog/unreleased_20/9185.rst
new file mode 100644 (file)
index 0000000..a28e8f9
--- /dev/null
@@ -0,0 +1,22 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 9185
+
+    Typing improvements:
+
+    * :class:`.CursorResult` is returned for some forms of
+      :meth:`_orm.Session.execute` where DML without RETURNING is used
+    * fixed type for :paramref:`_orm.Query.with_for_update.of` parameter within
+      :meth:`_orm.Query.with_for_update`
+    * improvements to ``_DMLColumnArgument`` type used by some DML methods to
+      pass column expressions
+    * Add overload to :func:`_sql.literal` so that it is inferred that the
+      return type is ``BindParameter[NullType]`` where
+      :paramref:`_sql.literal.type_` param is None
+    * Add overloads to :meth:`_sql.ColumnElement.op` so that the inferred
+      type when :paramref:`_sql.ColumnElement.op.return_type` is not provided
+      is ``Callable[[Any], BinaryExpression[Any]]``
+    * Add missing overload to :meth:`_sql.ColumnElement.__add__`
+
+    Pull request courtesy Mehdi Gmira.
+
index 155da8f4c9d6b28cf68efb85a073be18e12d9869..b70c3366b16513045d421188ca56a98aa1287095 100644 (file)
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
     from .result import AsyncScalarResult
     from .session import AsyncSessionTransaction
     from ...engine import Connection
+    from ...engine import CursorResult
     from ...engine import Engine
     from ...engine import Result
     from ...engine import Row
@@ -54,6 +55,7 @@ if TYPE_CHECKING:
     from ...orm.session import _PKIdentityArgument
     from ...orm.session import _SessionBind
     from ...sql.base import Executable
+    from ...sql.dml import UpdateBase
     from ...sql.elements import ClauseElement
     from ...sql.selectable import ForUpdateParameter
     from ...sql.selectable import TypedReturnsRows
@@ -554,6 +556,19 @@ class async_scoped_session(Generic[_AS]):
     ) -> Result[_T]:
         ...
 
+    @overload
+    async def execute(
+        self,
+        statement: UpdateBase,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> CursorResult[Any]:
+        ...
+
     @overload
     async def execute(
         self,
index 3d176b4e7b26d4b794a212fb689981e4f7666378..da69c4fb3efc651d6a6354f211a6d46866008130 100644 (file)
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
     from .engine import AsyncConnection
     from .engine import AsyncEngine
     from ...engine import Connection
+    from ...engine import CursorResult
     from ...engine import Engine
     from ...engine import Result
     from ...engine import Row
@@ -62,6 +63,7 @@ if TYPE_CHECKING:
     from ...orm.session import _SessionBindKey
     from ...sql._typing import _InfoType
     from ...sql.base import Executable
+    from ...sql.dml import UpdateBase
     from ...sql.elements import ClauseElement
     from ...sql.selectable import ForUpdateParameter
     from ...sql.selectable import TypedReturnsRows
@@ -398,6 +400,19 @@ class AsyncSession(ReversibleProxy[Session]):
     ) -> Result[_T]:
         ...
 
+    @overload
+    async def execute(
+        self,
+        statement: UpdateBase,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> CursorResult[Any]:
+        ...
+
     @overload
     async def execute(
         self,
index e6381bee16b553f8b170bf3442026696d86f6df7..14e75fab94857ba0c5a51b7ae0118dba4f18411d 100644 (file)
@@ -136,6 +136,7 @@ if TYPE_CHECKING:
     from ..sql.base import ExecutableOption
     from ..sql.elements import ColumnElement
     from ..sql.elements import Label
+    from ..sql.selectable import _ForUpdateOfArgument
     from ..sql.selectable import _JoinTargetElement
     from ..sql.selectable import _SetupJoinsElement
     from ..sql.selectable import Alias
@@ -1786,12 +1787,7 @@ class Query(
         *,
         nowait: bool = False,
         read: bool = False,
-        of: Optional[
-            Union[
-                _ColumnExpressionArgument[Any],
-                Sequence[_ColumnExpressionArgument[Any]],
-            ]
-        ] = None,
+        of: Optional[_ForUpdateOfArgument] = None,
         skip_locked: bool = False,
         key_share: bool = False,
     ) -> Self:
@@ -3177,14 +3173,11 @@ class Query(
 
         delete_ = sql.delete(*self._raw_columns)  # type: ignore
         delete_._where_criteria = self._where_criteria
-        result: CursorResult[Any] = cast(
-            "CursorResult[Any]",
-            self.session.execute(
-                delete_,
-                self._params,
-                execution_options=self._execution_options.union(
-                    {"synchronize_session": synchronize_session}
-                ),
+        result: CursorResult[Any] = self.session.execute(
+            delete_,
+            self._params,
+            execution_options=self._execution_options.union(
+                {"synchronize_session": synchronize_session}
             ),
         )
         bulk_del.result = result  # type: ignore
@@ -3270,14 +3263,11 @@ class Query(
             upd = upd.with_dialect_options(**update_args)
 
         upd._where_criteria = self._where_criteria
-        result: CursorResult[Any] = cast(
-            "CursorResult[Any]",
-            self.session.execute(
-                upd,
-                self._params,
-                execution_options=self._execution_options.union(
-                    {"synchronize_session": synchronize_session}
-                ),
+        result: CursorResult[Any] = self.session.execute(
+            upd,
+            self._params,
+            execution_options=self._execution_options.union(
+                {"synchronize_session": synchronize_session}
             ),
         )
         bulk_ud.result = result  # type: ignore
index 2fbd4fbe517e95116199883fec7cbe1f7b9286ac..fc144d98c4e0172537f20424094aa7b7c14a3456 100644 (file)
@@ -49,6 +49,7 @@ if TYPE_CHECKING:
     from .session import sessionmaker
     from .session import SessionTransaction
     from ..engine import Connection
+    from ..engine import CursorResult
     from ..engine import Engine
     from ..engine import Result
     from ..engine import Row
@@ -68,6 +69,7 @@ if TYPE_CHECKING:
     from ..sql._typing import _T7
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
+    from ..sql.dml import UpdateBase
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import ForUpdateParameter
@@ -639,6 +641,19 @@ class scoped_session(Generic[_S]):
     ) -> Result[_T]:
         ...
 
+    @overload
+    def execute(
+        self,
+        statement: UpdateBase,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> CursorResult[Any]:
+        ...
+
     @overload
     def execute(
         self,
index fade4c48847fc174e64262c0155bdf6acebebc7d..e5eb5036dd738b502fa89c7cd3548724c5620196 100644 (file)
@@ -101,6 +101,7 @@ if typing.TYPE_CHECKING:
     from .mapper import Mapper
     from .path_registry import PathRegistry
     from .query import RowReturningQuery
+    from ..engine import CursorResult
     from ..engine import Result
     from ..engine import Row
     from ..engine import RowMapping
@@ -125,6 +126,7 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
     from ..sql.base import ExecutableOption
+    from ..sql.dml import UpdateBase
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import ForUpdateParameter
@@ -2170,6 +2172,19 @@ class Session(_SessionClassMethods, EventTarget):
     ) -> Result[_T]:
         ...
 
+    @overload
+    def execute(
+        self,
+        statement: UpdateBase,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> CursorResult[Any]:
+        ...
+
     @overload
     def execute(
         self,
index b4f6fc71cfdbc956c951ca755e7039a0591d2dbd..f83c4b47714a861b42a6a7ddfeb5f70b7f3f30b6 100644 (file)
@@ -234,7 +234,7 @@ _DMLColumnArgument = Union[
     str,
     _HasClauseElement,
     roles.DMLColumnRole,
-    "SQLCoreOperations",
+    "SQLCoreOperations[Any]",
 ]
 """A DML column expression.  This is a "key" inside of insert().values(),
 update().values(), and related.
index 1c86c1669e85bd220d0bb15459f2cc7ed453eea6..75857fcfc445d4fb529eb5b1ad06d815dd524d81 100644 (file)
@@ -123,11 +123,38 @@ _NT = TypeVar("_NT", bound="_NUMERIC")
 _NMT = TypeVar("_NMT", bound="_NUMBER")
 
 
+@overload
 def literal(
     value: Any,
-    type_: Optional[_TypeEngineArgument[_T]] = None,
+    type_: _TypeEngineArgument[_T],
     literal_execute: bool = False,
 ) -> BindParameter[_T]:
+    ...
+
+
+@overload
+def literal(
+    value: _T,
+    type_: None = None,
+    literal_execute: bool = False,
+) -> BindParameter[_T]:
+    ...
+
+
+@overload
+def literal(
+    value: Any,
+    type_: Optional[_TypeEngineArgument[Any]] = None,
+    literal_execute: bool = False,
+) -> BindParameter[Any]:
+    ...
+
+
+def literal(
+    value: Any,
+    type_: Optional[_TypeEngineArgument[Any]] = None,
+    literal_execute: bool = False,
+) -> BindParameter[Any]:
     r"""Return a literal clause, bound to a bind parameter.
 
     Literal clauses are created automatically when non-
@@ -799,14 +826,37 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         ) -> ColumnElement[Any]:
             ...
 
+        @overload
+        def op(
+            self,
+            opstring: str,
+            precedence: int = ...,
+            is_comparison: bool = ...,
+            *,
+            return_type: _TypeEngineArgument[_OPT],
+            python_impl: Optional[Callable[..., Any]] = None,
+        ) -> Callable[[Any], BinaryExpression[_OPT]]:
+            ...
+
+        @overload
+        def op(
+            self,
+            opstring: str,
+            precedence: int = ...,
+            is_comparison: bool = ...,
+            return_type: Optional[_TypeEngineArgument[Any]] = ...,
+            python_impl: Optional[Callable[..., Any]] = ...,
+        ) -> Callable[[Any], BinaryExpression[Any]]:
+            ...
+
         def op(
             self,
             opstring: str,
             precedence: int = 0,
             is_comparison: bool = False,
-            return_type: Optional[_TypeEngineArgument[_OPT]] = None,
+            return_type: Optional[_TypeEngineArgument[Any]] = None,
             python_impl: Optional[Callable[..., Any]] = None,
-        ) -> Callable[[Any], BinaryExpression[_OPT]]:
+        ) -> Callable[[Any], BinaryExpression[Any]]:
             ...
 
         def bool_op(
@@ -1078,6 +1128,10 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         ) -> ColumnElement[str]:
             ...
 
+        @overload
+        def __add__(self, other: Any) -> ColumnElement[Any]:
+            ...
+
         def __add__(self, other: Any) -> ColumnElement[Any]:
             ...
 
index c3468d95280bc45cc4ecbd9729923213b20ded28..530e5f670fda18f54dcbec54385fe99e0596c40a 100644 (file)
@@ -440,6 +440,33 @@ def t_dml_insert() -> None:
     reveal_type(r3)
 
 
+def t_dml_bare_insert() -> None:
+    s1 = insert(User)
+    r1 = session.execute(s1)
+    # EXPECTED_TYPE: CursorResult[Any]
+    reveal_type(r1)
+    # EXPECTED_TYPE: int
+    reveal_type(r1.rowcount)
+
+
+def t_dml_bare_update() -> None:
+    s1 = update(User)
+    r1 = session.execute(s1)
+    # EXPECTED_TYPE: CursorResult[Any]
+    reveal_type(r1)
+    # EXPECTED_TYPE: int
+    reveal_type(r1.rowcount)
+
+
+def t_dml_bare_delete() -> None:
+    s1 = delete(User)
+    r1 = session.execute(s1)
+    # EXPECTED_TYPE: CursorResult[Any]
+    reveal_type(r1)
+    # EXPECTED_TYPE: int
+    reveal_type(r1.rowcount)
+
+
 def t_dml_update() -> None:
     s1 = update(User).returning(User.id, User.name)
 
index fd90e31a11c1a01da6e0529b2d332898cc342dfc..1152a04b1731f8e67aba165ced139420a19b8c41 100644 (file)
@@ -13,6 +13,7 @@ from sqlalchemy import asc
 from sqlalchemy import Column
 from sqlalchemy import desc
 from sqlalchemy import Integer
+from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import SQLColumnExpression
@@ -138,3 +139,24 @@ s9174_5 = select(user_table).with_for_update(of=user_table.c.id)
 s9174_6 = select(user_table).with_for_update(
     of=[user_table.c.id, user_table.c.email]
 )
+
+# with_for_update but for query
+session = Session()
+user = session.query(User).with_for_update(of=User)
+user = session.query(User).with_for_update(of=User.id)
+user = session.query(User).with_for_update(of=[User.id, User.email])
+user = session.query(user_table).with_for_update(of=user_table)
+user = session.query(user_table).with_for_update(of=user_table.c.id)
+user = session.query(user_table).with_for_update(
+    of=[user_table.c.id, user_table.c.email]
+)
+
+# literal
+# EXPECTED_TYPE: BindParameter[str]
+reveal_type(literal("5"))
+# EXPECTED_TYPE: BindParameter[str]
+reveal_type(literal("5", None))
+# EXPECTED_TYPE: BindParameter[int]
+reveal_type(literal("123", Integer))
+# EXPECTED_TYPE: BindParameter[int]
+reveal_type(literal("123", Integer))
index 8258ec65b1fe10cb4cedfe9282ff6845f251f894..2e2f31df9cf5fd3942b5725743c27508c9ea972c 100644 (file)
@@ -140,6 +140,11 @@ op_d: "ColumnElement[int]" = col.op("&", return_type=BigInteger)("1")
 op_e: "ColumnElement[bool]" = col.bool_op("&")("1")
 
 
+op_a1 = col.op("&")(1)
+# EXPECTED_TYPE: BinaryExpression[Any]
+reveal_type(op_a1)
+
+
 # op functions
 t1 = operators.eq(A.id, 1)
 select().where(t1)