]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Order_by and group_by accept labels
authorFederico Caselli <cfederico87@gmail.com>
Sat, 3 Dec 2022 16:39:55 +0000 (17:39 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Dec 2022 18:04:58 +0000 (13:04 -0500)
Improve typing to accept labels in ordey_by mand group_by.

Change-Id: I33e5d6f64633d39a220108d412ef84d6478b25e6

lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/_elements_constructors.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/ext/mypy/plain_files/common_sql_element.py

index 0d8d21df099d6b17f07a45cb2bb8d39fa74f6520..d51c8bf9ac55ded5ad318e04d45ca02ff138520f 100644 (file)
@@ -111,6 +111,7 @@ if TYPE_CHECKING:
     from ..engine.result import FrozenResult
     from ..engine.result import ScalarResult
     from ..sql._typing import _ColumnExpressionArgument
+    from ..sql._typing import _ColumnExpressionOrStrLabelArgument
     from ..sql._typing import _ColumnsClauseArgument
     from ..sql._typing import _DMLColumnArgument
     from ..sql._typing import _JoinTargetArgument
@@ -1952,9 +1953,10 @@ class Query(
     def order_by(
         self: SelfQuery,
         __first: Union[
-            Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+            Literal[None, False, _NoArg.NO_ARG],
+            _ColumnExpressionOrStrLabelArgument[Any],
         ] = _NoArg.NO_ARG,
-        *clauses: _ColumnExpressionArgument[Any],
+        *clauses: _ColumnExpressionOrStrLabelArgument[Any],
     ) -> SelfQuery:
         """Apply one or more ORDER BY criteria to the query and return
         the newly resulting :class:`_query.Query`.
@@ -2000,9 +2002,10 @@ class Query(
     def group_by(
         self: SelfQuery,
         __first: Union[
-            Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+            Literal[None, False, _NoArg.NO_ARG],
+            _ColumnExpressionOrStrLabelArgument[Any],
         ] = _NoArg.NO_ARG,
-        *clauses: _ColumnExpressionArgument[Any],
+        *clauses: _ColumnExpressionOrStrLabelArgument[Any],
     ) -> SelfQuery:
         """Apply one or more GROUP BY criterion to the query and return
         the newly resulting :class:`_query.Query`.
index 7c5281beeb7348175adb96991de2b658270d033c..2e5e399f9f2a8c5f34915e8aa0dab7d9b2bec634 100644 (file)
@@ -49,6 +49,7 @@ from ..util.typing import Literal
 if typing.TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrLiteralArgument
+    from ._typing import _ColumnExpressionOrStrLabelArgument
     from ._typing import _TypeEngineArgument
     from .elements import BinaryExpression
     from .selectable import FromClause
@@ -226,7 +227,9 @@ def any_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]:
     return CollectionAggregate._create_any(expr)
 
 
-def asc(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
+def asc(
+    column: _ColumnExpressionOrStrLabelArgument[_T],
+) -> UnaryExpression[_T]:
     """Produce an ascending ``ORDER BY`` clause element.
 
     e.g.::
@@ -935,7 +938,9 @@ def column(
     return ColumnClause(text, type_, is_literal, _selectable)
 
 
-def desc(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
+def desc(
+    column: _ColumnExpressionOrStrLabelArgument[_T],
+) -> UnaryExpression[_T]:
     """Produce a descending ``ORDER BY`` clause element.
 
     e.g.::
index 8a758e7c7e2fa08503574bb4a17a8caad9cede8d..78e196efc21b970160470b63137d9aadc23c9d58 100644 (file)
@@ -181,6 +181,8 @@ overall which brings in the TextClause object also.
 
 _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]]
 
+_ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]]
+
 
 _InfoType = Dict[Any, Any]
 """the .info dictionary accepted and used throughout Core /ORM"""
index eff8c9bc11ba62f19cb254eac3afc46e97afb32c..3896d4cbd98862f81d0f1d81245d1f4bd6ad754e 100644 (file)
@@ -80,6 +80,7 @@ from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
+    from ._typing import _ColumnExpressionOrStrLabelArgument
     from ._typing import _InfoType
     from ._typing import _PropagateAttrsType
     from ._typing import _TypeEngineArgument
@@ -3494,7 +3495,7 @@ class UnaryExpression(ColumnElement[_T]):
 
     @classmethod
     def _create_desc(
-        cls, column: _ColumnExpressionArgument[_T]
+        cls, column: _ColumnExpressionOrStrLabelArgument[_T]
     ) -> UnaryExpression[_T]:
         return UnaryExpression(
             coercions.expect(roles.ByOfRole, column),
@@ -3505,7 +3506,7 @@ class UnaryExpression(ColumnElement[_T]):
     @classmethod
     def _create_asc(
         cls,
-        column: _ColumnExpressionArgument[_T],
+        column: _ColumnExpressionOrStrLabelArgument[_T],
     ) -> UnaryExpression[_T]:
         return UnaryExpression(
             coercions.expect(roles.ByOfRole, column),
index 2dcc611fa52717d278f9e113af3bcdbb374c7f31..fd4157afdd30d44371fc59106e5bb3f8f88f6b80 100644 (file)
@@ -106,6 +106,7 @@ _T = TypeVar("_T", bound=Any)
 
 if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
+    from ._typing import _ColumnExpressionOrStrLabelArgument
     from ._typing import _FromClauseArgument
     from ._typing import _JoinTargetArgument
     from ._typing import _MAYBE_ENTITY
@@ -4146,9 +4147,10 @@ class GenerativeSelect(SelectBase, Generative):
     def order_by(
         self: SelfGenerativeSelect,
         __first: Union[
-            Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+            Literal[None, _NoArg.NO_ARG],
+            _ColumnExpressionOrStrLabelArgument[Any],
         ] = _NoArg.NO_ARG,
-        *clauses: _ColumnExpressionArgument[Any],
+        *clauses: _ColumnExpressionOrStrLabelArgument[Any],
     ) -> SelfGenerativeSelect:
         r"""Return a new selectable with the given list of ORDER BY
         criteria applied.
@@ -4190,9 +4192,10 @@ class GenerativeSelect(SelectBase, Generative):
     def group_by(
         self: SelfGenerativeSelect,
         __first: Union[
-            Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+            Literal[None, _NoArg.NO_ARG],
+            _ColumnExpressionOrStrLabelArgument[Any],
         ] = _NoArg.NO_ARG,
-        *clauses: _ColumnExpressionArgument[Any],
+        *clauses: _ColumnExpressionOrStrLabelArgument[Any],
     ) -> SelfGenerativeSelect:
         r"""Return a new selectable with the given list of GROUP BY
         criterion applied.
index af36c85ee9638c1df2c9b2e16dfc0dba43db6898..586a130d2596a00bc5e10f653c80d2f10ceff867 100644 (file)
@@ -9,7 +9,9 @@ unions.
 
 from __future__ import annotations
 
+from sqlalchemy import asc
 from sqlalchemy import Column
+from sqlalchemy import desc
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import select
@@ -19,6 +21,7 @@ from sqlalchemy import Table
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import Session
 
 
 class Base(DeclarativeBase):
@@ -79,6 +82,21 @@ reveal_type(stmt)
 
 stmt = stmt.where(e2)
 
+stmt2 = select(User.id).order_by("email").group_by("email")
+stmt2 = select(User.id).order_by("id", "email").group_by("email", "id")
+stmt2 = (
+    select(User.id).order_by(asc("id"), desc("email")).group_by("email", "id")
+)
+# EXPECTED_TYPE: Select[Tuple[int]]
+reveal_type(stmt2)
+
+stmt2 = select(User.id).order_by(User.id).group_by(User.email)
+stmt2 = (
+    select(User.id).order_by(User.id, User.email).group_by(User.email, User.id)
+)
+# EXPECTED_TYPE: Select[Tuple[int]]
+reveal_type(stmt2)
+
 
 receives_str_col_expr(User.email)
 receives_str_col_expr(User.email + "some expr")
@@ -92,3 +110,21 @@ receives_bool_col_expr(User.email == "x")
 receives_bool_col_expr(e2)
 receives_bool_col_expr(e2.label("x"))
 receives_bool_col_expr(user_table.c.email == "x")
+
+
+# query
+
+q1 = Session().query(User.id).order_by("email").group_by("email")
+q1 = Session().query(User.id).order_by("id", "email").group_by("email", "id")
+# EXPECTED_TYPE: RowReturningQuery[Tuple[int]]
+reveal_type(q1)
+
+q1 = Session().query(User.id).order_by(User.id).group_by(User.email)
+q1 = (
+    Session()
+    .query(User.id)
+    .order_by(User.id, User.email)
+    .group_by(User.email, User.id)
+)
+# EXPECTED_TYPE: RowReturningQuery[Tuple[int]]
+reveal_type(q1)