From 7bc1fe54a9f89057d3b7ba689020cd0f237dd3ee Mon Sep 17 00:00:00 2001 From: Joshua Morris Date: Sun, 4 Jun 2023 17:07:20 +1000 Subject: [PATCH] Add support for SQL string aggregation function :class:`.string_agg` Returns a :class:`.String` with support for PostgreSQL, SQLite, and MSSQL. fixes #9873 --- lib/sqlalchemy/dialects/sqlite/base.py | 3 ++ lib/sqlalchemy/sql/functions.py | 21 ++++++++ test/sql/test_compare.py | 4 ++ test/sql/test_functions.py | 68 ++++++++++++++++++++++++++ 4 files changed, 96 insertions(+) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 835ec27731..69bdc1e562 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1318,6 +1318,9 @@ class SQLiteCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): return "length%s" % self.function_argspec(fn) + def visit_string_agg_func(self, fn, **kw): + return "group_concat%s" % self.function_argspec(fn) + def visit_cast(self, cast, **kwargs): if self.dialect.supports_cast: return super().visit_cast(cast, **kwargs) diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 30e280c61f..5978ec5ade 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1044,6 +1044,10 @@ class _FunctionGenerator: def session_user(self) -> Type[session_user]: ... + @property + def string_agg(self) -> Type[string_agg[Any]]: + ... + @property def sum(self) -> Type[sum[Any]]: # noqa: A001 ... @@ -1795,3 +1799,20 @@ class grouping_sets(GenericFunction[_T]): """ _has_args = True inherit_cache = True + + +class string_agg(GenericFunction[_T]): + r"""Implement the ``STRING_AGG`` aggregation function + + This function will concatenate non-null values into a string and + separate the values by a delimeter. + + The return type of this function is :class:`.String`. + + """ + type = sqltypes.String() + _has_args = True + inherit_cache = True + + def __init__(self, sep=",", *args, **kwargs): + super().__init__(sep, *args, **kwargs) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 353537ad3e..7dd70ad3ac 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -377,6 +377,10 @@ class CoreFixtures: lambda: (table_a.c.a, table_b.c.a), lambda: (tuple_(1, 2), tuple_(3, 4)), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), + lambda: ( + func.string_agg(table_a.c.b), + func.string_agg(table_b_like_a.c.b), + ), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), func.percentile_cont(0.5).within_group(table_a.c.b), diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index e961a4e465..51ed638940 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -26,6 +26,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import Text from sqlalchemy import true +from sqlalchemy.dialects import mssql from sqlalchemy.dialects import mysql from sqlalchemy.dialects import oracle from sqlalchemy.dialects import postgresql @@ -215,6 +216,33 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ]: self.assert_compile(func.random(), ret, dialect=dialect) + def test_generic_string_agg(self): + t = table("t", column("value", String)) + expr = func.string_agg(t.c.value, ",") + is_(expr.type._type_affinity, String) + stmt = select(expr) + + self.assert_compile( + stmt, + "SELECT group_concat(t.value, ?) AS string_agg_1 FROM t", + dialect=sqlite.dialect(), + checkpositional=(",",), + ) + self.assert_compile( + stmt, + "SELECT string_agg(t.value, ',') AS string_agg_1 FROM t", + dialect=postgresql.dialect(), + literal_binds=True, + render_postcompile=True, + ) + self.assert_compile( + stmt, + "SELECT string_agg(t.value, ',') AS string_agg_1 FROM t", + dialect=mssql.dialect(), + literal_binds=True, + render_postcompile=True, + ) + def test_cube_operators(self): t = table( "t", @@ -1157,6 +1185,46 @@ class ExecuteTest(fixtures.TestBase): (9, "foo"), ) + @testing.provide_metadata + def test_string_agg_execute(self, connection): + meta = self.metadata + values_t = Table("values", meta, Column("value", String)) + meta.create_all(connection) + connection.execute( + values_t.insert(), + [ + {"value": "a"}, + {"value": "b"}, + {"value": "c"}, + ], + ) + rs = connection.execute(select(func.string_agg(values_t.c.value))) + row = rs.scalar() + + assert row == "a,b,c" + rs.close() + + @testing.provide_metadata + def test_string_agg_custom_sep(self, connection): + meta = self.metadata + values_t = Table("values", meta, Column("value", String)) + meta.create_all(connection) + connection.execute( + values_t.insert(), + [ + {"value": "a"}, + {"value": "b"}, + {"value": "c"}, + ], + ) + rs = connection.execute( + select(func.string_agg(values_t.c.value, " and ")) + ) + row = rs.scalar() + + assert row == "a and b and c" + rs.close() + @testing.fails_on_everything_except("postgresql") def test_as_from(self, connection): # TODO: shouldn't this work on oracle too ? -- 2.47.3