From: Denis Laxalde Date: Wed, 9 Apr 2025 07:04:20 +0000 (-0400) Subject: Type postgresql.aggregate_order_by() X-Git-Tag: rel_2_0_41~28 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5278af0c0fcbcfd05cdb81d79df6e82e8f07088b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Type postgresql.aggregate_order_by() Overloading of `__init__()` is needed, probably for the same reason as it is in `ReturnTypeFromArgs`. Related to #6810. Closes: #12463 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12463 Pull-request-sha: 701d979e20c6ca3e32b79145c20441407007122f Change-Id: I7e1bb4d2c48dfb3461725c7079aaa72c66f1dc03 (cherry picked from commit 09c1d3ccaccd93e0b8affa751c40c250aeedbaa5) --- diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 94466ae0a1..54bacd9447 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -8,6 +8,10 @@ from __future__ import annotations from typing import Any +from typing import Iterable +from typing import List +from typing import Optional +from typing import overload from typing import TYPE_CHECKING from typing import TypeVar @@ -23,13 +27,19 @@ from ...sql.schema import ColumnCollectionConstraint from ...sql.sqltypes import TEXT from ...sql.visitors import InternalTraversal -_T = TypeVar("_T", bound=Any) - if TYPE_CHECKING: + from ...sql._typing import _ColumnExpressionArgument + from ...sql.elements import ClauseElement + from ...sql.elements import ColumnElement + from ...sql.operators import OperatorType + from ...sql.selectable import FromClause + from ...sql.visitors import _CloneCallableType from ...sql.visitors import _TraverseInternalsType +_T = TypeVar("_T", bound=Any) -class aggregate_order_by(expression.ColumnElement): + +class aggregate_order_by(expression.ColumnElement[_T]): """Represent a PostgreSQL aggregate order by expression. E.g.:: @@ -75,11 +85,32 @@ class aggregate_order_by(expression.ColumnElement): ("order_by", InternalTraversal.dp_clauseelement), ] - def __init__(self, target, *order_by): - self.target = coercions.expect(roles.ExpressionElementRole, target) + @overload + def __init__( + self, + target: ColumnElement[_T], + *order_by: _ColumnExpressionArgument[Any], + ): ... + + @overload + def __init__( + self, + target: _ColumnExpressionArgument[_T], + *order_by: _ColumnExpressionArgument[Any], + ): ... + + def __init__( + self, + target: _ColumnExpressionArgument[_T], + *order_by: _ColumnExpressionArgument[Any], + ): + self.target: ClauseElement = coercions.expect( + roles.ExpressionElementRole, target + ) self.type = self.target.type _lob = len(order_by) + self.order_by: ClauseElement if _lob == 0: raise TypeError("at least one ORDER BY element is required") elif _lob == 1: @@ -91,18 +122,22 @@ class aggregate_order_by(expression.ColumnElement): *order_by, _literal_as_text_role=roles.ExpressionElementRole ) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> ClauseElement: return self - def get_children(self, **kwargs): + def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]: return self.target, self.order_by - def _copy_internals(self, clone=elements._clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = elements._clone, **kw: Any + ) -> None: self.target = clone(self.target, **kw) self.order_by = clone(self.order_by, **kw) @property - def _from_objects(self): + def _from_objects(self) -> List[FromClause]: return self.target._from_objects + self.order_by._from_objects diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 3dbb949878..0f1e588bd9 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -10,6 +10,7 @@ from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import Text from sqlalchemy import UniqueConstraint +from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import DATERANGE @@ -131,3 +132,25 @@ reveal_type(stmt_array_agg) # EXPECTED_TYPE: Select[Tuple[Sequence[str]]] reveal_type(select(func.array_agg(Test.ident_str))) + +stmt_array_agg_order_by_1 = select( + func.array_agg( + aggregate_order_by( + Column("title", type_=Text), + Column("date", type_=DATERANGE).desc(), + Column("id", type_=Integer), + ), + ) +) + +# EXPECTED_TYPE: Select[Tuple[Sequence[str]]] +reveal_type(stmt_array_agg_order_by_1) + +stmt_array_agg_order_by_2 = select( + func.array_agg( + aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident), + ) +) + +# EXPECTED_TYPE: Select[Tuple[Sequence[str]]] +reveal_type(stmt_array_agg_order_by_2)