From c1b7600d9ec6cb29eb48455726799a6779704240 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sat, 3 Dec 2022 17:39:55 +0100 Subject: [PATCH] Order_by and group_by accept labels Improve typing to accept labels in ordey_by mand group_by. Change-Id: I33e5d6f64633d39a220108d412ef84d6478b25e6 --- lib/sqlalchemy/orm/query.py | 11 +++--- lib/sqlalchemy/sql/_elements_constructors.py | 9 +++-- lib/sqlalchemy/sql/_typing.py | 2 ++ lib/sqlalchemy/sql/elements.py | 5 +-- lib/sqlalchemy/sql/selectable.py | 11 +++--- .../mypy/plain_files/common_sql_element.py | 36 +++++++++++++++++++ 6 files changed, 62 insertions(+), 12 deletions(-) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 0d8d21df09..d51c8bf9ac 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -111,6 +111,7 @@ if TYPE_CHECKING: from ..engine.result import FrozenResult from ..engine.result import ScalarResult from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _ColumnExpressionOrStrLabelArgument from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _DMLColumnArgument from ..sql._typing import _JoinTargetArgument @@ -1952,9 +1953,10 @@ class Query( def order_by( self: SelfQuery, __first: Union[ - Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + Literal[None, False, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], ] = _NoArg.NO_ARG, - *clauses: _ColumnExpressionArgument[Any], + *clauses: _ColumnExpressionOrStrLabelArgument[Any], ) -> SelfQuery: """Apply one or more ORDER BY criteria to the query and return the newly resulting :class:`_query.Query`. @@ -2000,9 +2002,10 @@ class Query( def group_by( self: SelfQuery, __first: Union[ - Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + Literal[None, False, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], ] = _NoArg.NO_ARG, - *clauses: _ColumnExpressionArgument[Any], + *clauses: _ColumnExpressionOrStrLabelArgument[Any], ) -> SelfQuery: """Apply one or more GROUP BY criterion to the query and return the newly resulting :class:`_query.Query`. diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 7c5281beeb..2e5e399f9f 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -49,6 +49,7 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrLiteralArgument + from ._typing import _ColumnExpressionOrStrLabelArgument from ._typing import _TypeEngineArgument from .elements import BinaryExpression from .selectable import FromClause @@ -226,7 +227,9 @@ def any_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]: return CollectionAggregate._create_any(expr) -def asc(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: +def asc( + column: _ColumnExpressionOrStrLabelArgument[_T], +) -> UnaryExpression[_T]: """Produce an ascending ``ORDER BY`` clause element. e.g.:: @@ -935,7 +938,9 @@ def column( return ColumnClause(text, type_, is_literal, _selectable) -def desc(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: +def desc( + column: _ColumnExpressionOrStrLabelArgument[_T], +) -> UnaryExpression[_T]: """Produce a descending ``ORDER BY`` clause element. e.g.:: diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 8a758e7c7e..78e196efc2 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -181,6 +181,8 @@ overall which brings in the TextClause object also. _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]] +_ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]] + _InfoType = Dict[Any, Any] """the .info dictionary accepted and used throughout Core /ORM""" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index eff8c9bc11..3896d4cbd9 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -80,6 +80,7 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrStrLabelArgument from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument @@ -3494,7 +3495,7 @@ class UnaryExpression(ColumnElement[_T]): @classmethod def _create_desc( - cls, column: _ColumnExpressionArgument[_T] + cls, column: _ColumnExpressionOrStrLabelArgument[_T] ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), @@ -3505,7 +3506,7 @@ class UnaryExpression(ColumnElement[_T]): @classmethod def _create_asc( cls, - column: _ColumnExpressionArgument[_T], + column: _ColumnExpressionOrStrLabelArgument[_T], ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 2dcc611fa5..fd4157afdd 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -106,6 +106,7 @@ _T = TypeVar("_T", bound=Any) if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrStrLabelArgument from ._typing import _FromClauseArgument from ._typing import _JoinTargetArgument from ._typing import _MAYBE_ENTITY @@ -4146,9 +4147,10 @@ class GenerativeSelect(SelectBase, Generative): def order_by( self: SelfGenerativeSelect, __first: Union[ - Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + Literal[None, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], ] = _NoArg.NO_ARG, - *clauses: _ColumnExpressionArgument[Any], + *clauses: _ColumnExpressionOrStrLabelArgument[Any], ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of ORDER BY criteria applied. @@ -4190,9 +4192,10 @@ class GenerativeSelect(SelectBase, Generative): def group_by( self: SelfGenerativeSelect, __first: Union[ - Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + Literal[None, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], ] = _NoArg.NO_ARG, - *clauses: _ColumnExpressionArgument[Any], + *clauses: _ColumnExpressionOrStrLabelArgument[Any], ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of GROUP BY criterion applied. diff --git a/test/ext/mypy/plain_files/common_sql_element.py b/test/ext/mypy/plain_files/common_sql_element.py index af36c85ee9..586a130d25 100644 --- a/test/ext/mypy/plain_files/common_sql_element.py +++ b/test/ext/mypy/plain_files/common_sql_element.py @@ -9,7 +9,9 @@ unions. from __future__ import annotations +from sqlalchemy import asc from sqlalchemy import Column +from sqlalchemy import desc from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import select @@ -19,6 +21,7 @@ from sqlalchemy import Table from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Session class Base(DeclarativeBase): @@ -79,6 +82,21 @@ reveal_type(stmt) stmt = stmt.where(e2) +stmt2 = select(User.id).order_by("email").group_by("email") +stmt2 = select(User.id).order_by("id", "email").group_by("email", "id") +stmt2 = ( + select(User.id).order_by(asc("id"), desc("email")).group_by("email", "id") +) +# EXPECTED_TYPE: Select[Tuple[int]] +reveal_type(stmt2) + +stmt2 = select(User.id).order_by(User.id).group_by(User.email) +stmt2 = ( + select(User.id).order_by(User.id, User.email).group_by(User.email, User.id) +) +# EXPECTED_TYPE: Select[Tuple[int]] +reveal_type(stmt2) + receives_str_col_expr(User.email) receives_str_col_expr(User.email + "some expr") @@ -92,3 +110,21 @@ receives_bool_col_expr(User.email == "x") receives_bool_col_expr(e2) receives_bool_col_expr(e2.label("x")) receives_bool_col_expr(user_table.c.email == "x") + + +# query + +q1 = Session().query(User.id).order_by("email").group_by("email") +q1 = Session().query(User.id).order_by("id", "email").group_by("email", "id") +# EXPECTED_TYPE: RowReturningQuery[Tuple[int]] +reveal_type(q1) + +q1 = Session().query(User.id).order_by(User.id).group_by(User.email) +q1 = ( + Session() + .query(User.id) + .order_by(User.id, User.email) + .group_by(User.email, User.id) +) +# EXPECTED_TYPE: RowReturningQuery[Tuple[int]] +reveal_type(q1) -- 2.47.2