]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
rename to aggregate_strings 9892/head
authorJoshua Morris <joshuajohnmorris@gmail.com>
Mon, 10 Jul 2023 09:40:12 +0000 (19:40 +1000)
committerJoshua Morris <joshua.morris@deswik.com>
Mon, 10 Jul 2023 22:48:07 +0000 (08:48 +1000)
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/sql/functions.py
test/sql/test_compare.py
test/sql/test_functions.py

index 1f80aaef29d5987b365b565e2663b494b1e4b4cc..a571e855fea4000ed45f1c014862e5e9bfe4a87d 100644 (file)
@@ -2057,6 +2057,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "LEN%s" % self.function_argspec(fn, **kw)
 
+    def visit_aggregate_strings_func(self, fn, **kw):
+        return "string_agg%s" % self.function_argspec(fn, **kw)
+
     def visit_concat_op_expression_clauselist(
         self, clauselist, operator, **kw
     ):
index ac16904279fabdd6ab7194b528661a6ee674b43f..45f45d513dcd975bef098a8f9d062062f2d386d4 100644 (file)
@@ -1208,15 +1208,9 @@ class MySQLCompiler(compiler.SQLCompiler):
         )
         return f"{clause} WITH ROLLUP"
 
-    def visit_string_agg_func(self, fn, **kw):
-        if len(fn.clauses) > 1:
-            clauses = [
-                elem._compiler_dispatch(self, **kw) for elem in fn.clauses
-            ]
-            clause = ", ".join(clauses[:-1])
-            return "group_concat(%s SEPARATOR %s)" % (clause, clauses[-1])
-        else:
-            return "group_concat%s" % self.function_argspec(fn)
+    def visit_aggregate_strings_func(self, fn, **kw):
+        expr, delimeter = fn.clauses
+        return "group_concat(%s SEPARATOR %s)" % (expr, delimeter)
 
     def visit_sequence(self, seq, **kw):
         return "nextval(%s)" % self.preparer.format_sequence(seq)
index d221574888ad528a1986f11181fbd08bb851f2ab..8d2c763468fae697fd7c62caf774cc4c75f0e752 100644 (file)
@@ -1868,6 +1868,9 @@ class PGCompiler(compiler.SQLCompiler):
             value = value.replace("\\", "\\\\")
         return value
 
+    def visit_aggregate_strings_func(self, fn, **kw):
+        return "string_agg%s" % self.function_argspec(fn)
+
     def visit_sequence(self, seq, **kw):
         return "nextval('%s')" % self.preparer.format_sequence(seq)
 
index 69bdc1e56202de29f3c21fc6bdff3e0eaf7ec6ed..469b174586e8a7488d66c58f05114d559776a484 100644 (file)
@@ -1318,7 +1318,7 @@ class SQLiteCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "length%s" % self.function_argspec(fn)
 
-    def visit_string_agg_func(self, fn, **kw):
+    def visit_aggregate_strings_func(self, fn, **kw):
         return "group_concat%s" % self.function_argspec(fn)
 
     def visit_cast(self, cast, **kwargs):
index 1ca0a0468c4f8a75974228f97e197b0e490c6f68..f02476612fb3c8004b055b1ad807201627790092 100644 (file)
@@ -1045,7 +1045,7 @@ class _FunctionGenerator:
             ...
 
         @property
-        def string_agg(self) -> Type[string_agg[Any]]:
+        def aggregate_strings(self) -> Type[aggregate_strings[Any]]:
             ...
 
         @property
@@ -1801,15 +1801,15 @@ class grouping_sets(GenericFunction[_T]):
     inherit_cache = True
 
 
-class string_agg(GenericFunction[_T]):
+class aggregate_strings(GenericFunction[_T]):
     r"""Implement the ``STRING_AGG`` aggregation function
 
     This function will concatenate non-null values into a string and
     separate the values by a delimeter.
 
-    e.g. Example usage with delimeter '.' as the last argument
+    e.g. Example usage with delimeter '.'
 
-    stmt = select(func.string_agg(table.c.str_col, "."))
+    stmt = select(func.aggregate_strings(table.c.str_col, "."))
 
     The return type of this function is :class:`.String`.
 
index 7dd70ad3ac2e38a26ac39f53c1a8e2006ac38e8a..b2be90f60cdf621a2302001a4a899387f75dc43a 100644 (file)
@@ -378,8 +378,8 @@ class CoreFixtures:
         lambda: (tuple_(1, 2), tuple_(3, 4)),
         lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])),
         lambda: (
-            func.string_agg(table_a.c.b),
-            func.string_agg(table_b_like_a.c.b),
+            func.aggregate_strings(table_a.c.b, ","),
+            func.aggregate_strings(table_b_like_a.c.b, ","),
         ),
         lambda: (
             func.percentile_cont(0.5).within_group(table_a.c.a),
index 6aa1faf49433861cc785a4e282fd275a4c0e3582..0f10daf837d905b69ed3333a339b98ce3a870037 100644 (file)
@@ -216,67 +216,38 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         ]:
             self.assert_compile(func.random(), ret, dialect=dialect)
 
-    def test_return_type_string_agg(self):
+    def test_return_type_aggregate_strings(self):
         t = table("t", column("value", String))
-        expr = func.string_agg(t.c.value, ",")
+        expr = func.aggregate_strings(t.c.value, ",")
         is_(expr.type._type_affinity, String)
 
-    def test_generic_string_agg(self):
+    def test_generic_aggregate_strings(self):
         t = table("t", column("value", String))
-        stmt = select(func.string_agg(t.c.value))
+        stmt = select(func.aggregate_strings(t.c.value, ","))
 
         self.assert_compile(
             stmt,
-            "SELECT group_concat(t.value) AS string_agg_1 FROM t",
-            dialect=sqlite.dialect(),
-        )
-        self.assert_compile(
-            stmt,
-            "SELECT string_agg(t.value) AS string_agg_1 FROM t",
-            dialect=postgresql.dialect(),
-        )
-        self.assert_compile(
-            stmt,
-            "SELECT string_agg(t.value) AS string_agg_1 FROM t",
-            dialect=mssql.dialect(),
-        )
-        self.assert_compile(
-            stmt,
-            "SELECT group_concat(t.value) AS string_agg_1 FROM t",
-            dialect=mysql.dialect(),
-        )
-
-    def test_generic_string_agg_with_delimeter(self):
-        t = table("t", column("value", String))
-        stmt = select(func.string_agg(t.c.value, ","))
-
-        self.assert_compile(
-            stmt,
-            "SELECT group_concat(t.value, ?) AS string_agg_1 FROM t",
+            "SELECT group_concat(t.value, ?) AS aggregate_strings_1 FROM t",
             dialect=sqlite.dialect(),
             checkpositional=(",",),
         )
         self.assert_compile(
             stmt,
-            "SELECT string_agg(t.value, ',') AS string_agg_1 FROM t",
+            "SELECT string_agg(t.value, %(aggregate_strings_2)s) AS "
+            "aggregate_strings_1 FROM t",
             dialect=postgresql.dialect(),
-            literal_binds=True,
-            render_postcompile=True,
         )
         self.assert_compile(
             stmt,
-            "SELECT string_agg(t.value, ',') AS string_agg_1 FROM t",
+            "SELECT string_agg(t.value, :aggregate_strings_2) AS "
+            "aggregate_strings_1 FROM t",
             dialect=mssql.dialect(),
-            literal_binds=True,
-            render_postcompile=True,
         )
         self.assert_compile(
             stmt,
-            "SELECT group_concat(t.value SEPARATOR ',') "
-            "AS string_agg_1 FROM t",
+            "SELECT group_concat(t.value SEPARATOR :aggregate_strings_1) "
+            "AS aggregate_strings_1 FROM t",
             dialect=mysql.dialect(),
-            literal_binds=True,
-            render_postcompile=True,
         )
 
     def test_cube_operators(self):
@@ -1222,27 +1193,7 @@ class ExecuteTest(fixtures.TestBase):
         )
 
     @testing.provide_metadata
-    def test_string_agg_execute(self, connection):
-        meta = self.metadata
-        values_t = Table("values", meta, Column("value", String))
-        meta.create_all(connection)
-        connection.execute(
-            values_t.insert(),
-            [
-                {"value": "a"},
-                {"value": "b"},
-                {"value": "c"},
-                {"value": None},  # ignored
-            ],
-        )
-        rs = connection.execute(select(func.string_agg(values_t.c.value)))
-        row = rs.scalar()
-
-        assert row == "a,b,c"
-        rs.close()
-
-    @testing.provide_metadata
-    def test_string_agg_execute_with_delimeter(self, connection):
+    def test_aggregate_strings_execute(self, connection):
         meta = self.metadata
         values_t = Table("values", meta, Column("value", String))
         meta.create_all(connection)
@@ -1256,7 +1207,7 @@ class ExecuteTest(fixtures.TestBase):
             ],
         )
         rs = connection.execute(
-            select(func.string_agg(values_t.c.value, " and "))
+            select(func.aggregate_strings(values_t.c.value, " and "))
         )
         row = rs.scalar()