]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type postgresql.aggregate_order_by()
authorDenis Laxalde <denis@laxalde.org>
Wed, 9 Apr 2025 07:04:20 +0000 (03:04 -0400)
committersqla-tester <sqla-tester@sqlalchemy.org>
Wed, 9 Apr 2025 07:04:20 +0000 (03:04 -0400)
Overloading of `__init__()` is needed, probably for the same reason as it is in `ReturnTypeFromArgs`.

Related to #6810.

Closes: #12463
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12463
Pull-request-sha: 701d979e20c6ca3e32b79145c20441407007122f

Change-Id: I7e1bb4d2c48dfb3461725c7079aaa72c66f1dc03

lib/sqlalchemy/dialects/postgresql/ext.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py

index 0f110b8e06a18f05fc3c08afc2bb228e66c9b7a5..63337c7aff4a75826c555a86290f0ddb0d49b064 100644 (file)
@@ -8,6 +8,10 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import TYPE_CHECKING
 from typing import TypeVar
@@ -28,12 +32,17 @@ from ...sql.visitors import InternalTraversal
 
 if TYPE_CHECKING:
     from ...sql._typing import _ColumnExpressionArgument
+    from ...sql.elements import ClauseElement
+    from ...sql.elements import ColumnElement
+    from ...sql.operators import OperatorType
+    from ...sql.selectable import FromClause
+    from ...sql.visitors import _CloneCallableType
     from ...sql.visitors import _TraverseInternalsType
 
 _T = TypeVar("_T", bound=Any)
 
 
-class aggregate_order_by(expression.ColumnElement):
+class aggregate_order_by(expression.ColumnElement[_T]):
     """Represent a PostgreSQL aggregate order by expression.
 
     E.g.::
@@ -77,11 +86,32 @@ class aggregate_order_by(expression.ColumnElement):
         ("order_by", InternalTraversal.dp_clauseelement),
     ]
 
-    def __init__(self, target, *order_by):
-        self.target = coercions.expect(roles.ExpressionElementRole, target)
+    @overload
+    def __init__(
+        self,
+        target: ColumnElement[_T],
+        *order_by: _ColumnExpressionArgument[Any],
+    ): ...
+
+    @overload
+    def __init__(
+        self,
+        target: _ColumnExpressionArgument[_T],
+        *order_by: _ColumnExpressionArgument[Any],
+    ): ...
+
+    def __init__(
+        self,
+        target: _ColumnExpressionArgument[_T],
+        *order_by: _ColumnExpressionArgument[Any],
+    ):
+        self.target: ClauseElement = coercions.expect(
+            roles.ExpressionElementRole, target
+        )
         self.type = self.target.type
 
         _lob = len(order_by)
+        self.order_by: ClauseElement
         if _lob == 0:
             raise TypeError("at least one ORDER BY element is required")
         elif _lob == 1:
@@ -93,18 +123,22 @@ class aggregate_order_by(expression.ColumnElement):
                 *order_by, _literal_as_text_role=roles.ExpressionElementRole
             )
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> ClauseElement:
         return self
 
-    def get_children(self, **kwargs):
+    def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]:
         return self.target, self.order_by
 
-    def _copy_internals(self, clone=elements._clone, **kw):
+    def _copy_internals(
+        self, clone: _CloneCallableType = elements._clone, **kw: Any
+    ) -> None:
         self.target = clone(self.target, **kw)
         self.order_by = clone(self.order_by, **kw)
 
     @property
-    def _from_objects(self):
+    def _from_objects(self) -> List[FromClause]:
         return self.target._from_objects + self.order_by._from_objects
 
 
index 6dda180c4f92cd33787ec079c5fa191690c4a9a2..4a50a9e42cc7922f48d77d4ba4c4a2a6da9f13f5 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import Text
 from sqlalchemy import UniqueConstraint
+from sqlalchemy.dialects.postgresql import aggregate_order_by
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import DATERANGE
@@ -131,3 +132,25 @@ reveal_type(stmt_array_agg)
 
 # EXPECTED_TYPE: Select[Sequence[str]]
 reveal_type(select(func.array_agg(Test.ident_str)))
+
+stmt_array_agg_order_by_1 = select(
+    func.array_agg(
+        aggregate_order_by(
+            Column("title", type_=Text),
+            Column("date", type_=DATERANGE).desc(),
+            Column("id", type_=Integer),
+        ),
+    )
+)
+
+# EXPECTED_TYPE: Select[Sequence[str]]
+reveal_type(stmt_array_agg_order_by_1)
+
+stmt_array_agg_order_by_2 = select(
+    func.array_agg(
+        aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident),
+    )
+)
+
+# EXPECTED_TYPE: Select[Sequence[str]]
+reveal_type(stmt_array_agg_order_by_2)