From: Reuven Starodubski Date: Thu, 11 Sep 2025 17:33:59 +0000 (-0400) Subject: Add FunctionElement.aggregate_order_by X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=459ebc668a5512be412c7b73dc6a4468363bf274;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add FunctionElement.aggregate_order_by Added new generalized aggregate function ordering to functions via the :func:`_functions.FunctionElement.aggregate_order_by` method, which receives an expression and generates the appropriate embedded "ORDER BY" or "WITHIN GROUP (ORDER BY)" phrase depending on backend database. This new function supersedes the use of the PostgreSQL :func:`_postgresql.aggregate_order_by` function, which remains present for backward compatibility. To complement the new parameter, the :paramref:`_functions.aggregate_strings.order_by` which adds ORDER BY capability to the :class:`_functions.aggregate_strings` dialect-agnostic function which works for all included backends. Thanks much to Reuven Starodubski with help on this patch. Co-authored-by: Mike Bayer Fixes: #12853 Closes: #12856 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12856 Pull-request-sha: d93fb591751227eb1f96052ea3ad449f511f70b3 Change-Id: I8eb41ff2d57695963a358b5f0017ca9372f15f70 --- diff --git a/doc/build/changelog/unreleased_21/12853.rst b/doc/build/changelog/unreleased_21/12853.rst new file mode 100644 index 0000000000..9c8775cc6f --- /dev/null +++ b/doc/build/changelog/unreleased_21/12853.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: usecase, sql + :tickets: 12853 + + Added new generalized aggregate function ordering to functions via the + :func:`_functions.FunctionElement.aggregate_order_by` method, which + receives an expression and generates the appropriate embedded "ORDER BY" or + "WITHIN GROUP (ORDER BY)" phrase depending on backend database. This new + function supersedes the use of the PostgreSQL + :func:`_postgresql.aggregate_order_by` function, which remains present for + backward compatibility. To complement the new parameter, the + :paramref:`_functions.aggregate_strings.order_by` which adds ORDER BY + capability to the :class:`_functions.aggregate_strings` dialect-agnostic + function which works for all included backends. Thanks much to Reuven + Starodubski with help on this patch. + + diff --git a/doc/build/core/sqlelement.rst b/doc/build/core/sqlelement.rst index 5e8299ab34..79c41f7d23 100644 --- a/doc/build/core/sqlelement.rst +++ b/doc/build/core/sqlelement.rst @@ -22,6 +22,8 @@ Column Element Foundational Constructors Standalone functions imported from the ``sqlalchemy`` namespace which are used when building up SQLAlchemy Expression Language constructs. +.. autofunction:: aggregate_order_by + .. autofunction:: and_ .. autofunction:: bindparam @@ -170,6 +172,8 @@ The classes here are generated using the constructors listed at well as ORM-mapped attributes that will have a ``__clause_element__()`` method. +.. autoclass:: AggregateOrderBy + :members: .. autoclass:: ColumnOperators :members: diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index de651a15b4..8e35a73acd 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -17,8 +17,10 @@ as well as array literals: * :func:`_postgresql.array_agg` - ARRAY_AGG SQL function -* :class:`_postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate - function syntax. +* :meth:`_functions.FunctionElement.aggregate_order_by` - dialect-agnostic ORDER BY + for aggregate functions + +* :class:`_postgresql.aggregate_order_by` - legacy helper specific to PostgreSQL BIT type -------- diff --git a/doc/build/tutorial/data_select.rst b/doc/build/tutorial/data_select.rst index d880b4a4ae..111ddaac1f 100644 --- a/doc/build/tutorial/data_select.rst +++ b/doc/build/tutorial/data_select.rst @@ -1652,17 +1652,55 @@ Further options for window functions include usage of ranges; see .. _tutorial_functions_within_group: -Special Modifiers WITHIN GROUP, FILTER -###################################### +Special Modifiers ORDER BY, WITHIN GROUP, FILTER +################################################ + +Some forms of SQL aggregate functions support ordering of the aggregated elements +within the scope of the function. This typically applies to aggregate +functions that produce a value which continues to enumerate the contents of the +collection, such as the ``array_agg()`` function that generates an array of +elements, or the ``string_agg()`` PostgreSQL function which generates a +delimited string (other backends like MySQL and SQLite use the +``group_concat()`` function in a similar way), or the MySQL ``json_arrayagg()`` +function which produces a JSON array. Ordering of the elements passed +to these functions is supported using the :meth:`_functions.FunctionElement.aggregate_order_by` +method, which will render ORDER BY in the appropriate part of the function:: -The "WITHIN GROUP" SQL syntax is used in conjunction with an "ordered set" -or a "hypothetical set" aggregate -function. Common "ordered set" functions include ``percentile_cont()`` -and ``rank()``. SQLAlchemy includes built in implementations + >>> with engine.connect() as conn: + ... result = conn.execute( + ... select( + ... func.group_concat(user_table.c.name).aggregate_order_by( + ... user_table.c.name.desc() + ... ) + ... ) + ... ) + ... print(result.all()) + {execsql}BEGIN (implicit) + SELECT group_concat(user_account.name ORDER BY user_account.name DESC) AS group_concat_1 + FROM user_account + [...] () + {stop}[('spongebob,sandy,patrick',)] + {printsql}ROLLBACK{stop} + +.. tip:: The above demonstration shows use of the ``group_concat()`` function + on SQLite to concatenate strings. As this type of function varies + highly on all backends, SQLAlchemy also provides a backend-agnostic + version specifically for concatenating strings called + :func:`_functions.aggregate_strings`. + +A more specific form of ORDER BY for aggregate functions is the "WITHIN GROUP" +SQL syntax. In some cases, the :meth:`_functions.FunctionElement.aggregate_order_by` +will render this syntax directly, when compiling on a backend such as Oracle +Database or Microsoft SQL Server which requires it for all aggregate ordering. +Beyond that, the "WITHIN GROUP" SQL syntax must sometimes be called upon explicitly, +when used in conjunction with an "ordered set" or a "hypothetical set" +aggregate function, supported by PostgreSQL, Oracle Database, and Microsoft SQL +Server. Common "ordered set" functions include ``percentile_cont()`` and +``rank()``. SQLAlchemy includes built in implementations :class:`_functions.rank`, :class:`_functions.dense_rank`, :class:`_functions.mode`, :class:`_functions.percentile_cont` and -:class:`_functions.percentile_disc` which include a :meth:`_functions.FunctionElement.within_group` -method:: +:class:`_functions.percentile_disc` which include a +:meth:`_functions.FunctionElement.within_group` method:: >>> print( ... func.unnest( diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 623acff128..137979dab3 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -80,6 +80,8 @@ from .sql import ColumnExpressionArgument as ColumnExpressionArgument from .sql import NotNullable as NotNullable from .sql import Nullable as Nullable from .sql import SelectLabelStyle as SelectLabelStyle +from .sql.expression import aggregate_order_by as aggregate_order_by +from .sql.expression import AggregateOrderBy as AggregateOrderBy from .sql.expression import Alias as Alias from .sql.expression import alias as alias from .sql.expression import AliasedReturnsRows as AliasedReturnsRows diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 12ac6edde2..ff67ee1ef5 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -994,6 +994,7 @@ from ...sql import sqltypes from ...sql import try_cast as try_cast # noqa: F401 from ...sql import util as sql_util from ...sql._typing import is_sql_compiler +from ...sql.compiler import AggregateOrderByStyle from ...sql.compiler import InsertmanyvaluesSentinelOpts from ...sql.elements import TryCast as TryCast # noqa: F401 from ...types import BIGINT @@ -2035,10 +2036,16 @@ class MSSQLCompiler(compiler.SQLCompiler): return "LEN%s" % self.function_argspec(fn, **kw) def visit_aggregate_strings_func(self, fn, **kw): - expr = fn.clauses.clauses[0]._compiler_dispatch(self, **kw) - kw["literal_execute"] = True - delimeter = fn.clauses.clauses[1]._compiler_dispatch(self, **kw) - return f"string_agg({expr}, {delimeter})" + cl = list(fn.clauses) + expr, delimeter = cl[0:2] + + literal_exec = dict(kw) + literal_exec["literal_execute"] = True + + return ( + f"string_agg({expr._compiler_dispatch(self, **kw)}, " + f"{delimeter._compiler_dispatch(self, **literal_exec)})" + ) def visit_pow_func(self, fn, **kw): return f"POWER{self.function_argspec(fn)}" @@ -2985,6 +2992,8 @@ class MSDialect(default.DefaultDialect): """ + aggregate_order_by_style = AggregateOrderByStyle.WITHIN_GROUP + # supports_native_uuid is partial here, so we implement our # own impl type diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 889ab858b2..1c51302ba2 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1359,10 +1359,28 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_aggregate_strings_func( self, fn: aggregate_strings, **kw: Any ) -> str: - expr, delimeter = ( - elem._compiler_dispatch(self, **kw) for elem in fn.clauses - ) - return f"group_concat({expr} SEPARATOR {delimeter})" + + order_by = getattr(fn.clauses, "aggregate_order_by", None) + + cl = list(fn.clauses) + expr, delimeter = cl[0:2] + + literal_exec = dict(kw) + literal_exec["literal_execute"] = True + + if order_by is not None: + return ( + f"group_concat({expr._compiler_dispatch(self, **kw)} " + f"ORDER BY {order_by._compiler_dispatch(self, **kw)} " + f"SEPARATOR " + f"{delimeter._compiler_dispatch(self, **literal_exec)})" + ) + else: + return ( + f"group_concat({expr._compiler_dispatch(self, **kw)} " + f"SEPARATOR " + f"{delimeter._compiler_dispatch(self, **literal_exec)})" + ) def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str: return "nextval(%s)" % self.preparer.format_sequence(sequence) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 83f562eba5..390afdd8f5 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1001,6 +1001,7 @@ from ...sql import selectable as sa_selectable from ...sql import sqltypes from ...sql import util as sql_util from ...sql import visitors +from ...sql.compiler import AggregateOrderByStyle from ...sql.visitors import InternalTraversal from ...types import BLOB from ...types import CHAR @@ -1712,7 +1713,9 @@ class OracleCompiler(compiler.SQLCompiler): ) def visit_aggregate_strings_func(self, fn, **kw): - return "LISTAGG%s" % self.function_argspec(fn, **kw) + return super().visit_aggregate_strings_func( + fn, use_function_name="LISTAGG", **kw + ) def _visit_bitwise(self, binary, fn_name, custom_right=None, **kw): left = self.process(binary.left, **kw) @@ -1971,6 +1974,8 @@ class OracleDialect(default.DefaultDialect): supports_empty_insert = False supports_identity_columns = True + aggregate_order_by_style = AggregateOrderByStyle.WITHIN_GROUP + statement_compiler = OracleCompiler ddl_compiler = OracleDDLCompiler type_compiler_cls = OracleTypeCompiler diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index d06b131a62..aa739914ac 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2022,7 +2022,9 @@ class PGCompiler(compiler.SQLCompiler): return value def visit_aggregate_strings_func(self, fn, **kw): - return "string_agg%s" % self.function_argspec(fn) + return super().visit_aggregate_strings_func( + fn, use_function_name="string_agg", **kw + ) def visit_pow_func(self, fn, **kw): return f"power{self.function_argspec(fn)}" diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 63337c7aff..d251c11d6c 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -58,21 +58,15 @@ class aggregate_order_by(expression.ColumnElement[_T]): SELECT array_agg(a ORDER BY b DESC) FROM table; - Similarly:: - - expr = func.string_agg( - table.c.a, aggregate_order_by(literal_column("','"), table.c.a) - ) - stmt = select(expr) - - Would represent: - - .. sourcecode:: sql - - SELECT string_agg(a, ',' ORDER BY a) FROM table; + .. legacy:: An improved dialect-agnostic form of this function is now + available in Core by calling the + :meth:`_functions.Function.aggregate_order_by` method on any function + defined by the backend as an aggregate function. .. seealso:: + :func:`_sql.aggregate_order_by` - Core level function + :class:`_functions.array_agg` """ diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index d1abf26c3c..dd1f6c1987 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1452,7 +1452,9 @@ class SQLiteCompiler(compiler.SQLCompiler): return "length%s" % self.function_argspec(fn) def visit_aggregate_strings_func(self, fn, **kw): - return "group_concat%s" % self.function_argspec(fn) + return super().visit_aggregate_strings_func( + fn, use_function_name="group_concat", **kw + ) def visit_cast(self, cast, **kwargs): if self.dialect.supports_cast: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index fcdb68093a..c456b66e29 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -63,6 +63,7 @@ from ..sql import type_api from ..sql import util as sql_util from ..sql._typing import is_tuple_type from ..sql.base import _NoArg +from ..sql.compiler import AggregateOrderByStyle from ..sql.compiler import DDLCompiler from ..sql.compiler import InsertmanyvaluesSentinelOpts from ..sql.compiler import SQLCompiler @@ -70,7 +71,6 @@ from ..sql.elements import quoted_name from ..util.typing import TupleAny from ..util.typing import Unpack - if typing.TYPE_CHECKING: from types import ModuleType @@ -162,6 +162,8 @@ class DefaultDialect(Dialect): delete_returning_multifrom = False insert_returning = False + aggregate_order_by_style = AggregateOrderByStyle.INLINE + cte_follows_insert = False supports_native_enum = False diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 0c998996a1..9f78daa59d 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -56,6 +56,7 @@ if TYPE_CHECKING: from ..exc import StatementError from ..sql import Executable from ..sql.compiler import _InsertManyValuesBatch + from ..sql.compiler import AggregateOrderByStyle from ..sql.compiler import DDLCompiler from ..sql.compiler import IdentifierPreparer from ..sql.compiler import InsertmanyvaluesSentinelOpts @@ -864,6 +865,13 @@ class Dialect(EventTarget): """ + aggregate_order_by_style: AggregateOrderByStyle + """Style of ORDER BY supported for arbitrary aggregate functions + + .. versionadded:: 2.1 + + """ + insert_executemany_returning: bool """dialect / driver / database supports some means of providing INSERT...RETURNING support when dialect.do_executemany() is used. diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index a3aa65c2b4..3b91fc8161 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -20,6 +20,7 @@ from .ddl import BaseDDLElement as BaseDDLElement from .ddl import DDL as DDL from .ddl import DDLElement as DDLElement from .ddl import ExecutableDDLElement as ExecutableDDLElement +from .expression import aggregate_order_by as aggregate_order_by from .expression import Alias as Alias from .expression import alias as alias from .expression import all_ as all_ diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index cf9a5b246a..cc2f8201cc 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -25,6 +25,7 @@ from . import operators from . import roles from .base import _NoArg from .coercions import _document_text_coercion +from .elements import AggregateOrderBy from .elements import BindParameter from .elements import BooleanClauseList from .elements import Case @@ -1948,20 +1949,24 @@ def within_group( Used against so-called "ordered set aggregate" and "hypothetical set aggregate" functions, including :class:`.percentile_cont`, - :class:`.rank`, :class:`.dense_rank`, etc. + :class:`.rank`, :class:`.dense_rank`, etc. This feature is typically + used by Oracle Database, Microsoft SQL Server. + + For generalized ORDER BY of aggregate functions on all included + backends, including PostgreSQL, MySQL/MariaDB, SQLite as well as Oracle + and SQL Server, the :func:`_sql.aggregate_order_by` provides a more + general approach that compiles to "WITHIN GROUP" only on those backends + which require it. :func:`_expression.within_group` is usually called using the :meth:`.FunctionElement.within_group` method, e.g.:: - from sqlalchemy import within_group - stmt = select( - department.c.id, func.percentile_cont(0.5).within_group(department.c.salary.desc()), ) The above statement would produce SQL similar to - ``SELECT department.id, percentile_cont(0.5) + ``SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY department.salary DESC)``. :param element: a :class:`.FunctionElement` construct, typically @@ -1974,9 +1979,62 @@ def within_group( :ref:`tutorial_functions_within_group` - in the :ref:`unified_tutorial` + :func:`_sql.aggregate_order_by` - helper for PostgreSQL, MySQL, + SQLite aggregate functions + :data:`.expression.func` :func:`_expression.over` """ return WithinGroup(element, *order_by) + + +def aggregate_order_by( + element: FunctionElement[_T], *order_by: _ColumnExpressionArgument[Any] +) -> AggregateOrderBy[_T]: + r"""Produce a :class:`.AggregateOrderBy` object against a function. + + Used for aggregating functions such as :class:`_functions.array_agg`, + ``group_concat``, ``json_agg`` on backends that support ordering via an + embedded ``ORDER BY`` parameter, e.g. PostgreSQL, MySQL/MariaDB, SQLite. + When used on backends like Oracle and SQL Server, SQL compilation uses that + of :class:`.WithinGroup`. On PostgreSQL, compilation is fixed at embedded + ``ORDER BY``; for set aggregation functions where PostgreSQL requires the + use of ``WITHIN GROUP``, :func:`_expression.within_group` should be used + explicitly. + + :func:`_expression.aggregate_order_by` is usually called using + the :meth:`.FunctionElement.aggregate_order_by` method, e.g.:: + + stmt = select( + func.array_agg(department.c.code).aggregate_order_by( + department.c.code.desc() + ), + ) + + which would produce an expression resembling: + + .. sourcecode:: sql + + SELECT array_agg(department.code ORDER BY department.code DESC) + AS array_agg_1 FROM department + + The ORDER BY argument may also be multiple terms. + + When using the backend-agnostic :class:`_functions.aggregate_strings` + string aggregation function, use the + :paramref:`_functions.aggregate_strings.order_by` parameter to indicate a + dialect-agnostic ORDER BY expression. + + .. versionadded:: 2.0.44 Generalized the PostgreSQL-specific + :func:`_postgresql.aggregate_order_by` function to a method on + :class:`.Function` that is backend agnostic. + + .. seealso:: + + :class:`_functions.aggregate_strings` - backend-agnostic string + concatenation function which also supports ORDER BY + + """ # noqa: E501 + return AggregateOrderBy(element, *order_by) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6753948a24..e95eaa5918 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -649,6 +649,26 @@ class InsertmanyvaluesSentinelOpts(FastIntFlag): RENDER_SELECT_COL_CASTS = 64 +class AggregateOrderByStyle(IntEnum): + """Describes backend database's capabilities with ORDER BY for aggregate + functions + + .. versionadded:: 2.1 + + """ + + NONE = 0 + """database has no ORDER BY for aggregate functions""" + + INLINE = 1 + """ORDER BY is rendered inside the function's argument list, typically as + the last element""" + + WITHIN_GROUP = 2 + """the WITHIN GROUP (ORDER BY ...) phrase is used for all aggregate + functions (not just the ordered set ones)""" + + class CompilerState(IntEnum): COMPILING = 0 """statement is present, compilation phase in progress""" @@ -1012,6 +1032,39 @@ class _CompileLabel( return self +class aggregate_orderby_inline( + roles.BinaryElementRole[Any], elements.CompilerColumnElement +): + """produce ORDER BY inside of function argument lists""" + + __visit_name__ = "aggregate_orderby_inline" + __slots__ = "element", "aggregate_order_by" + + def __init__(self, element, orderby): + self.element = element + self.aggregate_order_by = orderby + + def __iter__(self): + return iter(self.element) + + @property + def proxy_set(self): + return self.element.proxy_set + + @property + def type(self): + return self.element.type + + def self_group(self, **kw): + return self + + def _with_binary_element_type(self, type_): + return aggregate_orderby_inline( + self.element._with_binary_element_type(type_), + self.aggregate_order_by, + ) + + class ilike_case_insensitive( roles.BinaryElementRole[Any], elements.CompilerColumnElement ): @@ -2914,6 +2967,62 @@ class SQLCompiler(Compiled): funcfilter.criterion._compiler_dispatch(self, **kwargs), ) + def visit_aggregateorderby(self, aggregateorderby, **kwargs): + if self.dialect.aggregate_order_by_style is AggregateOrderByStyle.NONE: + raise exc.CompileError( + "this dialect does not support " + "ORDER BY within an aggregate function" + ) + elif ( + self.dialect.aggregate_order_by_style + is AggregateOrderByStyle.INLINE + ): + new_fn = aggregateorderby.element._clone() + new_fn.clause_expr = elements.Grouping( + aggregate_orderby_inline( + new_fn.clause_expr.element, aggregateorderby.order_by + ) + ) + + return new_fn._compiler_dispatch(self, **kwargs) + else: + return self.visit_withingroup(aggregateorderby, **kwargs) + + def visit_aggregate_orderby_inline(self, element, **kw): + return "%s ORDER BY %s" % ( + self.process(element.element, **kw), + self.process(element.aggregate_order_by, **kw), + ) + + def visit_aggregate_strings_func(self, fn, *, use_function_name, **kw): + # aggreagate_order_by attribute is present if visit_function + # gave us a Function with aggregate_orderby_inline() as the inner + # contents + order_by = getattr(fn.clauses, "aggregate_order_by", None) + + literal_exec = dict(kw) + literal_exec["literal_execute"] = True + + # break up the function into its components so we can apply + # literal_execute to the second argument (the delimeter) + cl = list(fn.clauses) + expr, delimeter = cl[0:2] + if ( + order_by is not None + and self.dialect.aggregate_order_by_style + is AggregateOrderByStyle.INLINE + ): + return ( + f"{use_function_name}({expr._compiler_dispatch(self, **kw)}, " + f"{delimeter._compiler_dispatch(self, **literal_exec)} " + f"ORDER BY {order_by._compiler_dispatch(self, **kw)})" + ) + else: + return ( + f"{use_function_name}({expr._compiler_dispatch(self, **kw)}, " + f"{delimeter._compiler_dispatch(self, **literal_exec)})" + ) + def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) return "EXTRACT(%s FROM %s)" % ( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8f68e520b8..fbb2f8632b 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -819,6 +819,17 @@ class CompilerColumnElement( _propagate_attrs = util.EMPTY_DICT _is_collection_aggregate = False + _is_implicitly_boolean = False + + def _with_binary_element_type(self, type_): + raise NotImplementedError() + + def _gen_cache_key(self, anon_map, bindparams): + raise NotImplementedError() + + @property + def _from_objects(self) -> List[FromClause]: + raise NotImplementedError() # SQLCoreOperations should be suiting the ExpressionElementRole @@ -4213,10 +4224,15 @@ class Grouping(GroupedElement, ColumnElement[_T]): ("element", InternalTraversal.dp_clauseelement), ] - element: Union[TextClause, ClauseList, ColumnElement[_T]] + element: Union[ + TextClause, ClauseList, ColumnElement[_T], CompilerColumnElement + ] def __init__( - self, element: Union[TextClause, ClauseList, ColumnElement[_T]] + self, + element: Union[ + TextClause, ClauseList, ColumnElement[_T], CompilerColumnElement + ], ): self.element = element @@ -4484,31 +4500,34 @@ class _FrameClause(ClauseElement): ) -class WithinGroup(ColumnElement[_T]): - """Represent a WITHIN GROUP (ORDER BY) clause. +class AggregateOrderBy(WrapsColumnExpression[_T]): + """Represent an aggregate ORDER BY expression. - This is a special operator against so-called - "ordered set aggregate" and "hypothetical - set aggregate" functions, including ``percentile_cont()``, - ``rank()``, ``dense_rank()``, etc. + This is a special operator against aggregate functions such as + ``array_agg()``, ``json_arrayagg()`` ``string_agg()``, etc. that provides + for an ORDER BY expression, using a syntax that's compatible with + the backend. - It's supported only by certain database backends, such as PostgreSQL, - Oracle Database and MS SQL Server. + :class:`.AggregateOrderBy` is a generalized version of the + :class:`.WithinGroup` construct, the latter of which always provides a + "WITHIN GROUP (ORDER BY ...)" expression. :class:`.AggregateOrderBy` will + also compile to "WITHIN GROUP (ORDER BY ...)" on backends such as Oracle + and SQL Server that don't have another style of aggregate function + ordering. + + .. versionadded:: 2.1 - The :class:`.WithinGroup` construct extracts its type from the - method :meth:`.FunctionElement.within_group_type`. If this returns - ``None``, the function's ``.type`` is used. """ - __visit_name__ = "withingroup" + __visit_name__ = "aggregateorderby" _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("order_by", InternalTraversal.dp_clauseelement), ] - order_by: Optional[ClauseList] = None + order_by: ClauseList def __init__( self, @@ -4516,10 +4535,21 @@ class WithinGroup(ColumnElement[_T]): *order_by: _ColumnExpressionArgument[Any], ): self.element = element - if order_by is not None: - self.order_by = ClauseList( - *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole - ) + if not order_by: + raise TypeError("at least one ORDER BY element is required") + self.order_by = ClauseList( + *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole + ) + + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return self.element.type + + @property + def wrapped_column_expression(self) -> ColumnElement[_T]: + return self.element def __reduce__(self): return self.__class__, (self.element,) + ( @@ -4569,16 +4599,6 @@ class WithinGroup(ColumnElement[_T]): return self return FunctionFilter(self, *criterion) - if not TYPE_CHECKING: - - @util.memoized_property - def type(self) -> TypeEngine[_T]: # noqa: A001 - wgt = self.element.within_group_type(self) - if wgt is not None: - return wgt - else: - return self.element.type - @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list( @@ -4592,6 +4612,37 @@ class WithinGroup(ColumnElement[_T]): ) +class WithinGroup(AggregateOrderBy[_T]): + """Represent a WITHIN GROUP (ORDER BY) clause. + + This is a special operator against so-called + "ordered set aggregate" and "hypothetical + set aggregate" functions, including ``percentile_cont()``, + ``rank()``, ``dense_rank()``, etc. + + It's supported only by certain database backends, such as PostgreSQL, + Oracle Database and MS SQL Server. + + The :class:`.WithinGroup` construct extracts its type from the + method :meth:`.FunctionElement.within_group_type`. If this returns + ``None``, the function's ``.type`` is used. + + """ + + __visit_name__ = "withingroup" + inherit_cache = True + + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + wgt = self.element.within_group_type(self) + if wgt is not None: + return wgt + else: + return self.element.type + + class FunctionFilter(Generative, ColumnElement[_T]): """Represent a function FILTER clause. @@ -4621,7 +4672,7 @@ class FunctionFilter(Generative, ColumnElement[_T]): def __init__( self, - func: Union[FunctionElement[_T], WithinGroup[_T]], + func: Union[FunctionElement[_T], AggregateOrderBy[_T]], *criterion: _ColumnExpressionArgument[bool], ): self.func = func diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f7847bf7e6..267a572a5b 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -13,6 +13,7 @@ from __future__ import annotations from ._dml_constructors import delete as delete from ._dml_constructors import insert as insert from ._dml_constructors import update as update +from ._elements_constructors import aggregate_order_by as aggregate_order_by from ._elements_constructors import all_ as all_ from ._elements_constructors import and_ as and_ from ._elements_constructors import any_ as any_ @@ -72,6 +73,7 @@ from .dml import Update as Update from .dml import UpdateBase as UpdateBase from .dml import ValuesBase as ValuesBase from .elements import _truncated_label as _truncated_label +from .elements import AggregateOrderBy as AggregateOrderBy from .elements import BinaryExpression as BinaryExpression from .elements import BindParameter as BindParameter from .elements import BooleanClauseList as BooleanClauseList diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 9a28dcfb4f..0230851227 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -40,6 +40,7 @@ from .base import Executable from .base import Generative from .base import HasMemoized from .elements import _type_from_args +from .elements import AggregateOrderBy from .elements import BinaryExpression from .elements import BindParameter from .elements import Cast @@ -469,6 +470,32 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): groups=groups, ) + def aggregate_order_by( + self, *order_by: _ColumnExpressionArgument[Any] + ) -> AggregateOrderBy[_T]: + r"""Produce a :class:`.AggregateOrderBy` object against a function. + + Used for aggregating functions such as :class:`_functions.array_agg`, + ``group_concat``, ``json_agg`` on backends that support ordering via an + embedded ORDER BY parameter, e.g. PostgreSQL, MySQL/MariaDB, SQLite. + When used on backends like Oracle and SQL Server, SQL compilation uses + that of :class:`.WithinGroup`. + + See :func:`_expression.aggregate_order_by` for a full description. + + .. versionadded:: 2.0.44 Generalized the PostgreSQL-specific + :func:`_postgresql.aggregate_order_by` function to a method on + :class:`.Function` that is backend agnostic. + + .. seealso:: + + :class:`_functions.aggregate_strings` - backend-agnostic string + concatenation function which also supports ORDER BY + + """ + + return AggregateOrderBy(self, *order_by) + def within_group( self, *order_by: _ColumnExpressionArgument[Any] ) -> WithinGroup[_T]: @@ -476,7 +503,11 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): Used against so-called "ordered set aggregate" and "hypothetical set aggregate" functions, including :class:`.percentile_cont`, - :class:`.rank`, :class:`.dense_rank`, etc. + :class:`.rank`, :class:`.dense_rank`, etc. This feature is typically + used by PostgreSQL, Oracle Database, and Microsoft SQL Server. + + For simple ORDER BY expressions within aggregate functions on + PostgreSQL, MySQL/MariaDB, SQLite, see :func:`_sql.aggregate_order_by`. See :func:`_expression.within_group` for a full description. @@ -2127,17 +2158,36 @@ class aggregate_strings(GenericFunction[str]): stmt = select(func.aggregate_strings(table.c.str_col, ".")) - The return type of this function is :class:`.String`. - .. versionadded:: 2.0.21 - """ + To add ordering to the expression, use the + :meth:`_functions.FunctionElement.aggregate_order_by` modifier method, + which will emit ORDER BY within the appropriate part of the column + expression (varies by backend):: + + stmt = select( + func.aggregate_strings(table.c.str_col, ".").aggregate_order_by( + table.c.str_col + ) + ) + + .. versionadded:: 2.1 added :meth:`_functions.FunctionElement.aggregate_order_by` + for all aggregate functions. + + :param clause: the SQL expression to be concatenated + + :param separator: separator string + + + """ # noqa: E501 type = sqltypes.String() _has_args = True inherit_cache = True def __init__( - self, clause: _ColumnExpressionArgument[Any], separator: str + self, + clause: _ColumnExpressionArgument[Any], + separator: str, ) -> None: super().__init__(clause, separator) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index c45b4bc9b2..a3642003da 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -2110,25 +2110,31 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): is_(expr.type.item_type.__class__, Integer) @testing.combinations( - ("original", False, False), - ("just_enum", True, False), - ("just_order_by", False, True), - ("issue_5989", True, True), - id_="iaa", - argnames="with_enum, using_aggregate_order_by", + ("original", False), + ("just_enum", True), + ("just_order_by", False), + ("issue_5989", True), + id_="ia", + argnames="with_enum", ) - def test_array_agg_specific(self, with_enum, using_aggregate_order_by): + @testing.variation("order_by_type", ["none", "legacy", "core"]) + def test_array_agg_specific(self, with_enum, order_by_type): element = ENUM(name="pgenum") if with_enum else Integer() element_type = type(element) - expr = ( - array_agg( + + if order_by_type.none: + expr = array_agg(column("q", element)) + elif order_by_type.legacy: + expr = array_agg( aggregate_order_by( column("q", element), column("idx", Integer) ) ) - if using_aggregate_order_by - else array_agg(column("q", element)) - ) + elif order_by_type.core: + expr = array_agg(column("q", element)).aggregate_order_by( + column("idx", Integer) + ) + is_(expr.type.__class__, postgresql.ARRAY) is_(expr.type.item_type.__class__, element_type) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 1956a8db98..45cea0c46a 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -32,6 +32,7 @@ from sqlalchemy import union from sqlalchemy import union_all from sqlalchemy import values from sqlalchemy.schema import Sequence +from sqlalchemy.sql import aggregate_order_by from sqlalchemy.sql import bindparam from sqlalchemy.sql import ColumnElement from sqlalchemy.sql import dml @@ -425,6 +426,24 @@ class CoreFixtures: func.json_to_recordset("{foo}").column_valued(), func.json_to_recordset("{foo}").scalar_table_valued("foo"), ), + lambda: ( + aggregate_order_by(column("a"), column("a")), + aggregate_order_by(column("a"), column("b")), + aggregate_order_by(column("a"), column("a").desc()), + aggregate_order_by(column("a"), column("a").nulls_first()), + aggregate_order_by(column("a"), column("a").desc().nulls_first()), + aggregate_order_by(column("a", Integer), column("b")), + aggregate_order_by(column("a"), column("b"), column("c")), + aggregate_order_by(column("a"), column("c"), column("b")), + aggregate_order_by(column("a"), column("b").desc(), column("c")), + aggregate_order_by( + column("a"), column("b").nulls_first(), column("c") + ), + aggregate_order_by( + column("a"), column("b").desc().nulls_first(), column("c") + ), + aggregate_order_by(column("a", Integer), column("a"), column("b")), + ), lambda: (table_a.table_valued(), table_b.table_valued()), lambda: (True_(), False_()), lambda: (Null(),), diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 4ca935766f..c1feee694c 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -6142,8 +6142,9 @@ class StringifySpecialTest(fixtures.TestBase): ) eq_ignore_whitespace( str(stmt), - "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " - "WITHIN GROUP (ORDER BY mytable.name DESC) AS anon_1 FROM mytable", + "SELECT mytable.myid, percentile_cont(:percentile_cont_2) " + "WITHIN GROUP (ORDER BY mytable.name DESC) AS percentile_cont_1 " + "FROM mytable", ) @testing.combinations( diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 28cdb03a96..b569c41ca3 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -10,6 +10,7 @@ from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import Date from sqlalchemy import DateTime +from sqlalchemy import exc from sqlalchemy import extract from sqlalchemy import Float from sqlalchemy import func @@ -17,6 +18,7 @@ from sqlalchemy import Integer from sqlalchemy import JSON from sqlalchemy import literal from sqlalchemy import literal_column +from sqlalchemy import MetaData from sqlalchemy import Numeric from sqlalchemy import select from sqlalchemy import Sequence @@ -33,7 +35,9 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import sqlite from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY from sqlalchemy.dialects.postgresql import array +from sqlalchemy.engine import default from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import aggregate_order_by from sqlalchemy.sql import column from sqlalchemy.sql import functions from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL @@ -41,6 +45,8 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import quoted_name from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table +from sqlalchemy.sql import util +from sqlalchemy.sql.compiler import AggregateOrderByStyle from sqlalchemy.sql.compiler import BIND_TEMPLATES from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.sql.functions import GenericFunction @@ -51,6 +57,7 @@ from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertions import expect_warnings from sqlalchemy.testing.engines import all_dialects from sqlalchemy.testing.provision import normalize_sequence @@ -223,28 +230,26 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): @testing.combinations( ( - "SELECT group_concat(t.value, ?) AS aggregate_strings_1 FROM t", + "SELECT group_concat(t.value, ',') AS aggregate_strings_1 FROM t", "sqlite", ), ( - "SELECT string_agg(t.value, %(aggregate_strings_2)s) AS " - "aggregate_strings_1 FROM t", + "SELECT string_agg(t.value, ',') AS " "aggregate_strings_1 FROM t", "postgresql", ), ( "SELECT string_agg(t.value, " - "__[POSTCOMPILE_aggregate_strings_2]) AS " + "',') AS " "aggregate_strings_1 FROM t", "mssql", ), ( - "SELECT group_concat(t.value SEPARATOR %s) " + "SELECT group_concat(t.value SEPARATOR ',') " "AS aggregate_strings_1 FROM t", "mysql", ), ( - "SELECT LISTAGG(t.value, :aggregate_strings_2) AS" - " aggregate_strings_1 FROM t", + "SELECT LISTAGG(t.value, ',') AS" " aggregate_strings_1 FROM t", "oracle", ), ) @@ -252,7 +257,52 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): t = table("t", column("value", String)) stmt = select(func.aggregate_strings(t.c.value, ",")) - self.assert_compile(stmt, expected_sql, dialect=dialect) + self.assert_compile( + stmt, expected_sql, dialect=dialect, render_postcompile=True + ) + + @testing.combinations( + ( + "SELECT group_concat(t.value, ',' ORDER BY t.ordering DESC) " + "AS aggregate_strings_1 FROM t", + "sqlite", + ), + ( + "SELECT string_agg(t.value, ',' " + "ORDER BY t.ordering DESC) AS " + "aggregate_strings_1 FROM t", + "postgresql", + ), + ( + "SELECT string_agg(t.value, ',') " + "WITHIN GROUP (ORDER BY t.ordering DESC) AS " + "aggregate_strings_1 FROM t", + "mssql", + ), + ( + "SELECT group_concat(t.value " + "ORDER BY t.ordering DESC SEPARATOR ',') " + "AS aggregate_strings_1 FROM t", + "mysql", + ), + ( + "SELECT LISTAGG(t.value, ',') " + "WITHIN GROUP (ORDER BY t.ordering DESC) AS" + " aggregate_strings_1 FROM t", + "oracle", + ), + ) + def test_aggregate_strings_order_by(self, expected_sql, dialect): + t = table("t", column("value", String), column("ordering", String)) + stmt = select( + func.aggregate_strings(t.c.value, ",").aggregate_order_by( + t.c.ordering.desc() + ) + ) + + self.assert_compile( + stmt, expected_sql, dialect=dialect, render_postcompile=True + ) def test_cube_operators(self): t = table( @@ -603,11 +653,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): # this still relies upon a strategy for table metadata as we have # in serializer. - f1 = func.percentile_cont(literal(1)).within_group() + f1 = func.percentile_cont(literal(1)).within_group(column("q")) self.assert_compile( pickle.loads(pickle.dumps(f1)), - "percentile_cont(:param_1) WITHIN GROUP (ORDER BY )", + "percentile_cont(:param_1) WITHIN GROUP (ORDER BY q)", ) f1 = func.percentile_cont(literal(1)).within_group( @@ -902,11 +952,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( stmt, - "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " + "SELECT mytable.myid, percentile_cont(:percentile_cont_2) " "WITHIN GROUP (ORDER BY mytable.name) " - "AS anon_1 " + "AS percentile_cont_1 " "FROM mytable", - {"percentile_cont_1": 0.5}, + {"percentile_cont_2": 0.5}, ) def test_within_group_multi(self): @@ -918,11 +968,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( stmt, - "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " + "SELECT mytable.myid, percentile_cont(:percentile_cont_2) " "WITHIN GROUP (ORDER BY mytable.name, mytable.description) " - "AS anon_1 " + "AS percentile_cont_1 " "FROM mytable", - {"percentile_cont_1": 0.5}, + {"percentile_cont_2": 0.5}, ) def test_within_group_desc(self): @@ -932,11 +982,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( stmt, - "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " + "SELECT mytable.myid, percentile_cont(:percentile_cont_2) " "WITHIN GROUP (ORDER BY mytable.name DESC) " - "AS anon_1 " + "AS percentile_cont_1 " "FROM mytable", - {"percentile_cont_1": 0.5}, + {"percentile_cont_2": 0.5}, ) def test_within_group_w_over(self): @@ -1064,6 +1114,121 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): }, ) + @testing.variation("style", ["none", "inline", "within_group"]) + def test_aggregate_order_by_one(self, style): + table = Table( + "table1", MetaData(), Column("a", Integer), Column("b", Integer) + ) + expr = func.array_agg(table.c.a).aggregate_order_by(table.c.b.desc()) + stmt = select(expr) + + if style.none: + dialect = default.DefaultDialect() + dialect.aggregate_order_by_style = AggregateOrderByStyle.NONE + with expect_raises_message( + exc.CompileError, + "this dialect does not support ORDER BY " + "within an aggregate function", + ): + stmt.compile(dialect=dialect) + elif style.within_group: + dialect = default.DefaultDialect() + dialect.aggregate_order_by_style = ( + AggregateOrderByStyle.WITHIN_GROUP + ) + self.assert_compile( + stmt, + "SELECT array_agg(table1.a) " + "WITHIN GROUP (ORDER BY table1.b DESC) " + "AS array_agg_1 FROM table1", + dialect=dialect, + ) + else: + self.assert_compile( + stmt, + "SELECT array_agg(table1.a ORDER BY table1.b DESC) " + "AS array_agg_1 FROM table1", + ) + + @testing.variation("style", ["inline", "within_group"]) + def test_aggregate_order_by_two(self, style): + table = Table( + "table1", MetaData(), Column("a", Integer), Column("b", Integer) + ) + expr = func.string_agg( + table.c.a, literal_column("','") + ).aggregate_order_by(table.c.a) + stmt = select(expr) + + if style.within_group: + dialect = default.DefaultDialect() + dialect.aggregate_order_by_style = ( + AggregateOrderByStyle.WITHIN_GROUP + ) + self.assert_compile( + stmt, + "SELECT string_agg(table1.a, ',') " + "WITHIN GROUP (ORDER BY table1.a) " + "AS string_agg_1 FROM table1", + dialect=dialect, + ) + else: + self.assert_compile( + stmt, + "SELECT string_agg(table1.a, ',' ORDER BY table1.a) " + "AS string_agg_1 FROM table1", + ) + + def test_aggregate_order_by_multi_col(self): + table = Table( + "table1", MetaData(), Column("a", Integer), Column("b", Integer) + ) + expr = func.string_agg( + table.c.a, + literal_column("','"), + ).aggregate_order_by(table.c.a, table.c.b.desc()) + stmt = select(expr) + + self.assert_compile( + stmt, + "SELECT string_agg(table1.a, " + "',' ORDER BY table1.a, table1.b DESC) " + "AS string_agg_1 FROM table1", + ) + + def test_aggregate_order_by_type_propagate(self): + table = Table( + "table1", MetaData(), Column("a", Integer), Column("b", String) + ) + expr = func.foo_agg(table.c.a, type_=Integer).aggregate_order_by( + table.c.b.desc() + ) + + is_(expr.type._type_affinity, Integer) + + def test_aggregate_order_by_no_arg(self): + assert_raises_message( + TypeError, + "at least one ORDER BY element is required", + aggregate_order_by, + literal_column("','"), + ) + + def test_aggregate_order_by_adapt(self): + table = Table( + "table1", MetaData(), Column("a", Integer), Column("b", Integer) + ) + expr = aggregate_order_by(func.array_agg(table.c.a), table.c.b.desc()) + stmt = select(expr) + + a1 = table.alias("foo") + stmt2 = util.ClauseAdapter(a1).traverse(stmt) + self.assert_compile( + stmt2, + "SELECT array_agg(foo.a ORDER BY foo.b DESC) AS array_agg_1 " + "FROM table1 AS foo", + ) + class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase): def test_array_agg(self): @@ -1266,8 +1431,17 @@ class ExecuteTest(fixtures.TestBase): @testing.variation("unicode_value", [True, False]) @testing.variation("unicode_separator", [True, False]) + @testing.variation("use_order_by", [True, False]) + @testing.only_on( + ["postgresql", "sqlite", "mysql", "mariadb", "oracle", "mssql"] + ) def test_aggregate_strings_execute( - self, connection, metadata, unicode_value, unicode_separator + self, + connection, + metadata, + unicode_value, + unicode_separator, + use_order_by, ): values_t = Table( "values", @@ -1279,10 +1453,10 @@ class ExecuteTest(fixtures.TestBase): connection.execute( values_t.insert(), [ - {"value": "a", "unicode_value": "測試"}, - {"value": "b", "unicode_value": "téble2"}, + {"value": "a", "unicode_value": "b 測試"}, + {"value": "b", "unicode_value": "c téble2"}, {"value": None, "unicode_value": None}, # ignored - {"value": "c", "unicode_value": "🐍 su"}, + {"value": "c", "unicode_value": "a 🐍 su"}, ], ) @@ -1293,22 +1467,81 @@ class ExecuteTest(fixtures.TestBase): if unicode_value: col = values_t.c.unicode_value - expected = separator.join(["測試", "téble2", "🐍 su"]) + if use_order_by: + expected = separator.join(["c téble2", "b 測試", "a 🐍 su"]) + else: + expected = separator.join(["b 測試", "c téble2", "a 🐍 su"]) else: col = values_t.c.value - expected = separator.join(["a", "b", "c"]) + if use_order_by: + expected = separator.join(["c", "b", "a"]) + else: + expected = separator.join(["a", "b", "c"]) # to join on a unicode separator, source string has to be unicode, # so cast(). SQL Server will raise otherwise if unicode_separator: col = cast(col, Unicode(42)) - value = connection.execute( - select(func.aggregate_strings(col, separator)) - ).scalar_one() + if use_order_by: + value = connection.execute( + select( + func.aggregate_strings(col, separator).aggregate_order_by( + col.desc() + ) + ) + ).scalar_one() + else: + value = connection.execute( + select(func.aggregate_strings(col, separator)) + ).scalar_one() eq_(value, expected) + @testing.only_on( + ["postgresql", "sqlite", "mysql", "mariadb", "oracle", "mssql"] + ) + def test_aggregate_order_by( + self, + connection, + metadata, + ): + + values_t = Table( + "values", + metadata, + Column("value", String(2)), + Column("ordering", String(2)), + ) + metadata.create_all(connection) + connection.execute( + values_t.insert(), + [ + {"value": "a", "ordering": "1"}, + {"value": "b", "ordering": "3"}, + {"value": "c", "ordering": "2"}, + ], + ) + + if testing.against("postgresql", "mssql"): + fn = lambda expr: func.string_agg( # noqa: E731 + expr, literal_column("''") + ) + expected = "bca" + elif testing.against(["mysql", "mariadb", "sqlite"]): + fn = func.group_concat + expected = "b,c,a" + elif testing.against("oracle"): + fn = func.listagg + expected = "bca" + else: + assert False + + stmt = select( + fn(values_t.c.value).aggregate_order_by(values_t.c.ordering.desc()) + ) + eq_(connection.scalar(stmt), expected) + @testing.fails_on_everything_except("postgresql") def test_as_from(self, connection): # TODO: shouldn't this work on oracle too ?