From: Denis Laxalde Date: Wed, 9 Apr 2025 07:04:20 +0000 (-0400) Subject: Type postgresql.aggregate_order_by() X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=09c1d3ccaccd93e0b8affa751c40c250aeedbaa5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Type postgresql.aggregate_order_by() Overloading of `__init__()` is needed, probably for the same reason as it is in `ReturnTypeFromArgs`. Related to #6810. Closes: #12463 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12463 Pull-request-sha: 701d979e20c6ca3e32b79145c20441407007122f Change-Id: I7e1bb4d2c48dfb3461725c7079aaa72c66f1dc03 --- diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 0f110b8e06..63337c7aff 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -8,6 +8,10 @@ from __future__ import annotations from typing import Any +from typing import Iterable +from typing import List +from typing import Optional +from typing import overload from typing import Sequence from typing import TYPE_CHECKING from typing import TypeVar @@ -28,12 +32,17 @@ from ...sql.visitors import InternalTraversal if TYPE_CHECKING: from ...sql._typing import _ColumnExpressionArgument + from ...sql.elements import ClauseElement + from ...sql.elements import ColumnElement + from ...sql.operators import OperatorType + from ...sql.selectable import FromClause + from ...sql.visitors import _CloneCallableType from ...sql.visitors import _TraverseInternalsType _T = TypeVar("_T", bound=Any) -class aggregate_order_by(expression.ColumnElement): +class aggregate_order_by(expression.ColumnElement[_T]): """Represent a PostgreSQL aggregate order by expression. E.g.:: @@ -77,11 +86,32 @@ class aggregate_order_by(expression.ColumnElement): ("order_by", InternalTraversal.dp_clauseelement), ] - def __init__(self, target, *order_by): - self.target = coercions.expect(roles.ExpressionElementRole, target) + @overload + def __init__( + self, + target: ColumnElement[_T], + *order_by: _ColumnExpressionArgument[Any], + ): ... + + @overload + def __init__( + self, + target: _ColumnExpressionArgument[_T], + *order_by: _ColumnExpressionArgument[Any], + ): ... + + def __init__( + self, + target: _ColumnExpressionArgument[_T], + *order_by: _ColumnExpressionArgument[Any], + ): + self.target: ClauseElement = coercions.expect( + roles.ExpressionElementRole, target + ) self.type = self.target.type _lob = len(order_by) + self.order_by: ClauseElement if _lob == 0: raise TypeError("at least one ORDER BY element is required") elif _lob == 1: @@ -93,18 +123,22 @@ class aggregate_order_by(expression.ColumnElement): *order_by, _literal_as_text_role=roles.ExpressionElementRole ) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> ClauseElement: return self - def get_children(self, **kwargs): + def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]: return self.target, self.order_by - def _copy_internals(self, clone=elements._clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = elements._clone, **kw: Any + ) -> None: self.target = clone(self.target, **kw) self.order_by = clone(self.order_by, **kw) @property - def _from_objects(self): + def _from_objects(self) -> List[FromClause]: return self.target._from_objects + self.order_by._from_objects diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 6dda180c4f..4a50a9e42c 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -10,6 +10,7 @@ from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import Text from sqlalchemy import UniqueConstraint +from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import DATERANGE @@ -131,3 +132,25 @@ reveal_type(stmt_array_agg) # EXPECTED_TYPE: Select[Sequence[str]] reveal_type(select(func.array_agg(Test.ident_str))) + +stmt_array_agg_order_by_1 = select( + func.array_agg( + aggregate_order_by( + Column("title", type_=Text), + Column("date", type_=DATERANGE).desc(), + Column("id", type_=Integer), + ), + ) +) + +# EXPECTED_TYPE: Select[Sequence[str]] +reveal_type(stmt_array_agg_order_by_1) + +stmt_array_agg_order_by_2 = select( + func.array_agg( + aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident), + ) +) + +# EXPECTED_TYPE: Select[Sequence[str]] +reveal_type(stmt_array_agg_order_by_2)