]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add support for SQL string aggregation function :class:`.string_agg`
authorJoshua Morris <joshuajohnmorris@gmail.com>
Sun, 4 Jun 2023 07:07:20 +0000 (17:07 +1000)
committerJoshua Morris <joshua.morris@deswik.com>
Mon, 10 Jul 2023 22:48:07 +0000 (08:48 +1000)
Returns a :class:`.String` with support for PostgreSQL, SQLite, and MSSQL. fixes #9873

lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/sql/functions.py
test/sql/test_compare.py
test/sql/test_functions.py

index 835ec27731bacc7e898893ef72dd91cad5b566af..69bdc1e56202de29f3c21fc6bdff3e0eaf7ec6ed 100644 (file)
@@ -1318,6 +1318,9 @@ class SQLiteCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "length%s" % self.function_argspec(fn)
 
+    def visit_string_agg_func(self, fn, **kw):
+        return "group_concat%s" % self.function_argspec(fn)
+
     def visit_cast(self, cast, **kwargs):
         if self.dialect.supports_cast:
             return super().visit_cast(cast, **kwargs)
index 30e280c61f6d35adce0424c947e750e33d3c4c68..5978ec5ade18d95e31494f0a6fb17f970bf2f1dc 100644 (file)
@@ -1044,6 +1044,10 @@ class _FunctionGenerator:
         def session_user(self) -> Type[session_user]:
             ...
 
+        @property
+        def string_agg(self) -> Type[string_agg[Any]]:
+            ...
+
         @property
         def sum(self) -> Type[sum[Any]]:  # noqa: A001
             ...
@@ -1795,3 +1799,20 @@ class grouping_sets(GenericFunction[_T]):
     """
     _has_args = True
     inherit_cache = True
+
+
+class string_agg(GenericFunction[_T]):
+    r"""Implement the ``STRING_AGG`` aggregation function
+
+    This function will concatenate non-null values into a string and
+    separate the values by a delimeter.
+
+    The return type of this function is :class:`.String`.
+
+    """
+    type = sqltypes.String()
+    _has_args = True
+    inherit_cache = True
+
+    def __init__(self, sep=",", *args, **kwargs):
+        super().__init__(sep, *args, **kwargs)
index 353537ad3ea60bf46b6d0125d43304c0c0d29d10..7dd70ad3ac2e38a26ac39f53c1a8e2006ac38e8a 100644 (file)
@@ -377,6 +377,10 @@ class CoreFixtures:
         lambda: (table_a.c.a, table_b.c.a),
         lambda: (tuple_(1, 2), tuple_(3, 4)),
         lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])),
+        lambda: (
+            func.string_agg(table_a.c.b),
+            func.string_agg(table_b_like_a.c.b),
+        ),
         lambda: (
             func.percentile_cont(0.5).within_group(table_a.c.a),
             func.percentile_cont(0.5).within_group(table_a.c.b),
index e961a4e4657273a247dfa5b79b4632e083502a03..51ed638940c5bbfa0f2609ef9d617b135522b64a 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import Text
 from sqlalchemy import true
+from sqlalchemy.dialects import mssql
 from sqlalchemy.dialects import mysql
 from sqlalchemy.dialects import oracle
 from sqlalchemy.dialects import postgresql
@@ -215,6 +216,33 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         ]:
             self.assert_compile(func.random(), ret, dialect=dialect)
 
+    def test_generic_string_agg(self):
+        t = table("t", column("value", String))
+        expr = func.string_agg(t.c.value, ",")
+        is_(expr.type._type_affinity, String)
+        stmt = select(expr)
+
+        self.assert_compile(
+            stmt,
+            "SELECT group_concat(t.value, ?) AS string_agg_1 FROM t",
+            dialect=sqlite.dialect(),
+            checkpositional=(",",),
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT string_agg(t.value, ',') AS string_agg_1 FROM t",
+            dialect=postgresql.dialect(),
+            literal_binds=True,
+            render_postcompile=True,
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT string_agg(t.value, ',') AS string_agg_1 FROM t",
+            dialect=mssql.dialect(),
+            literal_binds=True,
+            render_postcompile=True,
+        )
+
     def test_cube_operators(self):
         t = table(
             "t",
@@ -1157,6 +1185,46 @@ class ExecuteTest(fixtures.TestBase):
             (9, "foo"),
         )
 
+    @testing.provide_metadata
+    def test_string_agg_execute(self, connection):
+        meta = self.metadata
+        values_t = Table("values", meta, Column("value", String))
+        meta.create_all(connection)
+        connection.execute(
+            values_t.insert(),
+            [
+                {"value": "a"},
+                {"value": "b"},
+                {"value": "c"},
+            ],
+        )
+        rs = connection.execute(select(func.string_agg(values_t.c.value)))
+        row = rs.scalar()
+
+        assert row == "a,b,c"
+        rs.close()
+
+    @testing.provide_metadata
+    def test_string_agg_custom_sep(self, connection):
+        meta = self.metadata
+        values_t = Table("values", meta, Column("value", String))
+        meta.create_all(connection)
+        connection.execute(
+            values_t.insert(),
+            [
+                {"value": "a"},
+                {"value": "b"},
+                {"value": "c"},
+            ],
+        )
+        rs = connection.execute(
+            select(func.string_agg(values_t.c.value, " and "))
+        )
+        row = rs.scalar()
+
+        assert row == "a and b and c"
+        rs.close()
+
     @testing.fails_on_everything_except("postgresql")
     def test_as_from(self, connection):
         # TODO: shouldn't this work on oracle too ?