]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add support for SQL string aggregation function aggregate_strings.
authorJoshua Morris <joshuajohnmorris@gmail.com>
Tue, 11 Jul 2023 06:21:37 +0000 (02:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Sep 2023 19:49:29 +0000 (15:49 -0400)
Add support for SQL string aggregation function :class:`.aggregate_strings`.
Pull request curtesy Joshua Morris.

Fixes #9873
Closes: #9892
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9892
Pull-request-sha: 0f4f83bade675f2ff734579fc411d0b354e5a4e6

Change-Id: I6e4afc83664a142e3b1e245978b08a200b6d03d9

doc/build/changelog/unreleased_20/9873.rst [new file with mode: 0644]
doc/build/core/functions.rst
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/sql/functions.py
test/sql/test_compare.py
test/sql/test_functions.py
test/typing/plain_files/sql/functions.py

diff --git a/doc/build/changelog/unreleased_20/9873.rst b/doc/build/changelog/unreleased_20/9873.rst
new file mode 100644 (file)
index 0000000..f1071fd
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: usecase, sql
+    :tickets: 9873
+
+    Added new generic SQL function :class:`_functions.aggregate_strings`, which
+    accepts a SQL expression and a decimeter, concatenating strings on multiple
+    rows into a single aggregate value. The function is compiled on a
+    per-backend basis, into functions such as ``group_concat(),``
+    ``string_agg()``, or ``LISTAGG()``.
+    Pull request courtesy Joshua Morris.
\ No newline at end of file
index 6fcee6edaa26e87d15d991e9d2131fe68bdec7f8..9771ffeedd9be20c76f037d32f7aadeb5a7f5d06 100644 (file)
@@ -52,6 +52,9 @@ unknown to SQLAlchemy, built-in or user defined. The section here only
 describes those functions where SQLAlchemy already knows what argument and
 return types are in use.
 
+.. autoclass:: aggregate_strings
+    :no-members:
+
 .. autoclass:: array_agg
     :no-members:
 
index 1f80aaef29d5987b365b565e2663b494b1e4b4cc..6d46687e43b797bc012673480b9755fb20f0d3a3 100644 (file)
@@ -2057,6 +2057,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "LEN%s" % self.function_argspec(fn, **kw)
 
+    def visit_aggregate_strings_func(self, fn, **kw):
+        expr = fn.clauses.clauses[0]._compiler_dispatch(self, **kw)
+        kw["literal_execute"] = True
+        delimeter = fn.clauses.clauses[1]._compiler_dispatch(self, **kw)
+        return f"string_agg({expr}, {delimeter})"
+
     def visit_concat_op_expression_clauselist(
         self, clauselist, operator, **kw
     ):
index 8b0e1295d1ea4c39ca36da6636ed8e38400dc75d..d3f2a3ff87077d3efd79473b2eeb477ac51b8542 100644 (file)
@@ -1208,6 +1208,12 @@ class MySQLCompiler(compiler.SQLCompiler):
         )
         return f"{clause} WITH ROLLUP"
 
+    def visit_aggregate_strings_func(self, fn, **kw):
+        expr, delimeter = (
+            elem._compiler_dispatch(self, **kw) for elem in fn.clauses
+        )
+        return f"group_concat({expr} SEPARATOR {delimeter})"
+
     def visit_sequence(self, seq, **kw):
         return "nextval(%s)" % self.preparer.format_sequence(seq)
 
index 4a3ac3ac0781d91a1b8ddb3082cf90cd69abb2dc..d993ef2692790f08e686efe92181913e7b477f99 100644 (file)
@@ -1241,6 +1241,9 @@ class OracleCompiler(compiler.SQLCompiler):
                 self.render_literal_value(flags, sqltypes.STRINGTYPE),
             )
 
+    def visit_aggregate_strings_func(self, fn, **kw):
+        return "LISTAGG%s" % self.function_argspec(fn, **kw)
+
 
 class OracleDDLCompiler(compiler.DDLCompiler):
     def define_constraint_cascades(self, constraint):
index 08fd5d3a04d92913a264ac1a1d3565ef6921a3ce..5fe619d6f57c4bd489a55b23f5c593ade14bbbd9 100644 (file)
@@ -1868,6 +1868,9 @@ class PGCompiler(compiler.SQLCompiler):
             value = value.replace("\\", "\\\\")
         return value
 
+    def visit_aggregate_strings_func(self, fn, **kw):
+        return "string_agg%s" % self.function_argspec(fn)
+
     def visit_sequence(self, seq, **kw):
         return "nextval('%s')" % self.preparer.format_sequence(seq)
 
index bd6c1dff87b384b91192ac069fb9a3586c2dc962..a8e53684072040f9154cca169fab6a906ab53a5d 100644 (file)
@@ -1318,6 +1318,9 @@ class SQLiteCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "length%s" % self.function_argspec(fn)
 
+    def visit_aggregate_strings_func(self, fn, **kw):
+        return "group_concat%s" % self.function_argspec(fn)
+
     def visit_cast(self, cast, **kwargs):
         if self.dialect.supports_cast:
             return super().visit_cast(cast, **kwargs)
index 30e280c61f6d35adce0424c947e750e33d3c4c68..fc23e9d2156d9d9d8868abc12072e43e9acba620 100644 (file)
@@ -916,6 +916,10 @@ class _FunctionGenerator:
         # code within this block is **programmatically,
         # statically generated** by tools/generate_sql_functions.py
 
+        @property
+        def aggregate_strings(self) -> Type[aggregate_strings]:
+            ...
+
         @property
         def ansifunction(self) -> Type[AnsiFunction[Any]]:
             ...
@@ -1795,3 +1799,30 @@ class grouping_sets(GenericFunction[_T]):
     """
     _has_args = True
     inherit_cache = True
+
+
+class aggregate_strings(GenericFunction[str]):
+    """Implement a generic string aggregation function.
+
+    This function will concatenate non-null values into a string and
+    separate the values by a delimiter.
+
+    This function is compiled on a per-backend basis, into functions
+    such as ``group_concat()``, ``string_agg()``, or ``LISTAGG()``.
+
+    e.g. Example usage with delimiter '.'::
+
+        stmt = select(func.aggregate_strings(table.c.str_col, "."))
+
+    The return type of this function is :class:`.String`.
+
+    .. versionadded: 2.0.21
+
+    """
+
+    type = sqltypes.String()
+    _has_args = True
+    inherit_cache = True
+
+    def __init__(self, clause, separator):
+        super().__init__(clause, separator)
index 353537ad3ea60bf46b6d0125d43304c0c0d29d10..b2be90f60cdf621a2302001a4a899387f75dc43a 100644 (file)
@@ -377,6 +377,10 @@ class CoreFixtures:
         lambda: (table_a.c.a, table_b.c.a),
         lambda: (tuple_(1, 2), tuple_(3, 4)),
         lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])),
+        lambda: (
+            func.aggregate_strings(table_a.c.b, ","),
+            func.aggregate_strings(table_b_like_a.c.b, ","),
+        ),
         lambda: (
             func.percentile_cont(0.5).within_group(table_a.c.a),
             func.percentile_cont(0.5).within_group(table_a.c.b),
index e961a4e4657273a247dfa5b79b4632e083502a03..c47601b7616622c9c66a78eed982fe93063ef7eb 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import Text
 from sqlalchemy import true
+from sqlalchemy import Unicode
 from sqlalchemy.dialects import mysql
 from sqlalchemy.dialects import oracle
 from sqlalchemy.dialects import postgresql
@@ -215,6 +216,44 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         ]:
             self.assert_compile(func.random(), ret, dialect=dialect)
 
+    def test_return_type_aggregate_strings(self):
+        t = table("t", column("value", String))
+        expr = func.aggregate_strings(t.c.value, ",")
+        is_(expr.type._type_affinity, String)
+
+    @testing.combinations(
+        (
+            "SELECT group_concat(t.value, ?) AS aggregate_strings_1 FROM t",
+            "sqlite",
+        ),
+        (
+            "SELECT string_agg(t.value, %(aggregate_strings_2)s) AS "
+            "aggregate_strings_1 FROM t",
+            "postgresql",
+        ),
+        (
+            "SELECT string_agg(t.value, "
+            "__[POSTCOMPILE_aggregate_strings_2]) AS "
+            "aggregate_strings_1 FROM t",
+            "mssql",
+        ),
+        (
+            "SELECT group_concat(t.value SEPARATOR %s) "
+            "AS aggregate_strings_1 FROM t",
+            "mysql",
+        ),
+        (
+            "SELECT LISTAGG(t.value, :aggregate_strings_2) AS"
+            " aggregate_strings_1 FROM t",
+            "oracle",
+        ),
+    )
+    def test_aggregate_strings(self, expected_sql, dialect):
+        t = table("t", column("value", String))
+        stmt = select(func.aggregate_strings(t.c.value, ","))
+
+        self.assert_compile(stmt, expected_sql, dialect=dialect)
+
     def test_cube_operators(self):
         t = table(
             "t",
@@ -1157,6 +1196,51 @@ class ExecuteTest(fixtures.TestBase):
             (9, "foo"),
         )
 
+    @testing.variation("unicode_value", [True, False])
+    @testing.variation("unicode_separator", [True, False])
+    def test_aggregate_strings_execute(
+        self, connection, metadata, unicode_value, unicode_separator
+    ):
+        values_t = Table(
+            "values",
+            metadata,
+            Column("value", String(42)),
+            Column("unicode_value", Unicode(42)),
+        )
+        metadata.create_all(connection)
+        connection.execute(
+            values_t.insert(),
+            [
+                {"value": "a", "unicode_value": "測試"},
+                {"value": "b", "unicode_value": "téble2"},
+                {"value": None, "unicode_value": None},  # ignored
+                {"value": "c", "unicode_value": "🐍 su"},
+            ],
+        )
+
+        if unicode_separator:
+            separator = " 🐍試 "
+        else:
+            separator = " and "
+
+        if unicode_value:
+            col = values_t.c.unicode_value
+            expected = separator.join(["測試", "téble2", "🐍 su"])
+        else:
+            col = values_t.c.value
+            expected = separator.join(["a", "b", "c"])
+
+            # to join on a unicode separator, source string has to be unicode,
+            # so cast().  SQL Server will raise otherwise
+            if unicode_separator:
+                col = cast(col, Unicode(42))
+
+        value = connection.execute(
+            select(func.aggregate_strings(col, separator))
+        ).scalar_one()
+
+        eq_(value, expected)
+
     @testing.fails_on_everything_except("postgresql")
     def test_as_from(self, connection):
         # TODO: shouldn't this work on oracle too ?
index 09c2acf057f83062933e095db18d1ae61e0c3c0e..e66e554cff78cb2a3833ffddbdc59fa765983009 100644 (file)
@@ -9,111 +9,117 @@ from sqlalchemy import select
 # code within this block is **programmatically,
 # statically generated** by tools/generate_sql_functions.py
 
-stmt1 = select(func.char_length(column("x")))
+stmt1 = select(func.aggregate_strings(column("x"), column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt1)
 
 
-stmt2 = select(func.concat())
+stmt2 = select(func.char_length(column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt2)
 
 
-stmt3 = select(func.count(column("x")))
+stmt3 = select(func.concat())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt3)
 
 
-stmt4 = select(func.cume_dist())
+stmt4 = select(func.count(column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt4)
 
 
-stmt5 = select(func.current_date())
+stmt5 = select(func.cume_dist())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
 reveal_type(stmt5)
 
 
-stmt6 = select(func.current_time())
+stmt6 = select(func.current_date())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
 reveal_type(stmt6)
 
 
-stmt7 = select(func.current_timestamp())
+stmt7 = select(func.current_time())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
 reveal_type(stmt7)
 
 
-stmt8 = select(func.current_user())
+stmt8 = select(func.current_timestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt8)
 
 
-stmt9 = select(func.dense_rank())
+stmt9 = select(func.current_user())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt9)
 
 
-stmt10 = select(func.localtime())
+stmt10 = select(func.dense_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt10)
 
 
-stmt11 = select(func.localtimestamp())
+stmt11 = select(func.localtime())
 
 # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt11)
 
 
-stmt12 = select(func.next_value(column("x")))
+stmt12 = select(func.localtimestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt12)
 
 
-stmt13 = select(func.now())
+stmt13 = select(func.next_value(column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt13)
 
 
-stmt14 = select(func.percent_rank())
+stmt14 = select(func.now())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt14)
 
 
-stmt15 = select(func.rank())
+stmt15 = select(func.percent_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
 reveal_type(stmt15)
 
 
-stmt16 = select(func.session_user())
+stmt16 = select(func.rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt16)
 
 
-stmt17 = select(func.sysdate())
+stmt17 = select(func.session_user())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt17)
 
 
-stmt18 = select(func.user())
+stmt18 = select(func.sysdate())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt18)
 
+
+stmt19 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt19)
+
 # END GENERATED FUNCTION TYPING TESTS