From: Joshua Morris Date: Tue, 11 Jul 2023 06:21:37 +0000 (-0400) Subject: Add support for SQL string aggregation function aggregate_strings. X-Git-Tag: rel_2_0_21~10^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d24048a8fb8ad8590454d73e3d8edf4352280f13;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add support for SQL string aggregation function aggregate_strings. Add support for SQL string aggregation function :class:`.aggregate_strings`. Pull request curtesy Joshua Morris. Fixes #9873 Closes: #9892 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9892 Pull-request-sha: 0f4f83bade675f2ff734579fc411d0b354e5a4e6 Change-Id: I6e4afc83664a142e3b1e245978b08a200b6d03d9 --- diff --git a/doc/build/changelog/unreleased_20/9873.rst b/doc/build/changelog/unreleased_20/9873.rst new file mode 100644 index 0000000000..f1071fdf84 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9873.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: usecase, sql + :tickets: 9873 + + Added new generic SQL function :class:`_functions.aggregate_strings`, which + accepts a SQL expression and a decimeter, concatenating strings on multiple + rows into a single aggregate value. The function is compiled on a + per-backend basis, into functions such as ``group_concat(),`` + ``string_agg()``, or ``LISTAGG()``. + Pull request courtesy Joshua Morris. \ No newline at end of file diff --git a/doc/build/core/functions.rst b/doc/build/core/functions.rst index 6fcee6edaa..9771ffeedd 100644 --- a/doc/build/core/functions.rst +++ b/doc/build/core/functions.rst @@ -52,6 +52,9 @@ unknown to SQLAlchemy, built-in or user defined. The section here only describes those functions where SQLAlchemy already knows what argument and return types are in use. +.. autoclass:: aggregate_strings + :no-members: + .. autoclass:: array_agg :no-members: diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 1f80aaef29..6d46687e43 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2057,6 +2057,12 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): return "LEN%s" % self.function_argspec(fn, **kw) + def visit_aggregate_strings_func(self, fn, **kw): + expr = fn.clauses.clauses[0]._compiler_dispatch(self, **kw) + kw["literal_execute"] = True + delimeter = fn.clauses.clauses[1]._compiler_dispatch(self, **kw) + return f"string_agg({expr}, {delimeter})" + def visit_concat_op_expression_clauselist( self, clauselist, operator, **kw ): diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 8b0e1295d1..d3f2a3ff87 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1208,6 +1208,12 @@ class MySQLCompiler(compiler.SQLCompiler): ) return f"{clause} WITH ROLLUP" + def visit_aggregate_strings_func(self, fn, **kw): + expr, delimeter = ( + elem._compiler_dispatch(self, **kw) for elem in fn.clauses + ) + return f"group_concat({expr} SEPARATOR {delimeter})" + def visit_sequence(self, seq, **kw): return "nextval(%s)" % self.preparer.format_sequence(seq) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 4a3ac3ac07..d993ef2692 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1241,6 +1241,9 @@ class OracleCompiler(compiler.SQLCompiler): self.render_literal_value(flags, sqltypes.STRINGTYPE), ) + def visit_aggregate_strings_func(self, fn, **kw): + return "LISTAGG%s" % self.function_argspec(fn, **kw) + class OracleDDLCompiler(compiler.DDLCompiler): def define_constraint_cascades(self, constraint): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 08fd5d3a04..5fe619d6f5 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1868,6 +1868,9 @@ class PGCompiler(compiler.SQLCompiler): value = value.replace("\\", "\\\\") return value + def visit_aggregate_strings_func(self, fn, **kw): + return "string_agg%s" % self.function_argspec(fn) + def visit_sequence(self, seq, **kw): return "nextval('%s')" % self.preparer.format_sequence(seq) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index bd6c1dff87..a8e5368407 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1318,6 +1318,9 @@ class SQLiteCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): return "length%s" % self.function_argspec(fn) + def visit_aggregate_strings_func(self, fn, **kw): + return "group_concat%s" % self.function_argspec(fn) + def visit_cast(self, cast, **kwargs): if self.dialect.supports_cast: return super().visit_cast(cast, **kwargs) diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 30e280c61f..fc23e9d215 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -916,6 +916,10 @@ class _FunctionGenerator: # code within this block is **programmatically, # statically generated** by tools/generate_sql_functions.py + @property + def aggregate_strings(self) -> Type[aggregate_strings]: + ... + @property def ansifunction(self) -> Type[AnsiFunction[Any]]: ... @@ -1795,3 +1799,30 @@ class grouping_sets(GenericFunction[_T]): """ _has_args = True inherit_cache = True + + +class aggregate_strings(GenericFunction[str]): + """Implement a generic string aggregation function. + + This function will concatenate non-null values into a string and + separate the values by a delimiter. + + This function is compiled on a per-backend basis, into functions + such as ``group_concat()``, ``string_agg()``, or ``LISTAGG()``. + + e.g. Example usage with delimiter '.':: + + stmt = select(func.aggregate_strings(table.c.str_col, ".")) + + The return type of this function is :class:`.String`. + + .. versionadded: 2.0.21 + + """ + + type = sqltypes.String() + _has_args = True + inherit_cache = True + + def __init__(self, clause, separator): + super().__init__(clause, separator) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 353537ad3e..b2be90f60c 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -377,6 +377,10 @@ class CoreFixtures: lambda: (table_a.c.a, table_b.c.a), lambda: (tuple_(1, 2), tuple_(3, 4)), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), + lambda: ( + func.aggregate_strings(table_a.c.b, ","), + func.aggregate_strings(table_b_like_a.c.b, ","), + ), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), func.percentile_cont(0.5).within_group(table_a.c.b), diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index e961a4e465..c47601b761 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -26,6 +26,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import Text from sqlalchemy import true +from sqlalchemy import Unicode from sqlalchemy.dialects import mysql from sqlalchemy.dialects import oracle from sqlalchemy.dialects import postgresql @@ -215,6 +216,44 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ]: self.assert_compile(func.random(), ret, dialect=dialect) + def test_return_type_aggregate_strings(self): + t = table("t", column("value", String)) + expr = func.aggregate_strings(t.c.value, ",") + is_(expr.type._type_affinity, String) + + @testing.combinations( + ( + "SELECT group_concat(t.value, ?) AS aggregate_strings_1 FROM t", + "sqlite", + ), + ( + "SELECT string_agg(t.value, %(aggregate_strings_2)s) AS " + "aggregate_strings_1 FROM t", + "postgresql", + ), + ( + "SELECT string_agg(t.value, " + "__[POSTCOMPILE_aggregate_strings_2]) AS " + "aggregate_strings_1 FROM t", + "mssql", + ), + ( + "SELECT group_concat(t.value SEPARATOR %s) " + "AS aggregate_strings_1 FROM t", + "mysql", + ), + ( + "SELECT LISTAGG(t.value, :aggregate_strings_2) AS" + " aggregate_strings_1 FROM t", + "oracle", + ), + ) + def test_aggregate_strings(self, expected_sql, dialect): + t = table("t", column("value", String)) + stmt = select(func.aggregate_strings(t.c.value, ",")) + + self.assert_compile(stmt, expected_sql, dialect=dialect) + def test_cube_operators(self): t = table( "t", @@ -1157,6 +1196,51 @@ class ExecuteTest(fixtures.TestBase): (9, "foo"), ) + @testing.variation("unicode_value", [True, False]) + @testing.variation("unicode_separator", [True, False]) + def test_aggregate_strings_execute( + self, connection, metadata, unicode_value, unicode_separator + ): + values_t = Table( + "values", + metadata, + Column("value", String(42)), + Column("unicode_value", Unicode(42)), + ) + metadata.create_all(connection) + connection.execute( + values_t.insert(), + [ + {"value": "a", "unicode_value": "測試"}, + {"value": "b", "unicode_value": "téble2"}, + {"value": None, "unicode_value": None}, # ignored + {"value": "c", "unicode_value": "🐍 su"}, + ], + ) + + if unicode_separator: + separator = " 🐍試 " + else: + separator = " and " + + if unicode_value: + col = values_t.c.unicode_value + expected = separator.join(["測試", "téble2", "🐍 su"]) + else: + col = values_t.c.value + expected = separator.join(["a", "b", "c"]) + + # to join on a unicode separator, source string has to be unicode, + # so cast(). SQL Server will raise otherwise + if unicode_separator: + col = cast(col, Unicode(42)) + + value = connection.execute( + select(func.aggregate_strings(col, separator)) + ).scalar_one() + + eq_(value, expected) + @testing.fails_on_everything_except("postgresql") def test_as_from(self, connection): # TODO: shouldn't this work on oracle too ? diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index 09c2acf057..e66e554cff 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -9,111 +9,117 @@ from sqlalchemy import select # code within this block is **programmatically, # statically generated** by tools/generate_sql_functions.py -stmt1 = select(func.char_length(column("x"))) +stmt1 = select(func.aggregate_strings(column("x"), column("x"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt1) -stmt2 = select(func.concat()) +stmt2 = select(func.char_length(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt2) -stmt3 = select(func.count(column("x"))) +stmt3 = select(func.concat()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt3) -stmt4 = select(func.cume_dist()) +stmt4 = select(func.count(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt4) -stmt5 = select(func.current_date()) +stmt5 = select(func.cume_dist()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt5) -stmt6 = select(func.current_time()) +stmt6 = select(func.current_date()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] reveal_type(stmt6) -stmt7 = select(func.current_timestamp()) +stmt7 = select(func.current_time()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] reveal_type(stmt7) -stmt8 = select(func.current_user()) +stmt8 = select(func.current_timestamp()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt8) -stmt9 = select(func.dense_rank()) +stmt9 = select(func.current_user()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt9) -stmt10 = select(func.localtime()) +stmt10 = select(func.dense_rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt10) -stmt11 = select(func.localtimestamp()) +stmt11 = select(func.localtime()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt11) -stmt12 = select(func.next_value(column("x"))) +stmt12 = select(func.localtimestamp()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt12) -stmt13 = select(func.now()) +stmt13 = select(func.next_value(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt13) -stmt14 = select(func.percent_rank()) +stmt14 = select(func.now()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt14) -stmt15 = select(func.rank()) +stmt15 = select(func.percent_rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt15) -stmt16 = select(func.session_user()) +stmt16 = select(func.rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt16) -stmt17 = select(func.sysdate()) +stmt17 = select(func.session_user()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt17) -stmt18 = select(func.user()) +stmt18 = select(func.sysdate()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt18) + +stmt19 = select(func.user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt19) + # END GENERATED FUNCTION TYPING TESTS