]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Overload TypeEngine.operate() when return_type is specified
authorDenis Laxalde <denis@laxalde.org>
Tue, 4 Mar 2025 09:57:13 +0000 (10:57 +0100)
committerDenis Laxalde <denis@laxalde.org>
Tue, 4 Mar 2025 10:01:31 +0000 (11:01 +0100)
lib/sqlalchemy/sql/type_api.py

index 19b315928afd3a025b9c53f27d2b36d4e4deab1b..bdc56b46ac479e1f3a2a1df5e875a26077370632 100644 (file)
@@ -67,6 +67,7 @@ _T_con = TypeVar("_T_con", bound=Any, contravariant=True)
 _O = TypeVar("_O", bound=object)
 _TE = TypeVar("_TE", bound="TypeEngine[Any]")
 _CT = TypeVar("_CT", bound=Any)
+_RT = TypeVar("_RT", bound=Any)
 
 _MatchedOnType = Union[
     "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any]
@@ -186,10 +187,24 @@ class TypeEngine(Visitable, Generic[_T]):
         def __reduce__(self) -> Any:
             return self.__class__, (self.expr,)
 
+        @overload
+        def operate(
+            self,
+            op: OperatorType,
+            *other: Any,
+            result_type: Type[TypeEngine[_RT]],
+            **kwargs: Any,
+        ) -> ColumnElement[_RT]: ...
+
+        @overload
+        def operate(
+            self, op: OperatorType, *other: Any, **kwargs: Any
+        ) -> ColumnElement[_CT]: ...
+
         @util.preload_module("sqlalchemy.sql.default_comparator")
         def operate(
             self, op: OperatorType, *other: Any, **kwargs: Any
-        ) -> ColumnElement[_CT]:
+        ) -> ColumnElement[Any]:
             default_comparator = util.preloaded.sql_default_comparator
             op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
             if kwargs: