From: Mike Bayer Date: Sun, 15 Jan 2023 03:24:36 +0000 (-0500) Subject: apply pep-612 to hybrid_method; accept SQLCoreOperations X-Git-Tag: rel_2_0_0rc3~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=67c1c018f571fbbcf070c4e0637f36d9533c86d7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git apply pep-612 to hybrid_method; accept SQLCoreOperations Fixes to the annotations within the ``sqlalchemy.ext.hybrid`` extension for more effective typing of user-defined methods. The typing now uses :pep:`612` features, now supported by recent versions of Mypy, to maintain argument signatures for :class:`.hybrid_method`. Return values for hybrid methods are accepted as SQL expressions in contexts such as :meth:`_sql.Select.where` while still supporting SQL methods. Fixes: #9096 Change-Id: Id4e3a38ec50e415220dfc5f022281b11bb262469 --- diff --git a/doc/build/changelog/unreleased_20/9096.rst b/doc/build/changelog/unreleased_20/9096.rst new file mode 100644 index 0000000000..70755fc8f9 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9096.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, typing + :tickets: 9096 + + Fixes to the annotations within the ``sqlalchemy.ext.hybrid`` extension for + more effective typing of user-defined methods. The typing now uses + :pep:`612` features, now supported by recent versions of Mypy, to maintain + argument signatures for :class:`.hybrid_method`. Return values for hybrid + methods are accepted as SQL expressions in contexts such as + :meth:`_sql.Select.where` while still supporting SQL methods. diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 657bc8c6e6..115c1cb85b 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -707,7 +707,9 @@ from ..sql import roles from ..sql._typing import is_has_clause_element from ..sql.elements import ColumnElement from ..sql.elements import SQLCoreOperations +from ..util.typing import Concatenate from ..util.typing import Literal +from ..util.typing import ParamSpec from ..util.typing import Protocol if TYPE_CHECKING: @@ -719,6 +721,8 @@ if TYPE_CHECKING: from ..sql._typing import _InfoType from ..sql.operators import OperatorType +_P = ParamSpec("_P") +_R = TypeVar("_R") _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) _T_con = TypeVar("_T_con", bound=Any, contravariant=True) @@ -784,7 +788,7 @@ class _HybridExprCallableType(Protocol[_T_co]): ... -class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]): +class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): """A decorator which allows definition of a Python object method with both instance-level and class-level behavior. @@ -795,8 +799,10 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]): def __init__( self, - func: Callable[..., _T], - expr: Optional[Callable[..., SQLCoreOperations[_T]]] = None, + func: Callable[Concatenate[Any, _P], _R], + expr: Optional[ + Callable[Concatenate[Any, _P], SQLCoreOperations[_R]] + ] = None, ): """Create a new :class:`.hybrid_method`. @@ -815,31 +821,34 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]): """ self.func = func - self.expression(expr or func) + if expr is not None: + self.expression(expr) + else: + self.expression(func) # type: ignore @overload def __get__( self, instance: Literal[None], owner: Type[object] - ) -> Callable[[Any], SQLCoreOperations[_T]]: + ) -> Callable[_P, SQLCoreOperations[_R]]: ... @overload def __get__( self, instance: object, owner: Type[object] - ) -> Callable[[Any], _T]: + ) -> Callable[_P, _R]: ... def __get__( self, instance: Optional[object], owner: Type[object] - ) -> Union[Callable[[Any], _T], Callable[[Any], SQLCoreOperations[_T]]]: + ) -> Union[Callable[_P, _R], Callable[_P, SQLCoreOperations[_R]]]: if instance is None: return self.expr.__get__(owner, owner) # type: ignore else: return self.func.__get__(instance, owner) # type: ignore def expression( - self, expr: Callable[..., SQLCoreOperations[_T]] - ) -> hybrid_method[_T]: + self, expr: Callable[Concatenate[Any, _P], SQLCoreOperations[_R]] + ) -> hybrid_method[_P, _R]: """Provide a modifying decorator that defines a SQL-expression producing method.""" diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index a120629caa..da3a9ad4e7 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from .elements import ColumnElement from .elements import KeyedColumnElement from .elements import quoted_name + from .elements import SQLCoreOperations from .elements import TextClause from .lambdas import LambdaElement from .roles import ColumnsClauseRole @@ -128,6 +129,7 @@ _TextCoercedExpressionArgument = Union[ _ColumnsClauseArgument = Union[ roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, + "SQLCoreOperations[_T]", Literal["*", 1], Type[_T], Inspectable[_HasClauseElement], @@ -144,7 +146,10 @@ sets; select(...), insert().returning(...), etc. """ _TypedColumnClauseArgument = Union[ - roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T] + roles.TypedColumnsClauseRole[_T], + "SQLCoreOperations[_T]", + roles.ExpressionElementRole[_T], + Type[_T], ] _TP = TypeVar("_TP", bound=Tuple[Any, ...]) @@ -164,6 +169,7 @@ _T9 = TypeVar("_T9", bound=Any) _ColumnExpressionArgument = Union[ "ColumnElement[_T]", _HasClauseElement, + "SQLCoreOperations[_T]", roles.ExpressionElementRole[_T], Callable[[], "ColumnElement[_T]"], "LambdaElement", diff --git a/test/ext/mypy/plain_files/hybrid_one.py b/test/ext/mypy/plain_files/hybrid_one.py index 12c7c204c5..b3ce365acd 100644 --- a/test/ext/mypy/plain_files/hybrid_one.py +++ b/test/ext/mypy/plain_files/hybrid_one.py @@ -2,6 +2,7 @@ from __future__ import annotations import typing +from sqlalchemy import select from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import DeclarativeBase @@ -36,6 +37,10 @@ class Interval(Base): def intersects(self, other: Interval) -> int: return self.contains(other.start) | self.contains(other.end) + @hybrid_method + def fancy_thing(self, point: int, x: int, y: int) -> bool: + return (self.start <= point) & (point <= self.end) + i1 = Interval(5, 10) i2 = Interval(7, 12) @@ -46,6 +51,20 @@ expr2 = Interval.contains(7) expr3 = Interval.intersects(i2) +expr4 = Interval.fancy_thing(10, 12, 15) + +# test that pep-612 actually works + +# EXPECTED_MYPY: Too few arguments +Interval.fancy_thing(1, 2) + +# EXPECTED_MYPY: Argument 2 has incompatible type +Interval.fancy_thing(1, "foo", 3) + +stmt1 = select(Interval).where(expr1).where(expr4) + +stmt2 = select(expr4) + if typing.TYPE_CHECKING: # EXPECTED_RE_TYPE: builtins.int\*? reveal_type(i1.length) @@ -61,3 +80,12 @@ if typing.TYPE_CHECKING: # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] reveal_type(expr3) + + # EXPECTED_TYPE: bool + reveal_type(i1.fancy_thing(1, 2, 3)) + + # EXPECTED_TYPE: SQLCoreOperations[bool] + reveal_type(expr4) + + # EXPECTED_TYPE: Select[Tuple[bool]] + reveal_type(stmt2)