]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
apply pep-612 to hybrid_method; accept SQLCoreOperations
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Jan 2023 03:24:36 +0000 (22:24 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Jan 2023 03:36:22 +0000 (22:36 -0500)
Fixes to the annotations within the ``sqlalchemy.ext.hybrid`` extension for
more effective typing of user-defined methods. The typing now uses
:pep:`612` features, now supported by recent versions of Mypy, to maintain
argument signatures for :class:`.hybrid_method`. Return values for hybrid
methods are accepted as SQL expressions in contexts such as
:meth:`_sql.Select.where` while still supporting SQL methods.

Fixes: #9096
Change-Id: Id4e3a38ec50e415220dfc5f022281b11bb262469

doc/build/changelog/unreleased_20/9096.rst [new file with mode: 0644]
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/sql/_typing.py
test/ext/mypy/plain_files/hybrid_one.py

diff --git a/doc/build/changelog/unreleased_20/9096.rst b/doc/build/changelog/unreleased_20/9096.rst
new file mode 100644 (file)
index 0000000..70755fc
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 9096
+
+    Fixes to the annotations within the ``sqlalchemy.ext.hybrid`` extension for
+    more effective typing of user-defined methods. The typing now uses
+    :pep:`612` features, now supported by recent versions of Mypy, to maintain
+    argument signatures for :class:`.hybrid_method`. Return values for hybrid
+    methods are accepted as SQL expressions in contexts such as
+    :meth:`_sql.Select.where` while still supporting SQL methods.
index 657bc8c6e68bc7d92aadcc244660df5836c8fd5a..115c1cb85b4fd9aa4ef75874c8f25d5381515579 100644 (file)
@@ -707,7 +707,9 @@ from ..sql import roles
 from ..sql._typing import is_has_clause_element
 from ..sql.elements import ColumnElement
 from ..sql.elements import SQLCoreOperations
+from ..util.typing import Concatenate
 from ..util.typing import Literal
+from ..util.typing import ParamSpec
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
@@ -719,6 +721,8 @@ if TYPE_CHECKING:
     from ..sql._typing import _InfoType
     from ..sql.operators import OperatorType
 
+_P = ParamSpec("_P")
+_R = TypeVar("_R")
 _T = TypeVar("_T", bound=Any)
 _T_co = TypeVar("_T_co", bound=Any, covariant=True)
 _T_con = TypeVar("_T_con", bound=Any, contravariant=True)
@@ -784,7 +788,7 @@ class _HybridExprCallableType(Protocol[_T_co]):
         ...
 
 
-class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
+class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
     """A decorator which allows definition of a Python object method with both
     instance-level and class-level behavior.
 
@@ -795,8 +799,10 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
 
     def __init__(
         self,
-        func: Callable[..., _T],
-        expr: Optional[Callable[..., SQLCoreOperations[_T]]] = None,
+        func: Callable[Concatenate[Any, _P], _R],
+        expr: Optional[
+            Callable[Concatenate[Any, _P], SQLCoreOperations[_R]]
+        ] = None,
     ):
         """Create a new :class:`.hybrid_method`.
 
@@ -815,31 +821,34 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
 
         """
         self.func = func
-        self.expression(expr or func)
+        if expr is not None:
+            self.expression(expr)
+        else:
+            self.expression(func)  # type: ignore
 
     @overload
     def __get__(
         self, instance: Literal[None], owner: Type[object]
-    ) -> Callable[[Any], SQLCoreOperations[_T]]:
+    ) -> Callable[_P, SQLCoreOperations[_R]]:
         ...
 
     @overload
     def __get__(
         self, instance: object, owner: Type[object]
-    ) -> Callable[[Any], _T]:
+    ) -> Callable[_P, _R]:
         ...
 
     def __get__(
         self, instance: Optional[object], owner: Type[object]
-    ) -> Union[Callable[[Any], _T], Callable[[Any], SQLCoreOperations[_T]]]:
+    ) -> Union[Callable[_P, _R], Callable[_P, SQLCoreOperations[_R]]]:
         if instance is None:
             return self.expr.__get__(owner, owner)  # type: ignore
         else:
             return self.func.__get__(instance, owner)  # type: ignore
 
     def expression(
-        self, expr: Callable[..., SQLCoreOperations[_T]]
-    ) -> hybrid_method[_T]:
+        self, expr: Callable[Concatenate[Any, _P], SQLCoreOperations[_R]]
+    ) -> hybrid_method[_P, _R]:
         """Provide a modifying decorator that defines a
         SQL-expression producing method."""
 
index a120629caa14eb2541d896534948c3127a29af0e..da3a9ad4e7a8fb1e2d7c68278bbb439ae85f9adb 100644 (file)
@@ -44,6 +44,7 @@ if TYPE_CHECKING:
     from .elements import ColumnElement
     from .elements import KeyedColumnElement
     from .elements import quoted_name
+    from .elements import SQLCoreOperations
     from .elements import TextClause
     from .lambdas import LambdaElement
     from .roles import ColumnsClauseRole
@@ -128,6 +129,7 @@ _TextCoercedExpressionArgument = Union[
 _ColumnsClauseArgument = Union[
     roles.TypedColumnsClauseRole[_T],
     roles.ColumnsClauseRole,
+    "SQLCoreOperations[_T]",
     Literal["*", 1],
     Type[_T],
     Inspectable[_HasClauseElement],
@@ -144,7 +146,10 @@ sets; select(...), insert().returning(...), etc.
 """
 
 _TypedColumnClauseArgument = Union[
-    roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T]
+    roles.TypedColumnsClauseRole[_T],
+    "SQLCoreOperations[_T]",
+    roles.ExpressionElementRole[_T],
+    Type[_T],
 ]
 
 _TP = TypeVar("_TP", bound=Tuple[Any, ...])
@@ -164,6 +169,7 @@ _T9 = TypeVar("_T9", bound=Any)
 _ColumnExpressionArgument = Union[
     "ColumnElement[_T]",
     _HasClauseElement,
+    "SQLCoreOperations[_T]",
     roles.ExpressionElementRole[_T],
     Callable[[], "ColumnElement[_T]"],
     "LambdaElement",
index 12c7c204c573cf28c22df46bcdc65b042267ebf1..b3ce365acd83e9b5e1380ef378cd68395a0fd31f 100644 (file)
@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import typing
 
+from sqlalchemy import select
 from sqlalchemy.ext.hybrid import hybrid_method
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase
@@ -36,6 +37,10 @@ class Interval(Base):
     def intersects(self, other: Interval) -> int:
         return self.contains(other.start) | self.contains(other.end)
 
+    @hybrid_method
+    def fancy_thing(self, point: int, x: int, y: int) -> bool:
+        return (self.start <= point) & (point <= self.end)
+
 
 i1 = Interval(5, 10)
 i2 = Interval(7, 12)
@@ -46,6 +51,20 @@ expr2 = Interval.contains(7)
 
 expr3 = Interval.intersects(i2)
 
+expr4 = Interval.fancy_thing(10, 12, 15)
+
+# test that pep-612 actually works
+
+# EXPECTED_MYPY: Too few arguments
+Interval.fancy_thing(1, 2)
+
+# EXPECTED_MYPY: Argument 2 has incompatible type
+Interval.fancy_thing(1, "foo", 3)
+
+stmt1 = select(Interval).where(expr1).where(expr4)
+
+stmt2 = select(expr4)
+
 if typing.TYPE_CHECKING:
     # EXPECTED_RE_TYPE: builtins.int\*?
     reveal_type(i1.length)
@@ -61,3 +80,12 @@ if typing.TYPE_CHECKING:
 
     # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
     reveal_type(expr3)
+
+    # EXPECTED_TYPE: bool
+    reveal_type(i1.fancy_thing(1, 2, 3))
+
+    # EXPECTED_TYPE: SQLCoreOperations[bool]
+    reveal_type(expr4)
+
+    # EXPECTED_TYPE: Select[Tuple[bool]]
+    reveal_type(stmt2)