]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type postgresql.aggregate_order_by()
authorDenis Laxalde <denis@laxalde.org>
Wed, 9 Apr 2025 07:04:20 +0000 (03:04 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 10 Apr 2025 21:20:08 +0000 (23:20 +0200)
Overloading of `__init__()` is needed, probably for the same reason as it is in `ReturnTypeFromArgs`.

Related to #6810.

Closes: #12463
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12463
Pull-request-sha: 701d979e20c6ca3e32b79145c20441407007122f

Change-Id: I7e1bb4d2c48dfb3461725c7079aaa72c66f1dc03
(cherry picked from commit 09c1d3ccaccd93e0b8affa751c40c250aeedbaa5)

lib/sqlalchemy/dialects/postgresql/ext.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py

index 94466ae0a1396e3e6f93a7d0c86cc4a59442481f..54bacd9447158a8eb474db84411c6fab9ef16edc 100644 (file)
@@ -8,6 +8,10 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import overload
 from typing import TYPE_CHECKING
 from typing import TypeVar
 
@@ -23,13 +27,19 @@ from ...sql.schema import ColumnCollectionConstraint
 from ...sql.sqltypes import TEXT
 from ...sql.visitors import InternalTraversal
 
-_T = TypeVar("_T", bound=Any)
-
 if TYPE_CHECKING:
+    from ...sql._typing import _ColumnExpressionArgument
+    from ...sql.elements import ClauseElement
+    from ...sql.elements import ColumnElement
+    from ...sql.operators import OperatorType
+    from ...sql.selectable import FromClause
+    from ...sql.visitors import _CloneCallableType
     from ...sql.visitors import _TraverseInternalsType
 
+_T = TypeVar("_T", bound=Any)
 
-class aggregate_order_by(expression.ColumnElement):
+
+class aggregate_order_by(expression.ColumnElement[_T]):
     """Represent a PostgreSQL aggregate order by expression.
 
     E.g.::
@@ -75,11 +85,32 @@ class aggregate_order_by(expression.ColumnElement):
         ("order_by", InternalTraversal.dp_clauseelement),
     ]
 
-    def __init__(self, target, *order_by):
-        self.target = coercions.expect(roles.ExpressionElementRole, target)
+    @overload
+    def __init__(
+        self,
+        target: ColumnElement[_T],
+        *order_by: _ColumnExpressionArgument[Any],
+    ): ...
+
+    @overload
+    def __init__(
+        self,
+        target: _ColumnExpressionArgument[_T],
+        *order_by: _ColumnExpressionArgument[Any],
+    ): ...
+
+    def __init__(
+        self,
+        target: _ColumnExpressionArgument[_T],
+        *order_by: _ColumnExpressionArgument[Any],
+    ):
+        self.target: ClauseElement = coercions.expect(
+            roles.ExpressionElementRole, target
+        )
         self.type = self.target.type
 
         _lob = len(order_by)
+        self.order_by: ClauseElement
         if _lob == 0:
             raise TypeError("at least one ORDER BY element is required")
         elif _lob == 1:
@@ -91,18 +122,22 @@ class aggregate_order_by(expression.ColumnElement):
                 *order_by, _literal_as_text_role=roles.ExpressionElementRole
             )
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> ClauseElement:
         return self
 
-    def get_children(self, **kwargs):
+    def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]:
         return self.target, self.order_by
 
-    def _copy_internals(self, clone=elements._clone, **kw):
+    def _copy_internals(
+        self, clone: _CloneCallableType = elements._clone, **kw: Any
+    ) -> None:
         self.target = clone(self.target, **kw)
         self.order_by = clone(self.order_by, **kw)
 
     @property
-    def _from_objects(self):
+    def _from_objects(self) -> List[FromClause]:
         return self.target._from_objects + self.order_by._from_objects
 
 
index 3dbb94987879f5ae6616675a2420e75d046c0b14..0f1e588bd950c212ef426877a61084f5c2a90553 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import Text
 from sqlalchemy import UniqueConstraint
+from sqlalchemy.dialects.postgresql import aggregate_order_by
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import DATERANGE
@@ -131,3 +132,25 @@ reveal_type(stmt_array_agg)
 
 # EXPECTED_TYPE: Select[Tuple[Sequence[str]]]
 reveal_type(select(func.array_agg(Test.ident_str)))
+
+stmt_array_agg_order_by_1 = select(
+    func.array_agg(
+        aggregate_order_by(
+            Column("title", type_=Text),
+            Column("date", type_=DATERANGE).desc(),
+            Column("id", type_=Integer),
+        ),
+    )
+)
+
+# EXPECTED_TYPE: Select[Tuple[Sequence[str]]]
+reveal_type(stmt_array_agg_order_by_1)
+
+stmt_array_agg_order_by_2 = select(
+    func.array_agg(
+        aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident),
+    )
+)
+
+# EXPECTED_TYPE: Select[Tuple[Sequence[str]]]
+reveal_type(stmt_array_agg_order_by_2)