]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add a generic argument to _HasClauseElement
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jan 2024 17:49:10 +0000 (12:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jan 2024 17:51:02 +0000 (12:51 -0500)
Further enhancements to pep-484 typing to allow SQL functions from
:attr:`_sql.func` derived elements to work more effectively with ORM-mapped
attributes.

Fixes: #10801
Change-Id: Ib8222d888a2d8c3fbeab0d1bf5edb535916d4721

doc/build/changelog/unreleased_20/10801.rst [new file with mode: 0644]
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/_typing.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/_elements_constructors.py
lib/sqlalchemy/sql/_typing.py
test/typing/plain_files/sql/functions_again.py

diff --git a/doc/build/changelog/unreleased_20/10801.rst b/doc/build/changelog/unreleased_20/10801.rst
new file mode 100644 (file)
index 0000000..a35a548
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 10801
+
+    Further enhancements to pep-484 typing to allow SQL functions from
+    :attr:`_sql.func` derived elements to work more effectively with ORM-mapped
+    attributes.
index 6252e33d571e342160dae437e10b2f0474472a27..9208d107af6765b7b56976391a54fd1ba486a72d 100644 (file)
@@ -930,7 +930,7 @@ class _HybridDeleterType(Protocol[_T_co]):
 class _HybridExprCallableType(Protocol[_T_co]):
     def __call__(
         s, cls: Any
-    ) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]:
+    ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]:
         ...
 
 
@@ -1447,7 +1447,7 @@ class Comparator(interfaces.PropComparator[_T]):
     classes for usage with hybrids."""
 
     def __init__(
-        self, expression: Union[_HasClauseElement, SQLColumnExpression[_T]]
+        self, expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]]
     ):
         self.expression = expression
 
@@ -1482,7 +1482,7 @@ class ExprComparator(Comparator[_T]):
     def __init__(
         self,
         cls: Type[Any],
-        expression: Union[_HasClauseElement, SQLColumnExpression[_T]],
+        expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]],
         hybrid: hybrid_property[_T],
     ):
         self.cls = cls
index 532d0e0b361af8c2cb21ac5ae156e59320bd570d..d9abe28c012be5ee7b4dcb9da94c2ae3550bacd8 100644 (file)
@@ -78,7 +78,7 @@ _IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]]
 
 _ORMColumnExprArgument = Union[
     ColumnElement[_T],
-    _HasClauseElement,
+    _HasClauseElement[_T],
     roles.ExpressionElementRole[_T],
 ]
 
index 0a431d2cfb8f49bd8cc97c33ae7a2e366f5ccc68..3ab1cc64c707fbd31126134b7997c08618b958b7 100644 (file)
@@ -179,7 +179,10 @@ _ORMOrderByArgument = Union[
 ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
 
 _ORMColCollectionElement = Union[
-    ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole, "Mapped[Any]"
+    ColumnClause[Any],
+    _HasClauseElement[Any],
+    roles.DMLColumnRole,
+    "Mapped[Any]",
 ]
 _ORMColCollectionArgument = Union[
     str,
index 23e275ed5d742a2f7905fa808307ee0d5216b086..a51e4a2cf4cb0a71b4120e3bae58cfb90decdd38 100644 (file)
@@ -436,10 +436,8 @@ def outparam(
     return BindParameter(key, None, type_=type_, unique=False, isoutparam=True)
 
 
-# mypy insists that BinaryExpression and _HasClauseElement protocol overlap.
-# they do not.  at all.  bug in mypy?
 @overload
-def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]:  # type: ignore
+def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]:
     ...
 
 
index 944b29176a177403ed7f9b131dce47fd1cc5cba4..93e4d92c00c05c7847882b93e4920c76d3b5458d 100644 (file)
@@ -11,6 +11,7 @@ import operator
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Generic
 from typing import Iterable
 from typing import Mapping
 from typing import NoReturn
@@ -52,7 +53,6 @@ if TYPE_CHECKING:
     from .elements import SQLCoreOperations
     from .elements import TextClause
     from .lambdas import LambdaElement
-    from .roles import ColumnsClauseRole
     from .roles import FromClauseRole
     from .schema import Column
     from .selectable import Alias
@@ -72,6 +72,7 @@ if TYPE_CHECKING:
     from ..util.typing import TypeGuard
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 
 _CE = TypeVar("_CE", bound="ColumnElement[Any]")
@@ -79,10 +80,10 @@ _CE = TypeVar("_CE", bound="ColumnElement[Any]")
 _CLE = TypeVar("_CLE", bound="ClauseElement")
 
 
-class _HasClauseElement(Protocol):
+class _HasClauseElement(Protocol, Generic[_T_co]):
     """indicates a class that has a __clause_element__() method"""
 
-    def __clause_element__(self) -> ColumnsClauseRole:
+    def __clause_element__(self) -> roles.ExpressionElementRole[_T_co]:
         ...
 
 
@@ -112,8 +113,8 @@ _MAYBE_ENTITY = TypeVar(
     roles.ColumnsClauseRole,
     Literal["*", 1],
     Type[Any],
-    Inspectable[_HasClauseElement],
-    _HasClauseElement,
+    Inspectable[_HasClauseElement[Any]],
+    _HasClauseElement[Any],
 )
 
 
@@ -127,7 +128,7 @@ _TextCoercedExpressionArgument = Union[
     str,
     "TextClause",
     "ColumnElement[_T]",
-    _HasClauseElement,
+    _HasClauseElement[_T],
     roles.ExpressionElementRole[_T],
 ]
 
@@ -137,8 +138,8 @@ _ColumnsClauseArgument = Union[
     "SQLCoreOperations[_T]",
     Literal["*", 1],
     Type[_T],
-    Inspectable[_HasClauseElement],
-    _HasClauseElement,
+    Inspectable[_HasClauseElement[_T]],
+    _HasClauseElement[_T],
 ]
 """open-ended SELECT columns clause argument.
 
@@ -172,7 +173,7 @@ _T9 = TypeVar("_T9", bound=Any)
 
 _ColumnExpressionArgument = Union[
     "ColumnElement[_T]",
-    _HasClauseElement,
+    _HasClauseElement[_T],
     "SQLCoreOperations[_T]",
     roles.ExpressionElementRole[_T],
     Callable[[], "ColumnElement[_T]"],
@@ -212,8 +213,8 @@ _InfoType = Dict[Any, Any]
 _FromClauseArgument = Union[
     roles.FromClauseRole,
     Type[Any],
-    Inspectable[_HasClauseElement],
-    _HasClauseElement,
+    Inspectable[_HasClauseElement[Any]],
+    _HasClauseElement[Any],
 ]
 """A FROM clause, like we would send to select().select_from().
 
@@ -240,7 +241,7 @@ _SelectStatementForCompoundArgument = Union[
 
 _DMLColumnArgument = Union[
     str,
-    _HasClauseElement,
+    _HasClauseElement[Any],
     roles.DMLColumnRole,
     "SQLCoreOperations[Any]",
 ]
@@ -271,8 +272,8 @@ _DMLTableArgument = Union[
     "Alias",
     "CTE",
     Type[Any],
-    Inspectable[_HasClauseElement],
-    _HasClauseElement,
+    Inspectable[_HasClauseElement[Any]],
+    _HasClauseElement[Any],
 ]
 
 _PropagateAttrsType = util.immutabledict[str, Any]
@@ -364,7 +365,7 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
     return hasattr(s, "quote")
 
 
-def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
+def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement[Any]]:
     return hasattr(s, "__clause_element__")
 
 
index 5173d1fe0822347ec88b421b7bce2cc56c6708d7..87ade922468259c1df632c7490ace4502eba951e 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy import func
+from sqlalchemy import select
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -27,3 +28,16 @@ func.row_number().over(partition_by="a", order_by=("a", "b"))
 reveal_type(func.row_number().filter())
 # EXPECTED_TYPE: FunctionFilter[Any]
 reveal_type(func.row_number().filter(Foo.a > 0))
+
+
+# test #10801
+# EXPECTED_TYPE: max[int]
+reveal_type(func.max(Foo.b))
+
+
+stmt1 = select(
+    Foo.a,
+    func.min(Foo.b),
+).group_by(Foo.a)
+# EXPECTED_TYPE: Select[Tuple[int, int]]
+reveal_type(stmt1)