]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep484 for hybrid
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Mar 2022 16:07:25 +0000 (12:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 Mar 2022 13:42:29 +0000 (09:42 -0400)
Change-Id: I53274b13094d996e11b04acb03f9613edbddf87f
References: #6810

13 files changed:
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/sqltypes.py
pyproject.toml
test/ext/mypy/plain_files/hybrid_one.py [new file with mode: 0644]
test/ext/mypy/plain_files/hybrid_two.py [new file with mode: 0644]
test/ext/mypy/plain_files/sql_operations.py

index dc34a2ef58b333cad01b65523753258a97ec3f46..92b3ce54f7a8a975057991fab61c9963aa990a7e 100644 (file)
@@ -802,17 +802,41 @@ advanced and/or patient developers, there's probably a whole lot of amazing
 things it can be used for.
 
 """  # noqa
+
+from __future__ import annotations
+
 from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 
 from .. import util
 from ..orm import attributes
 from ..orm import InspectionAttrExtensionType
 from ..orm import interfaces
 from ..orm import ORMDescriptor
+from ..sql._typing import is_has_column_element_clause_element
+from ..sql.elements import ColumnElement
+from ..sql.elements import SQLCoreOperations
+from ..util.typing import Literal
+from ..util.typing import Protocol
 
+if TYPE_CHECKING:
+    from ..orm.util import AliasedInsp
+    from ..sql.operators import OperatorType
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
 
 
 class HybridExtensionType(InspectionAttrExtensionType):
@@ -844,7 +868,34 @@ class HybridExtensionType(InspectionAttrExtensionType):
     """
 
 
-class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
+class _HybridGetterType(Protocol[_T_co]):
+    def __call__(s, self: Any) -> _T_co:
+        ...
+
+
+class _HybridSetterType(Protocol[_T_con]):
+    def __call__(self, instance: Any, value: _T_con) -> None:
+        ...
+
+
+class _HybridUpdaterType(Protocol[_T]):
+    def __call__(
+        self, cls: Type[Any], value: Union[_T, SQLCoreOperations[_T]]
+    ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
+        ...
+
+
+class _HybridDeleterType(Protocol[_T_co]):
+    def __call__(self, instance: Any) -> None:
+        ...
+
+
+class _HybridExprCallableType(Protocol[_T]):
+    def __call__(self, cls: Any) -> SQLCoreOperations[_T]:
+        ...
+
+
+class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
     """A decorator which allows definition of a Python object method with both
     instance-level and class-level behavior.
 
@@ -853,7 +904,11 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
     is_attribute = True
     extension_type = HybridExtensionType.HYBRID_METHOD
 
-    def __init__(self, func, expr=None):
+    def __init__(
+        self,
+        func: Callable[..., _T],
+        expr: Optional[Callable[..., SQLCoreOperations[_T]]] = None,
+    ):
         """Create a new :class:`.hybrid_method`.
 
         Usage is typically via decorator::
@@ -873,13 +928,29 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
         self.func = func
         self.expression(expr or func)
 
-    def __get__(self, instance, owner):
+    @overload
+    def __get__(
+        self, instance: Literal[None], owner: Type[object]
+    ) -> Callable[[Any], SQLCoreOperations[_T]]:
+        ...
+
+    @overload
+    def __get__(
+        self, instance: object, owner: Type[object]
+    ) -> Callable[[Any], _T]:
+        ...
+
+    def __get__(
+        self, instance: Optional[object], owner: Type[object]
+    ) -> Union[Callable[[Any], _T], Callable[[Any], SQLCoreOperations[_T]]]:
         if instance is None:
-            return self.expr.__get__(owner, owner.__class__)
+            return self.expr.__get__(owner, owner)  # type: ignore
         else:
-            return self.func.__get__(instance, owner)
+            return self.func.__get__(instance, owner)  # type: ignore
 
-    def expression(self, expr):
+    def expression(
+        self, expr: Callable[..., SQLCoreOperations[_T]]
+    ) -> hybrid_method[_T]:
         """Provide a modifying decorator that defines a
         SQL-expression producing method."""
 
@@ -889,7 +960,12 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
         return self
 
 
-class hybrid_property(interfaces.InspectionAttrInfo):
+Selfhybrid_property = TypeVar(
+    "Selfhybrid_property", bound="hybrid_property[Any]"
+)
+
+
+class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
     """A decorator which allows definition of a Python descriptor with both
     instance-level and class-level behavior.
 
@@ -898,14 +974,16 @@ class hybrid_property(interfaces.InspectionAttrInfo):
     is_attribute = True
     extension_type = HybridExtensionType.HYBRID_PROPERTY
 
+    __name__: str
+
     def __init__(
         self,
-        fget,
-        fset=None,
-        fdel=None,
-        expr=None,
-        custom_comparator=None,
-        update_expr=None,
+        fget: _HybridGetterType[_T],
+        fset: Optional[_HybridSetterType[_T]] = None,
+        fdel: Optional[_HybridDeleterType[_T]] = None,
+        expr: Optional[_HybridExprCallableType[_T]] = None,
+        custom_comparator: Optional[Comparator[_T]] = None,
+        update_expr: Optional[_HybridUpdaterType[_T]] = None,
     ):
         """Create a new :class:`.hybrid_property`.
 
@@ -931,23 +1009,43 @@ class hybrid_property(interfaces.InspectionAttrInfo):
         self.update_expr = update_expr
         util.update_wrapper(self, fget)
 
-    def __get__(self, instance, owner):
-        if instance is None:
+    @overload
+    def __get__(
+        self: Selfhybrid_property, instance: Any, owner: Literal[None]
+    ) -> Selfhybrid_property:
+        ...
+
+    @overload
+    def __get__(
+        self, instance: Literal[None], owner: Type[object]
+    ) -> SQLCoreOperations[_T]:
+        ...
+
+    @overload
+    def __get__(self, instance: object, owner: Type[object]) -> _T:
+        ...
+
+    def __get__(
+        self, instance: Optional[object], owner: Optional[Type[object]]
+    ) -> Union[hybrid_property[_T], SQLCoreOperations[_T], _T]:
+        if owner is None:
+            return self
+        elif instance is None:
             return self._expr_comparator(owner)
         else:
             return self.fget(instance)
 
-    def __set__(self, instance, value):
+    def __set__(self, instance: object, value: Any) -> None:
         if self.fset is None:
             raise AttributeError("can't set attribute")
         self.fset(instance, value)
 
-    def __delete__(self, instance):
+    def __delete__(self, instance: object) -> None:
         if self.fdel is None:
             raise AttributeError("can't delete attribute")
         self.fdel(instance)
 
-    def _copy(self, **kw):
+    def _copy(self, **kw: Any) -> hybrid_property[_T]:
         defaults = {
             key: value
             for key, value in self.__dict__.items()
@@ -957,7 +1055,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
         return type(self)(**defaults)
 
     @property
-    def overrides(self):
+    def overrides(self: Selfhybrid_property) -> Selfhybrid_property:
         """Prefix for a method that is overriding an existing attribute.
 
         The :attr:`.hybrid_property.overrides` accessor just returns
@@ -992,7 +1090,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
         """
         return self
 
-    def getter(self, fget):
+    def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]:
         """Provide a modifying decorator that defines a getter method.
 
         .. versionadded:: 1.2
@@ -1001,17 +1099,19 @@ class hybrid_property(interfaces.InspectionAttrInfo):
 
         return self._copy(fget=fget)
 
-    def setter(self, fset):
+    def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]:
         """Provide a modifying decorator that defines a setter method."""
 
         return self._copy(fset=fset)
 
-    def deleter(self, fdel):
+    def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]:
         """Provide a modifying decorator that defines a deletion method."""
 
         return self._copy(fdel=fdel)
 
-    def expression(self, expr):
+    def expression(
+        self, expr: _HybridExprCallableType[_T]
+    ) -> hybrid_property[_T]:
         """Provide a modifying decorator that defines a SQL-expression
         producing method.
 
@@ -1043,7 +1143,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
 
         return self._copy(expr=expr)
 
-    def comparator(self, comparator):
+    def comparator(self, comparator: Comparator[_T]) -> hybrid_property[_T]:
         """Provide a modifying decorator that defines a custom
         comparator producing method.
 
@@ -1078,7 +1178,9 @@ class hybrid_property(interfaces.InspectionAttrInfo):
         """
         return self._copy(custom_comparator=comparator)
 
-    def update_expression(self, meth):
+    def update_expression(
+        self, meth: _HybridUpdaterType[_T]
+    ) -> hybrid_property[_T]:
         """Provide a modifying decorator that defines an UPDATE tuple
         producing method.
 
@@ -1115,27 +1217,35 @@ class hybrid_property(interfaces.InspectionAttrInfo):
         return self._copy(update_expr=meth)
 
     @util.memoized_property
-    def _expr_comparator(self):
+    def _expr_comparator(
+        self,
+    ) -> Callable[[Any], interfaces.PropComparator[_T]]:
         if self.custom_comparator is not None:
             return self._get_comparator(self.custom_comparator)
         elif self.expr is not None:
             return self._get_expr(self.expr)
         else:
-            return self._get_expr(self.fget)
+            return self._get_expr(cast(_HybridExprCallableType[_T], self.fget))
 
-    def _get_expr(self, expr):
-        def _expr(cls):
+    def _get_expr(
+        self, expr: _HybridExprCallableType[_T]
+    ) -> Callable[[Any], interfaces.PropComparator[_T]]:
+        def _expr(cls: Any) -> ExprComparator[_T]:
             return ExprComparator(cls, expr(cls), self)
 
         util.update_wrapper(_expr, expr)
 
         return self._get_comparator(_expr)
 
-    def _get_comparator(self, comparator):
+    def _get_comparator(
+        self, comparator: Any
+    ) -> Callable[[Any], interfaces.PropComparator[_T]]:
 
         proxy_attr = attributes.create_proxied_attribute(self)
 
-        def expr_comparator(owner):
+        def expr_comparator(
+            owner: Type[object],
+        ) -> interfaces.PropComparator[_T]:
             # because this is the descriptor protocol, we don't really know
             # what our attribute name is.  so search for it through the
             # MRO.
@@ -1163,36 +1273,48 @@ class Comparator(interfaces.PropComparator[_T]):
     :class:`~.orm.interfaces.PropComparator`
     classes for usage with hybrids."""
 
-    property = None
-
-    def __init__(self, expression):
+    def __init__(self, expression: SQLCoreOperations[_T]):
         self.expression = expression
 
-    def __clause_element__(self):
+    def __clause_element__(self) -> ColumnElement[_T]:
         expr = self.expression
-        if hasattr(expr, "__clause_element__"):
+        if is_has_column_element_clause_element(expr):
             expr = expr.__clause_element__()
+
+        elif TYPE_CHECKING:
+            assert isinstance(expr, ColumnElement)
         return expr
 
-    def adapt_to_entity(self, adapt_to_entity):
+    @util.non_memoized_property
+    def property(self) -> Any:
+        return None
+
+    def adapt_to_entity(self, adapt_to_entity: AliasedInsp) -> Comparator[_T]:
         # interesting....
         return self
 
 
 class ExprComparator(Comparator[_T]):
-    def __init__(self, cls, expression, hybrid):
+    def __init__(
+        self,
+        cls: Type[Any],
+        expression: SQLCoreOperations[_T],
+        hybrid: hybrid_property[_T],
+    ):
         self.cls = cls
         self.expression = expression
         self.hybrid = hybrid
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Any:
         return getattr(self.expression, key)
 
-    @property
-    def info(self):
+    @util.non_memoized_property
+    def info(self) -> Dict[Any, Any]:
         return self.hybrid.info
 
-    def _bulk_update_tuples(self, value):
+    def _bulk_update_tuples(
+        self, value: Any
+    ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
         if isinstance(self.expression, attributes.QueryableAttribute):
             return self.expression._bulk_update_tuples(value)
         elif self.hybrid.update_expr is not None:
@@ -1200,12 +1322,16 @@ class ExprComparator(Comparator[_T]):
         else:
             return [(self.expression, value)]
 
-    @property
-    def property(self):
-        return self.expression.property
+    @util.non_memoized_property
+    def property(self) -> Any:
+        return self.expression.property  # type: ignore
 
-    def operate(self, op, *other, **kwargs):
-        return op(self.expression, *other, **kwargs)
+    def operate(
+        self, op: OperatorType, *other: Any, **kwargs: Any
+    ) -> ColumnElement[Any]:
+        return op(self.expression, *other, **kwargs)  # type: ignore
 
-    def reverse_operate(self, op, other, **kwargs):
-        return op(other, self.expression, **kwargs)
+    def reverse_operate(
+        self, op: OperatorType, other: Any, **kwargs: Any
+    ) -> ColumnElement[Any]:
+        return op(other, self.expression, **kwargs)  # type: ignore
index ce3a645adbe55dd0a4c0f2eb56a34cdd4953133d..2b6ca400e94b830aed4eb4a64382940c6a072802 100644 (file)
@@ -20,6 +20,7 @@ from collections import namedtuple
 import operator
 import typing
 from typing import Any
+from typing import Callable
 from typing import List
 from typing import NamedTuple
 from typing import Tuple
@@ -68,6 +69,7 @@ from ..sql import visitors
 
 if typing.TYPE_CHECKING:
     from ..sql.elements import ColumnElement
+    from ..sql.elements import SQLCoreOperations
 
 _T = TypeVar("_T")
 
@@ -277,7 +279,9 @@ class QueryableAttribute(
     def _from_objects(self):
         return self.expression._from_objects
 
-    def _bulk_update_tuples(self, value):
+    def _bulk_update_tuples(
+        self, value: Any
+    ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
         """Return setter tuples for a bulk UPDATE."""
 
         return self.comparator._bulk_update_tuples(value)
@@ -416,7 +420,9 @@ HasEntityNamespace = namedtuple("HasEntityNamespace", ["entity_namespace"])
 HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False
 
 
-def create_proxied_attribute(descriptor):
+def create_proxied_attribute(
+    descriptor: Any,
+) -> Callable[..., QueryableAttribute[Any]]:
     """Create an QueryableAttribute / user descriptor hybrid.
 
     Returns a new QueryableAttribute type that delegates descriptor
index c63a89c70427b6973bc2932a3c51be27cc64a22e..cb30701037418555f20dbb1e4e0bdef23ee876dc 100644 (file)
@@ -638,7 +638,7 @@ class ORMDescriptor(Generic[_T], TypingOnly):
         @overload
         def __get__(
             self, instance: Literal[None], owner: Any
-        ) -> SQLORMOperations[_T]:
+        ) -> SQLCoreOperations[_T]:
             ...
 
         @overload
@@ -647,7 +647,7 @@ class ORMDescriptor(Generic[_T], TypingOnly):
 
         def __get__(
             self, instance: object, owner: Any
-        ) -> Union[SQLORMOperations[_T], _T]:
+        ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]:
             ...
 
 
index 00ddbcca72f463d776e7d900531ecc973c192197..d7977418732a86fbbc97e50effef3a6fe954257a 100644 (file)
@@ -466,7 +466,9 @@ class PropComparator(SQLORMOperations[_T]):
     def __clause_element__(self):
         raise NotImplementedError("%r" % self)
 
-    def _bulk_update_tuples(self, value):
+    def _bulk_update_tuples(
+        self, value: Any
+    ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
         """Receive a SQL expression that represents a value in the SET
         clause of an UPDATE statement.
 
index 389f7e8d0048496629383267dd38e543ae6704ca..2be98b88fef98d3e2f2841618f93ab0a5457f584 100644 (file)
@@ -56,3 +56,13 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
 
 def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
     return t._is_tuple_type
+
+
+def is_has_clause_element(s: object) -> TypeGuard[roles.HasClauseElement]:
+    return hasattr(s, "__clause_element__")
+
+
+def is_has_column_element_clause_element(
+    s: object,
+) -> TypeGuard[roles.HasColumnElementClauseElement]:
+    return hasattr(s, "__clause_element__")
index db88496a09c6b53e34419d13eca0ce2ebbf72958..8f878b66c08a5614e502920d94e5dc04909d4f26 100644 (file)
@@ -265,7 +265,7 @@ OPERATORS = {
     operators.nulls_last_op: " NULLS LAST",
 }
 
-FUNCTIONS: Dict[Type[Function], str] = {
+FUNCTIONS: Dict[Type[Function[Any]], str] = {
     functions.coalesce: "coalesce",
     functions.current_date: "CURRENT_DATE",
     functions.current_time: "CURRENT_TIME",
@@ -2043,7 +2043,7 @@ class SQLCompiler(Compiled):
 
     def visit_function(
         self,
-        func: Function,
+        func: Function[Any],
         add_to_result_map: Optional[_ResultMapAppender] = None,
         **kwargs: Any,
     ) -> str:
index fdb3fc8bbf892542ae388a887a86ea1745b7457b..c1a7d847628c2995cc6206703f2dba6ff6b8142f 100644 (file)
@@ -103,8 +103,8 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import CacheStats
     from ..engine.result import Result
 
-_NUMERIC = Union[complex, Decimal]
-_NUMBER = Union[complex, int, Decimal]
+_NUMERIC = Union[float, Decimal]
+_NUMBER = Union[float, int, Decimal]
 
 _T = TypeVar("_T", bound="Any")
 _OPT = TypeVar("_OPT", bound="Any")
@@ -348,6 +348,7 @@ class ClauseElement(
         the _copy_internals() method.
 
         """
+
         skip = self._memoized_keys
         c = self.__class__.__new__(self.__class__)
         c.__dict__ = {k: v for k, v in self.__dict__.items() if k not in skip}
@@ -995,10 +996,14 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
 
         @overload
         def __truediv__(
-            self: _SQO[_NMT], other: Any
+            self: _SQO[int], other: Any
         ) -> ColumnElement[_NUMERIC]:
             ...
 
+        @overload
+        def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]:
+            ...
+
         @overload
         def __truediv__(self, other: Any) -> ColumnElement[Any]:
             ...
index 6e5eec12712934021ff847608dbbca8866ab621c..563b58418b7d59858d9bc7a30347eab508992ca7 100644 (file)
 from __future__ import annotations
 
 from typing import Any
+from typing import Optional
+from typing import overload
 from typing import Sequence
+from typing import TYPE_CHECKING
 from typing import TypeVar
 
 from . import annotation
@@ -47,6 +50,8 @@ from .type_api import TypeEngine
 from .visitors import InternalTraversal
 from .. import util
 
+if TYPE_CHECKING:
+    from ._typing import _TypeEngineArgument
 
 _T = TypeVar("_T", bound=Any)
 
@@ -104,7 +109,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
     _with_ordinality = False
     _table_value_type = None
 
-    def __init__(self, *clauses, **kwargs):
+    def __init__(self, *clauses: Any):
         r"""Construct a :class:`.FunctionElement`.
 
         :param \*clauses: list of column expressions that form the arguments
@@ -752,7 +757,7 @@ class _FunctionGenerator:
         self.__names = []
         self.opts = opts
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> _FunctionGenerator:
         # passthru __ attributes; fixes pydoc
         if name.startswith("__"):
             try:
@@ -766,7 +771,17 @@ class _FunctionGenerator:
         f.__names = list(self.__names) + [name]
         return f
 
-    def __call__(self, *c, **kwargs):
+    @overload
+    def __call__(
+        self, *c: Any, type_: TypeEngine[_T], **kwargs: Any
+    ) -> Function[_T]:
+        ...
+
+    @overload
+    def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]:
+        ...
+
+    def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]:
         o = self.opts.copy()
         o.update(kwargs)
 
@@ -795,7 +810,7 @@ func.__doc__ = _FunctionGenerator.__doc__
 modifier = _FunctionGenerator(group=False)
 
 
-class Function(FunctionElement):
+class Function(FunctionElement[_T]):
     r"""Describe a named SQL function.
 
     The :class:`.Function` object is typically generated from the
@@ -842,7 +857,7 @@ class Function(FunctionElement):
 
     packagenames: Sequence[str]
 
-    type: TypeEngine = sqltypes.NULLTYPE
+    type: TypeEngine[_T]
     """A :class:`_types.TypeEngine` object which refers to the SQL return
     type represented by this SQL function.
 
@@ -859,19 +874,25 @@ class Function(FunctionElement):
 
     """
 
-    def __init__(self, name, *clauses, **kw):
+    def __init__(
+        self,
+        name: str,
+        *clauses: Any,
+        type_: Optional[_TypeEngineArgument[_T]] = None,
+        packagenames: Optional[Sequence[str]] = None,
+    ):
         """Construct a :class:`.Function`.
 
         The :data:`.func` construct is normally used to construct
         new :class:`.Function` instances.
 
         """
-        self.packagenames = kw.pop("packagenames", None) or ()
+        self.packagenames = packagenames or ()
         self.name = name
 
-        self.type = sqltypes.to_instance(kw.get("type_", None))
+        self.type = sqltypes.to_instance(type_)
 
-        FunctionElement.__init__(self, *clauses, **kw)
+        FunctionElement.__init__(self, *clauses)
 
     def _bind_param(self, operator, obj, type_=None, **kw):
         return BindParameter(
index e64ec0843f41c85ba168cca6ce445d72aab20f64..4d0169370c4924db2ab0ca1f6b1cc0f9c1817bbd 100644 (file)
@@ -17,6 +17,7 @@ import enum
 import json
 import pickle
 from typing import Any
+from typing import overload
 from typing import Sequence
 from typing import Tuple
 from typing import TypeVar
@@ -48,6 +49,7 @@ from .. import util
 from ..engine import processors
 from ..util import langhelpers
 from ..util import OrderedDict
+from ..util.typing import Literal
 
 
 _T = TypeVar("_T", bound="Any")
@@ -373,9 +375,10 @@ class BigInteger(Integer):
     __visit_name__ = "big_integer"
 
 
-class Numeric(
-    _LookupExpressionAdapter, TypeEngine[Union[decimal.Decimal, float]]
-):
+_N = TypeVar("_N", bound=Union[decimal.Decimal, float])
+
+
+class Numeric(_LookupExpressionAdapter, TypeEngine[_N]):
 
     """A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``.
 
@@ -542,7 +545,7 @@ class Numeric(
         }
 
 
-class Float(Numeric):
+class Float(Numeric[_N]):
 
     """Type representing floating point types, such as ``FLOAT`` or ``REAL``.
 
@@ -567,8 +570,34 @@ class Float(Numeric):
 
     scale = None
 
+    @overload
+    def __init__(
+        self: Float[float],
+        precision=...,
+        decimal_return_scale=...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: Float[decimal.Decimal],
+        precision=...,
+        asdecimal: Literal[True] = ...,
+        decimal_return_scale=...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: Float[float],
+        precision=...,
+        asdecimal: Literal[False] = ...,
+        decimal_return_scale=...,
+    ):
+        ...
+
     def __init__(
-        self: "Float",
+        self: Float[_N],
         precision=None,
         asdecimal=False,
         decimal_return_scale=None,
index c4848494df0fb77a3c5eb837df4d2911d7006b85..6cfa8db4695557b38bf12cce9227dca5bae64cf4 100644 (file)
@@ -91,6 +91,7 @@ module = [
     "sqlalchemy.sql._py_util",
     "sqlalchemy.connectors.*",
     "sqlalchemy.engine.*",
+    "sqlalchemy.ext.hybrid",
     "sqlalchemy.ext.associationproxy",
     "sqlalchemy.pool.*",
     "sqlalchemy.event.*",
diff --git a/test/ext/mypy/plain_files/hybrid_one.py b/test/ext/mypy/plain_files/hybrid_one.py
new file mode 100644 (file)
index 0000000..7d97024
--- /dev/null
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+import typing
+
+from sqlalchemy.ext.hybrid import hybrid_method
+from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Interval(Base):
+    __tablename__ = "interval"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    start: Mapped[int]
+    end: Mapped[int]
+
+    def __init__(self, start: int, end: int):
+        self.start = start
+        self.end = end
+
+    @hybrid_property
+    def length(self) -> int:
+        return self.end - self.start
+
+    @hybrid_method
+    def contains(self, point: int) -> int:
+        return (self.start <= point) & (point <= self.end)
+
+    @hybrid_method
+    def intersects(self, other: Interval) -> int:
+        return self.contains(other.start) | self.contains(other.end)
+
+
+i1 = Interval(5, 10)
+i2 = Interval(7, 12)
+
+expr1 = Interval.length.in_([5, 10])
+
+expr2 = Interval.contains(7)
+
+expr3 = Interval.intersects(i2)
+
+if typing.TYPE_CHECKING:
+    # EXPECTED_TYPE: builtins.int\*
+    reveal_type(i1.length)
+
+    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    reveal_type(Interval.length)
+
+    # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
+    reveal_type(expr1)
+
+    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    reveal_type(expr2)
+
+    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    reveal_type(expr3)
diff --git a/test/ext/mypy/plain_files/hybrid_two.py b/test/ext/mypy/plain_files/hybrid_two.py
new file mode 100644 (file)
index 0000000..6bfabbd
--- /dev/null
@@ -0,0 +1,91 @@
+from __future__ import annotations
+
+import typing
+
+from sqlalchemy import Float
+from sqlalchemy import func
+from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.expression import ColumnElement
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Interval(Base):
+    __tablename__ = "interval"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    start: Mapped[int]
+    end: Mapped[int]
+
+    def __init__(self, start: int, end: int):
+        self.start = start
+        self.end = end
+
+    @hybrid_property
+    def length(self) -> int:
+        return self.end - self.start
+
+    # im not sure if there's a way to get typing tools to not complain about
+    # the re-defined name here, it handles it for plain @property
+    # but im not sure if that's hardcoded
+    # see https://github.com/python/typing/discussions/1102
+
+    @hybrid_property
+    def _inst_radius(self) -> float:
+        return abs(self.length) / 2
+
+    @_inst_radius.expression
+    def radius(cls) -> ColumnElement[float]:
+        f1 = func.abs(cls.length, type_=Float())
+
+        expr = f1 / 2
+
+        # while we are here, check some Float[] / div type stuff
+        if typing.TYPE_CHECKING:
+            # EXPECTED_TYPE: sqlalchemy.*Function\[builtins.float\*?\]
+            reveal_type(f1)
+
+            # EXPECTED_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\]
+            reveal_type(expr)
+        return expr
+
+
+i1 = Interval(5, 10)
+i2 = Interval(7, 12)
+
+l1: int = i1.length
+rd: float = i2.radius
+
+expr1 = Interval.length.in_([5, 10])
+
+expr2 = Interval.radius
+
+expr3 = Interval.radius.in_([0.5, 5.2])
+
+
+if typing.TYPE_CHECKING:
+    # EXPECTED_TYPE: builtins.int\*
+    reveal_type(i1.length)
+
+    # EXPECTED_TYPE: builtins.float\*
+    reveal_type(i2.radius)
+
+    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    reveal_type(Interval.length)
+
+    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\]
+    reveal_type(Interval.radius)
+
+    # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
+    reveal_type(expr1)
+
+    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\]
+    reveal_type(expr2)
+
+    # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
+    reveal_type(expr3)
index a3d57aa1ceea788f843fc5caeaf9a2f1037709f1..b7bae0185d22192dad50d946b2b2d15c1e04d982 100644 (file)
@@ -49,7 +49,7 @@ if typing.TYPE_CHECKING:
     # EXPECTED_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\]
     reveal_type(expr2)
 
-    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.complex, decimal.Decimal\]\]
+    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\]
     reveal_type(expr3)
 
     # EXPECTED_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\]