From: Mike Bayer Date: Wed, 16 Mar 2022 16:07:25 +0000 (-0400) Subject: pep484 for hybrid X-Git-Tag: rel_2_0_0b1~423 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3b520e758a715cf817075e4a90ae1b5813ffadd3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep484 for hybrid Change-Id: I53274b13094d996e11b04acb03f9613edbddf87f References: #6810 --- diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index dc34a2ef58..92b3ce54f7 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -802,17 +802,41 @@ advanced and/or patient developers, there's probably a whole lot of amazing things it can be used for. """ # noqa + +from __future__ import annotations + from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from .. import util from ..orm import attributes from ..orm import InspectionAttrExtensionType from ..orm import interfaces from ..orm import ORMDescriptor +from ..sql._typing import is_has_column_element_clause_element +from ..sql.elements import ColumnElement +from ..sql.elements import SQLCoreOperations +from ..util.typing import Literal +from ..util.typing import Protocol +if TYPE_CHECKING: + from ..orm.util import AliasedInsp + from ..sql.operators import OperatorType _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) class HybridExtensionType(InspectionAttrExtensionType): @@ -844,7 +868,34 @@ class HybridExtensionType(InspectionAttrExtensionType): """ -class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): +class _HybridGetterType(Protocol[_T_co]): + def __call__(s, self: Any) -> _T_co: + ... + + +class _HybridSetterType(Protocol[_T_con]): + def __call__(self, instance: Any, value: _T_con) -> None: + ... + + +class _HybridUpdaterType(Protocol[_T]): + def __call__( + self, cls: Type[Any], value: Union[_T, SQLCoreOperations[_T]] + ) -> List[Tuple[SQLCoreOperations[_T], Any]]: + ... + + +class _HybridDeleterType(Protocol[_T_co]): + def __call__(self, instance: Any) -> None: + ... + + +class _HybridExprCallableType(Protocol[_T]): + def __call__(self, cls: Any) -> SQLCoreOperations[_T]: + ... + + +class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]): """A decorator which allows definition of a Python object method with both instance-level and class-level behavior. @@ -853,7 +904,11 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): is_attribute = True extension_type = HybridExtensionType.HYBRID_METHOD - def __init__(self, func, expr=None): + def __init__( + self, + func: Callable[..., _T], + expr: Optional[Callable[..., SQLCoreOperations[_T]]] = None, + ): """Create a new :class:`.hybrid_method`. Usage is typically via decorator:: @@ -873,13 +928,29 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): self.func = func self.expression(expr or func) - def __get__(self, instance, owner): + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> Callable[[Any], SQLCoreOperations[_T]]: + ... + + @overload + def __get__( + self, instance: object, owner: Type[object] + ) -> Callable[[Any], _T]: + ... + + def __get__( + self, instance: Optional[object], owner: Type[object] + ) -> Union[Callable[[Any], _T], Callable[[Any], SQLCoreOperations[_T]]]: if instance is None: - return self.expr.__get__(owner, owner.__class__) + return self.expr.__get__(owner, owner) # type: ignore else: - return self.func.__get__(instance, owner) + return self.func.__get__(instance, owner) # type: ignore - def expression(self, expr): + def expression( + self, expr: Callable[..., SQLCoreOperations[_T]] + ) -> hybrid_method[_T]: """Provide a modifying decorator that defines a SQL-expression producing method.""" @@ -889,7 +960,12 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): return self -class hybrid_property(interfaces.InspectionAttrInfo): +Selfhybrid_property = TypeVar( + "Selfhybrid_property", bound="hybrid_property[Any]" +) + + +class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): """A decorator which allows definition of a Python descriptor with both instance-level and class-level behavior. @@ -898,14 +974,16 @@ class hybrid_property(interfaces.InspectionAttrInfo): is_attribute = True extension_type = HybridExtensionType.HYBRID_PROPERTY + __name__: str + def __init__( self, - fget, - fset=None, - fdel=None, - expr=None, - custom_comparator=None, - update_expr=None, + fget: _HybridGetterType[_T], + fset: Optional[_HybridSetterType[_T]] = None, + fdel: Optional[_HybridDeleterType[_T]] = None, + expr: Optional[_HybridExprCallableType[_T]] = None, + custom_comparator: Optional[Comparator[_T]] = None, + update_expr: Optional[_HybridUpdaterType[_T]] = None, ): """Create a new :class:`.hybrid_property`. @@ -931,23 +1009,43 @@ class hybrid_property(interfaces.InspectionAttrInfo): self.update_expr = update_expr util.update_wrapper(self, fget) - def __get__(self, instance, owner): - if instance is None: + @overload + def __get__( + self: Selfhybrid_property, instance: Any, owner: Literal[None] + ) -> Selfhybrid_property: + ... + + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> SQLCoreOperations[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Type[object]) -> _T: + ... + + def __get__( + self, instance: Optional[object], owner: Optional[Type[object]] + ) -> Union[hybrid_property[_T], SQLCoreOperations[_T], _T]: + if owner is None: + return self + elif instance is None: return self._expr_comparator(owner) else: return self.fget(instance) - def __set__(self, instance, value): + def __set__(self, instance: object, value: Any) -> None: if self.fset is None: raise AttributeError("can't set attribute") self.fset(instance, value) - def __delete__(self, instance): + def __delete__(self, instance: object) -> None: if self.fdel is None: raise AttributeError("can't delete attribute") self.fdel(instance) - def _copy(self, **kw): + def _copy(self, **kw: Any) -> hybrid_property[_T]: defaults = { key: value for key, value in self.__dict__.items() @@ -957,7 +1055,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): return type(self)(**defaults) @property - def overrides(self): + def overrides(self: Selfhybrid_property) -> Selfhybrid_property: """Prefix for a method that is overriding an existing attribute. The :attr:`.hybrid_property.overrides` accessor just returns @@ -992,7 +1090,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): """ return self - def getter(self, fget): + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a getter method. .. versionadded:: 1.2 @@ -1001,17 +1099,19 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._copy(fget=fget) - def setter(self, fset): + def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a setter method.""" return self._copy(fset=fset) - def deleter(self, fdel): + def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a deletion method.""" return self._copy(fdel=fdel) - def expression(self, expr): + def expression( + self, expr: _HybridExprCallableType[_T] + ) -> hybrid_property[_T]: """Provide a modifying decorator that defines a SQL-expression producing method. @@ -1043,7 +1143,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._copy(expr=expr) - def comparator(self, comparator): + def comparator(self, comparator: Comparator[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a custom comparator producing method. @@ -1078,7 +1178,9 @@ class hybrid_property(interfaces.InspectionAttrInfo): """ return self._copy(custom_comparator=comparator) - def update_expression(self, meth): + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: """Provide a modifying decorator that defines an UPDATE tuple producing method. @@ -1115,27 +1217,35 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._copy(update_expr=meth) @util.memoized_property - def _expr_comparator(self): + def _expr_comparator( + self, + ) -> Callable[[Any], interfaces.PropComparator[_T]]: if self.custom_comparator is not None: return self._get_comparator(self.custom_comparator) elif self.expr is not None: return self._get_expr(self.expr) else: - return self._get_expr(self.fget) + return self._get_expr(cast(_HybridExprCallableType[_T], self.fget)) - def _get_expr(self, expr): - def _expr(cls): + def _get_expr( + self, expr: _HybridExprCallableType[_T] + ) -> Callable[[Any], interfaces.PropComparator[_T]]: + def _expr(cls: Any) -> ExprComparator[_T]: return ExprComparator(cls, expr(cls), self) util.update_wrapper(_expr, expr) return self._get_comparator(_expr) - def _get_comparator(self, comparator): + def _get_comparator( + self, comparator: Any + ) -> Callable[[Any], interfaces.PropComparator[_T]]: proxy_attr = attributes.create_proxied_attribute(self) - def expr_comparator(owner): + def expr_comparator( + owner: Type[object], + ) -> interfaces.PropComparator[_T]: # because this is the descriptor protocol, we don't really know # what our attribute name is. so search for it through the # MRO. @@ -1163,36 +1273,48 @@ class Comparator(interfaces.PropComparator[_T]): :class:`~.orm.interfaces.PropComparator` classes for usage with hybrids.""" - property = None - - def __init__(self, expression): + def __init__(self, expression: SQLCoreOperations[_T]): self.expression = expression - def __clause_element__(self): + def __clause_element__(self) -> ColumnElement[_T]: expr = self.expression - if hasattr(expr, "__clause_element__"): + if is_has_column_element_clause_element(expr): expr = expr.__clause_element__() + + elif TYPE_CHECKING: + assert isinstance(expr, ColumnElement) return expr - def adapt_to_entity(self, adapt_to_entity): + @util.non_memoized_property + def property(self) -> Any: + return None + + def adapt_to_entity(self, adapt_to_entity: AliasedInsp) -> Comparator[_T]: # interesting.... return self class ExprComparator(Comparator[_T]): - def __init__(self, cls, expression, hybrid): + def __init__( + self, + cls: Type[Any], + expression: SQLCoreOperations[_T], + hybrid: hybrid_property[_T], + ): self.cls = cls self.expression = expression self.hybrid = hybrid - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: return getattr(self.expression, key) - @property - def info(self): + @util.non_memoized_property + def info(self) -> Dict[Any, Any]: return self.hybrid.info - def _bulk_update_tuples(self, value): + def _bulk_update_tuples( + self, value: Any + ) -> List[Tuple[SQLCoreOperations[_T], Any]]: if isinstance(self.expression, attributes.QueryableAttribute): return self.expression._bulk_update_tuples(value) elif self.hybrid.update_expr is not None: @@ -1200,12 +1322,16 @@ class ExprComparator(Comparator[_T]): else: return [(self.expression, value)] - @property - def property(self): - return self.expression.property + @util.non_memoized_property + def property(self) -> Any: + return self.expression.property # type: ignore - def operate(self, op, *other, **kwargs): - return op(self.expression, *other, **kwargs) + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.expression, *other, **kwargs) # type: ignore - def reverse_operate(self, op, other, **kwargs): - return op(other, self.expression, **kwargs) + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.expression, **kwargs) # type: ignore diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index ce3a645adb..2b6ca400e9 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -20,6 +20,7 @@ from collections import namedtuple import operator import typing from typing import Any +from typing import Callable from typing import List from typing import NamedTuple from typing import Tuple @@ -68,6 +69,7 @@ from ..sql import visitors if typing.TYPE_CHECKING: from ..sql.elements import ColumnElement + from ..sql.elements import SQLCoreOperations _T = TypeVar("_T") @@ -277,7 +279,9 @@ class QueryableAttribute( def _from_objects(self): return self.expression._from_objects - def _bulk_update_tuples(self, value): + def _bulk_update_tuples( + self, value: Any + ) -> List[Tuple[SQLCoreOperations[_T], Any]]: """Return setter tuples for a bulk UPDATE.""" return self.comparator._bulk_update_tuples(value) @@ -416,7 +420,9 @@ HasEntityNamespace = namedtuple("HasEntityNamespace", ["entity_namespace"]) HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False -def create_proxied_attribute(descriptor): +def create_proxied_attribute( + descriptor: Any, +) -> Callable[..., QueryableAttribute[Any]]: """Create an QueryableAttribute / user descriptor hybrid. Returns a new QueryableAttribute type that delegates descriptor diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index c63a89c704..cb30701037 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -638,7 +638,7 @@ class ORMDescriptor(Generic[_T], TypingOnly): @overload def __get__( self, instance: Literal[None], owner: Any - ) -> SQLORMOperations[_T]: + ) -> SQLCoreOperations[_T]: ... @overload @@ -647,7 +647,7 @@ class ORMDescriptor(Generic[_T], TypingOnly): def __get__( self, instance: object, owner: Any - ) -> Union[SQLORMOperations[_T], _T]: + ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]: ... diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 00ddbcca72..d797741873 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -466,7 +466,9 @@ class PropComparator(SQLORMOperations[_T]): def __clause_element__(self): raise NotImplementedError("%r" % self) - def _bulk_update_tuples(self, value): + def _bulk_update_tuples( + self, value: Any + ) -> List[Tuple[SQLCoreOperations[_T], Any]]: """Receive a SQL expression that represents a value in the SET clause of an UPDATE statement. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 389f7e8d00..2be98b88fe 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -56,3 +56,13 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]: def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: return t._is_tuple_type + + +def is_has_clause_element(s: object) -> TypeGuard[roles.HasClauseElement]: + return hasattr(s, "__clause_element__") + + +def is_has_column_element_clause_element( + s: object, +) -> TypeGuard[roles.HasColumnElementClauseElement]: + return hasattr(s, "__clause_element__") diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index db88496a09..8f878b66c0 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -265,7 +265,7 @@ OPERATORS = { operators.nulls_last_op: " NULLS LAST", } -FUNCTIONS: Dict[Type[Function], str] = { +FUNCTIONS: Dict[Type[Function[Any]], str] = { functions.coalesce: "coalesce", functions.current_date: "CURRENT_DATE", functions.current_time: "CURRENT_TIME", @@ -2043,7 +2043,7 @@ class SQLCompiler(Compiled): def visit_function( self, - func: Function, + func: Function[Any], add_to_result_map: Optional[_ResultMapAppender] = None, **kwargs: Any, ) -> str: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index fdb3fc8bbf..c1a7d84762 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -103,8 +103,8 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import CacheStats from ..engine.result import Result -_NUMERIC = Union[complex, Decimal] -_NUMBER = Union[complex, int, Decimal] +_NUMERIC = Union[float, Decimal] +_NUMBER = Union[float, int, Decimal] _T = TypeVar("_T", bound="Any") _OPT = TypeVar("_OPT", bound="Any") @@ -348,6 +348,7 @@ class ClauseElement( the _copy_internals() method. """ + skip = self._memoized_keys c = self.__class__.__new__(self.__class__) c.__dict__ = {k: v for k, v in self.__dict__.items() if k not in skip} @@ -995,10 +996,14 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): @overload def __truediv__( - self: _SQO[_NMT], other: Any + self: _SQO[int], other: Any ) -> ColumnElement[_NUMERIC]: ... + @overload + def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: + ... + @overload def __truediv__(self, other: Any) -> ColumnElement[Any]: ... diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6e5eec1271..563b58418b 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -12,7 +12,10 @@ from __future__ import annotations from typing import Any +from typing import Optional +from typing import overload from typing import Sequence +from typing import TYPE_CHECKING from typing import TypeVar from . import annotation @@ -47,6 +50,8 @@ from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util +if TYPE_CHECKING: + from ._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) @@ -104,7 +109,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): _with_ordinality = False _table_value_type = None - def __init__(self, *clauses, **kwargs): + def __init__(self, *clauses: Any): r"""Construct a :class:`.FunctionElement`. :param \*clauses: list of column expressions that form the arguments @@ -752,7 +757,7 @@ class _FunctionGenerator: self.__names = [] self.opts = opts - def __getattr__(self, name): + def __getattr__(self, name: str) -> _FunctionGenerator: # passthru __ attributes; fixes pydoc if name.startswith("__"): try: @@ -766,7 +771,17 @@ class _FunctionGenerator: f.__names = list(self.__names) + [name] return f - def __call__(self, *c, **kwargs): + @overload + def __call__( + self, *c: Any, type_: TypeEngine[_T], **kwargs: Any + ) -> Function[_T]: + ... + + @overload + def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: + ... + + def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: o = self.opts.copy() o.update(kwargs) @@ -795,7 +810,7 @@ func.__doc__ = _FunctionGenerator.__doc__ modifier = _FunctionGenerator(group=False) -class Function(FunctionElement): +class Function(FunctionElement[_T]): r"""Describe a named SQL function. The :class:`.Function` object is typically generated from the @@ -842,7 +857,7 @@ class Function(FunctionElement): packagenames: Sequence[str] - type: TypeEngine = sqltypes.NULLTYPE + type: TypeEngine[_T] """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -859,19 +874,25 @@ class Function(FunctionElement): """ - def __init__(self, name, *clauses, **kw): + def __init__( + self, + name: str, + *clauses: Any, + type_: Optional[_TypeEngineArgument[_T]] = None, + packagenames: Optional[Sequence[str]] = None, + ): """Construct a :class:`.Function`. The :data:`.func` construct is normally used to construct new :class:`.Function` instances. """ - self.packagenames = kw.pop("packagenames", None) or () + self.packagenames = packagenames or () self.name = name - self.type = sqltypes.to_instance(kw.get("type_", None)) + self.type = sqltypes.to_instance(type_) - FunctionElement.__init__(self, *clauses, **kw) + FunctionElement.__init__(self, *clauses) def _bind_param(self, operator, obj, type_=None, **kw): return BindParameter( diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index e64ec0843f..4d0169370c 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -17,6 +17,7 @@ import enum import json import pickle from typing import Any +from typing import overload from typing import Sequence from typing import Tuple from typing import TypeVar @@ -48,6 +49,7 @@ from .. import util from ..engine import processors from ..util import langhelpers from ..util import OrderedDict +from ..util.typing import Literal _T = TypeVar("_T", bound="Any") @@ -373,9 +375,10 @@ class BigInteger(Integer): __visit_name__ = "big_integer" -class Numeric( - _LookupExpressionAdapter, TypeEngine[Union[decimal.Decimal, float]] -): +_N = TypeVar("_N", bound=Union[decimal.Decimal, float]) + + +class Numeric(_LookupExpressionAdapter, TypeEngine[_N]): """A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``. @@ -542,7 +545,7 @@ class Numeric( } -class Float(Numeric): +class Float(Numeric[_N]): """Type representing floating point types, such as ``FLOAT`` or ``REAL``. @@ -567,8 +570,34 @@ class Float(Numeric): scale = None + @overload + def __init__( + self: Float[float], + precision=..., + decimal_return_scale=..., + ): + ... + + @overload + def __init__( + self: Float[decimal.Decimal], + precision=..., + asdecimal: Literal[True] = ..., + decimal_return_scale=..., + ): + ... + + @overload + def __init__( + self: Float[float], + precision=..., + asdecimal: Literal[False] = ..., + decimal_return_scale=..., + ): + ... + def __init__( - self: "Float", + self: Float[_N], precision=None, asdecimal=False, decimal_return_scale=None, diff --git a/pyproject.toml b/pyproject.toml index c4848494df..6cfa8db469 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ module = [ "sqlalchemy.sql._py_util", "sqlalchemy.connectors.*", "sqlalchemy.engine.*", + "sqlalchemy.ext.hybrid", "sqlalchemy.ext.associationproxy", "sqlalchemy.pool.*", "sqlalchemy.event.*", diff --git a/test/ext/mypy/plain_files/hybrid_one.py b/test/ext/mypy/plain_files/hybrid_one.py new file mode 100644 index 0000000000..7d97024afe --- /dev/null +++ b/test/ext/mypy/plain_files/hybrid_one.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import typing + +from sqlalchemy.ext.hybrid import hybrid_method +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Interval(Base): + __tablename__ = "interval" + + id: Mapped[int] = mapped_column(primary_key=True) + start: Mapped[int] + end: Mapped[int] + + def __init__(self, start: int, end: int): + self.start = start + self.end = end + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @hybrid_method + def contains(self, point: int) -> int: + return (self.start <= point) & (point <= self.end) + + @hybrid_method + def intersects(self, other: Interval) -> int: + return self.contains(other.start) | self.contains(other.end) + + +i1 = Interval(5, 10) +i2 = Interval(7, 12) + +expr1 = Interval.length.in_([5, 10]) + +expr2 = Interval.contains(7) + +expr3 = Interval.intersects(i2) + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: builtins.int\* + reveal_type(i1.length) + + # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + reveal_type(Interval.length) + + # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] + reveal_type(expr1) + + # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + reveal_type(expr2) + + # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + reveal_type(expr3) diff --git a/test/ext/mypy/plain_files/hybrid_two.py b/test/ext/mypy/plain_files/hybrid_two.py new file mode 100644 index 0000000000..6bfabbd30a --- /dev/null +++ b/test/ext/mypy/plain_files/hybrid_two.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import typing + +from sqlalchemy import Float +from sqlalchemy import func +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.expression import ColumnElement + + +class Base(DeclarativeBase): + pass + + +class Interval(Base): + __tablename__ = "interval" + + id: Mapped[int] = mapped_column(primary_key=True) + start: Mapped[int] + end: Mapped[int] + + def __init__(self, start: int, end: int): + self.start = start + self.end = end + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + # im not sure if there's a way to get typing tools to not complain about + # the re-defined name here, it handles it for plain @property + # but im not sure if that's hardcoded + # see https://github.com/python/typing/discussions/1102 + + @hybrid_property + def _inst_radius(self) -> float: + return abs(self.length) / 2 + + @_inst_radius.expression + def radius(cls) -> ColumnElement[float]: + f1 = func.abs(cls.length, type_=Float()) + + expr = f1 / 2 + + # while we are here, check some Float[] / div type stuff + if typing.TYPE_CHECKING: + # EXPECTED_TYPE: sqlalchemy.*Function\[builtins.float\*?\] + reveal_type(f1) + + # EXPECTED_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\] + reveal_type(expr) + return expr + + +i1 = Interval(5, 10) +i2 = Interval(7, 12) + +l1: int = i1.length +rd: float = i2.radius + +expr1 = Interval.length.in_([5, 10]) + +expr2 = Interval.radius + +expr3 = Interval.radius.in_([0.5, 5.2]) + + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: builtins.int\* + reveal_type(i1.length) + + # EXPECTED_TYPE: builtins.float\* + reveal_type(i2.radius) + + # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + reveal_type(Interval.length) + + # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\] + reveal_type(Interval.radius) + + # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] + reveal_type(expr1) + + # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\] + reveal_type(expr2) + + # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] + reveal_type(expr3) diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py index a3d57aa1ce..b7bae0185d 100644 --- a/test/ext/mypy/plain_files/sql_operations.py +++ b/test/ext/mypy/plain_files/sql_operations.py @@ -49,7 +49,7 @@ if typing.TYPE_CHECKING: # EXPECTED_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\] reveal_type(expr2) - # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.complex, decimal.Decimal\]\] + # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\] reveal_type(expr3) # EXPECTED_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\]