From: Joshua Morris Date: Mon, 10 Jul 2023 09:40:12 +0000 (+1000) Subject: rename to aggregate_strings X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0f4f83bade675f2ff734579fc411d0b354e5a4e6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git rename to aggregate_strings --- diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 1f80aaef29..a571e855fe 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2057,6 +2057,9 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): return "LEN%s" % self.function_argspec(fn, **kw) + def visit_aggregate_strings_func(self, fn, **kw): + return "string_agg%s" % self.function_argspec(fn, **kw) + def visit_concat_op_expression_clauselist( self, clauselist, operator, **kw ): diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index ac16904279..45f45d513d 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1208,15 +1208,9 @@ class MySQLCompiler(compiler.SQLCompiler): ) return f"{clause} WITH ROLLUP" - def visit_string_agg_func(self, fn, **kw): - if len(fn.clauses) > 1: - clauses = [ - elem._compiler_dispatch(self, **kw) for elem in fn.clauses - ] - clause = ", ".join(clauses[:-1]) - return "group_concat(%s SEPARATOR %s)" % (clause, clauses[-1]) - else: - return "group_concat%s" % self.function_argspec(fn) + def visit_aggregate_strings_func(self, fn, **kw): + expr, delimeter = fn.clauses + return "group_concat(%s SEPARATOR %s)" % (expr, delimeter) def visit_sequence(self, seq, **kw): return "nextval(%s)" % self.preparer.format_sequence(seq) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index d221574888..8d2c763468 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1868,6 +1868,9 @@ class PGCompiler(compiler.SQLCompiler): value = value.replace("\\", "\\\\") return value + def visit_aggregate_strings_func(self, fn, **kw): + return "string_agg%s" % self.function_argspec(fn) + def visit_sequence(self, seq, **kw): return "nextval('%s')" % self.preparer.format_sequence(seq) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 69bdc1e562..469b174586 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1318,7 +1318,7 @@ class SQLiteCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): return "length%s" % self.function_argspec(fn) - def visit_string_agg_func(self, fn, **kw): + def visit_aggregate_strings_func(self, fn, **kw): return "group_concat%s" % self.function_argspec(fn) def visit_cast(self, cast, **kwargs): diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 1ca0a0468c..f02476612f 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1045,7 +1045,7 @@ class _FunctionGenerator: ... @property - def string_agg(self) -> Type[string_agg[Any]]: + def aggregate_strings(self) -> Type[aggregate_strings[Any]]: ... @property @@ -1801,15 +1801,15 @@ class grouping_sets(GenericFunction[_T]): inherit_cache = True -class string_agg(GenericFunction[_T]): +class aggregate_strings(GenericFunction[_T]): r"""Implement the ``STRING_AGG`` aggregation function This function will concatenate non-null values into a string and separate the values by a delimeter. - e.g. Example usage with delimeter '.' as the last argument + e.g. Example usage with delimeter '.' - stmt = select(func.string_agg(table.c.str_col, ".")) + stmt = select(func.aggregate_strings(table.c.str_col, ".")) The return type of this function is :class:`.String`. diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 7dd70ad3ac..b2be90f60c 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -378,8 +378,8 @@ class CoreFixtures: lambda: (tuple_(1, 2), tuple_(3, 4)), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), lambda: ( - func.string_agg(table_a.c.b), - func.string_agg(table_b_like_a.c.b), + func.aggregate_strings(table_a.c.b, ","), + func.aggregate_strings(table_b_like_a.c.b, ","), ), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 6aa1faf494..0f10daf837 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -216,67 +216,38 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ]: self.assert_compile(func.random(), ret, dialect=dialect) - def test_return_type_string_agg(self): + def test_return_type_aggregate_strings(self): t = table("t", column("value", String)) - expr = func.string_agg(t.c.value, ",") + expr = func.aggregate_strings(t.c.value, ",") is_(expr.type._type_affinity, String) - def test_generic_string_agg(self): + def test_generic_aggregate_strings(self): t = table("t", column("value", String)) - stmt = select(func.string_agg(t.c.value)) + stmt = select(func.aggregate_strings(t.c.value, ",")) self.assert_compile( stmt, - "SELECT group_concat(t.value) AS string_agg_1 FROM t", - dialect=sqlite.dialect(), - ) - self.assert_compile( - stmt, - "SELECT string_agg(t.value) AS string_agg_1 FROM t", - dialect=postgresql.dialect(), - ) - self.assert_compile( - stmt, - "SELECT string_agg(t.value) AS string_agg_1 FROM t", - dialect=mssql.dialect(), - ) - self.assert_compile( - stmt, - "SELECT group_concat(t.value) AS string_agg_1 FROM t", - dialect=mysql.dialect(), - ) - - def test_generic_string_agg_with_delimeter(self): - t = table("t", column("value", String)) - stmt = select(func.string_agg(t.c.value, ",")) - - self.assert_compile( - stmt, - "SELECT group_concat(t.value, ?) AS string_agg_1 FROM t", + "SELECT group_concat(t.value, ?) AS aggregate_strings_1 FROM t", dialect=sqlite.dialect(), checkpositional=(",",), ) self.assert_compile( stmt, - "SELECT string_agg(t.value, ',') AS string_agg_1 FROM t", + "SELECT string_agg(t.value, %(aggregate_strings_2)s) AS " + "aggregate_strings_1 FROM t", dialect=postgresql.dialect(), - literal_binds=True, - render_postcompile=True, ) self.assert_compile( stmt, - "SELECT string_agg(t.value, ',') AS string_agg_1 FROM t", + "SELECT string_agg(t.value, :aggregate_strings_2) AS " + "aggregate_strings_1 FROM t", dialect=mssql.dialect(), - literal_binds=True, - render_postcompile=True, ) self.assert_compile( stmt, - "SELECT group_concat(t.value SEPARATOR ',') " - "AS string_agg_1 FROM t", + "SELECT group_concat(t.value SEPARATOR :aggregate_strings_1) " + "AS aggregate_strings_1 FROM t", dialect=mysql.dialect(), - literal_binds=True, - render_postcompile=True, ) def test_cube_operators(self): @@ -1222,27 +1193,7 @@ class ExecuteTest(fixtures.TestBase): ) @testing.provide_metadata - def test_string_agg_execute(self, connection): - meta = self.metadata - values_t = Table("values", meta, Column("value", String)) - meta.create_all(connection) - connection.execute( - values_t.insert(), - [ - {"value": "a"}, - {"value": "b"}, - {"value": "c"}, - {"value": None}, # ignored - ], - ) - rs = connection.execute(select(func.string_agg(values_t.c.value))) - row = rs.scalar() - - assert row == "a,b,c" - rs.close() - - @testing.provide_metadata - def test_string_agg_execute_with_delimeter(self, connection): + def test_aggregate_strings_execute(self, connection): meta = self.metadata values_t = Table("values", meta, Column("value", String)) meta.create_all(connection) @@ -1256,7 +1207,7 @@ class ExecuteTest(fixtures.TestBase): ], ) rs = connection.execute( - select(func.string_agg(values_t.c.value, " and ")) + select(func.aggregate_strings(values_t.c.value, " and ")) ) row = rs.scalar()