--- /dev/null
+.. change::
+ :tags: usecase, sql
+ :tickets: 9873
+
+ Added new generic SQL function :class:`_functions.aggregate_strings`, which
+ accepts a SQL expression and a decimeter, concatenating strings on multiple
+ rows into a single aggregate value. The function is compiled on a
+ per-backend basis, into functions such as ``group_concat(),``
+ ``string_agg()``, or ``LISTAGG()``.
+ Pull request courtesy Joshua Morris.
\ No newline at end of file
describes those functions where SQLAlchemy already knows what argument and
return types are in use.
+.. autoclass:: aggregate_strings
+ :no-members:
+
.. autoclass:: array_agg
:no-members:
def visit_char_length_func(self, fn, **kw):
return "LEN%s" % self.function_argspec(fn, **kw)
+ def visit_aggregate_strings_func(self, fn, **kw):
+ expr = fn.clauses.clauses[0]._compiler_dispatch(self, **kw)
+ kw["literal_execute"] = True
+ delimeter = fn.clauses.clauses[1]._compiler_dispatch(self, **kw)
+ return f"string_agg({expr}, {delimeter})"
+
def visit_concat_op_expression_clauselist(
self, clauselist, operator, **kw
):
)
return f"{clause} WITH ROLLUP"
+ def visit_aggregate_strings_func(self, fn, **kw):
+ expr, delimeter = (
+ elem._compiler_dispatch(self, **kw) for elem in fn.clauses
+ )
+ return f"group_concat({expr} SEPARATOR {delimeter})"
+
def visit_sequence(self, seq, **kw):
return "nextval(%s)" % self.preparer.format_sequence(seq)
self.render_literal_value(flags, sqltypes.STRINGTYPE),
)
+ def visit_aggregate_strings_func(self, fn, **kw):
+ return "LISTAGG%s" % self.function_argspec(fn, **kw)
+
class OracleDDLCompiler(compiler.DDLCompiler):
def define_constraint_cascades(self, constraint):
value = value.replace("\\", "\\\\")
return value
+ def visit_aggregate_strings_func(self, fn, **kw):
+ return "string_agg%s" % self.function_argspec(fn)
+
def visit_sequence(self, seq, **kw):
return "nextval('%s')" % self.preparer.format_sequence(seq)
def visit_char_length_func(self, fn, **kw):
return "length%s" % self.function_argspec(fn)
+ def visit_aggregate_strings_func(self, fn, **kw):
+ return "group_concat%s" % self.function_argspec(fn)
+
def visit_cast(self, cast, **kwargs):
if self.dialect.supports_cast:
return super().visit_cast(cast, **kwargs)
# code within this block is **programmatically,
# statically generated** by tools/generate_sql_functions.py
+ @property
+ def aggregate_strings(self) -> Type[aggregate_strings]:
+ ...
+
@property
def ansifunction(self) -> Type[AnsiFunction[Any]]:
...
"""
_has_args = True
inherit_cache = True
+
+
+class aggregate_strings(GenericFunction[str]):
+ """Implement a generic string aggregation function.
+
+ This function will concatenate non-null values into a string and
+ separate the values by a delimiter.
+
+ This function is compiled on a per-backend basis, into functions
+ such as ``group_concat()``, ``string_agg()``, or ``LISTAGG()``.
+
+ e.g. Example usage with delimiter '.'::
+
+ stmt = select(func.aggregate_strings(table.c.str_col, "."))
+
+ The return type of this function is :class:`.String`.
+
+ .. versionadded: 2.0.21
+
+ """
+
+ type = sqltypes.String()
+ _has_args = True
+ inherit_cache = True
+
+ def __init__(self, clause, separator):
+ super().__init__(clause, separator)
lambda: (table_a.c.a, table_b.c.a),
lambda: (tuple_(1, 2), tuple_(3, 4)),
lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])),
+ lambda: (
+ func.aggregate_strings(table_a.c.b, ","),
+ func.aggregate_strings(table_b_like_a.c.b, ","),
+ ),
lambda: (
func.percentile_cont(0.5).within_group(table_a.c.a),
func.percentile_cont(0.5).within_group(table_a.c.b),
from sqlalchemy import testing
from sqlalchemy import Text
from sqlalchemy import true
+from sqlalchemy import Unicode
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import oracle
from sqlalchemy.dialects import postgresql
]:
self.assert_compile(func.random(), ret, dialect=dialect)
+ def test_return_type_aggregate_strings(self):
+ t = table("t", column("value", String))
+ expr = func.aggregate_strings(t.c.value, ",")
+ is_(expr.type._type_affinity, String)
+
+ @testing.combinations(
+ (
+ "SELECT group_concat(t.value, ?) AS aggregate_strings_1 FROM t",
+ "sqlite",
+ ),
+ (
+ "SELECT string_agg(t.value, %(aggregate_strings_2)s) AS "
+ "aggregate_strings_1 FROM t",
+ "postgresql",
+ ),
+ (
+ "SELECT string_agg(t.value, "
+ "__[POSTCOMPILE_aggregate_strings_2]) AS "
+ "aggregate_strings_1 FROM t",
+ "mssql",
+ ),
+ (
+ "SELECT group_concat(t.value SEPARATOR %s) "
+ "AS aggregate_strings_1 FROM t",
+ "mysql",
+ ),
+ (
+ "SELECT LISTAGG(t.value, :aggregate_strings_2) AS"
+ " aggregate_strings_1 FROM t",
+ "oracle",
+ ),
+ )
+ def test_aggregate_strings(self, expected_sql, dialect):
+ t = table("t", column("value", String))
+ stmt = select(func.aggregate_strings(t.c.value, ","))
+
+ self.assert_compile(stmt, expected_sql, dialect=dialect)
+
def test_cube_operators(self):
t = table(
"t",
(9, "foo"),
)
+ @testing.variation("unicode_value", [True, False])
+ @testing.variation("unicode_separator", [True, False])
+ def test_aggregate_strings_execute(
+ self, connection, metadata, unicode_value, unicode_separator
+ ):
+ values_t = Table(
+ "values",
+ metadata,
+ Column("value", String(42)),
+ Column("unicode_value", Unicode(42)),
+ )
+ metadata.create_all(connection)
+ connection.execute(
+ values_t.insert(),
+ [
+ {"value": "a", "unicode_value": "測試"},
+ {"value": "b", "unicode_value": "téble2"},
+ {"value": None, "unicode_value": None}, # ignored
+ {"value": "c", "unicode_value": "🐍 su"},
+ ],
+ )
+
+ if unicode_separator:
+ separator = " 🐍試 "
+ else:
+ separator = " and "
+
+ if unicode_value:
+ col = values_t.c.unicode_value
+ expected = separator.join(["測試", "téble2", "🐍 su"])
+ else:
+ col = values_t.c.value
+ expected = separator.join(["a", "b", "c"])
+
+ # to join on a unicode separator, source string has to be unicode,
+ # so cast(). SQL Server will raise otherwise
+ if unicode_separator:
+ col = cast(col, Unicode(42))
+
+ value = connection.execute(
+ select(func.aggregate_strings(col, separator))
+ ).scalar_one()
+
+ eq_(value, expected)
+
@testing.fails_on_everything_except("postgresql")
def test_as_from(self, connection):
# TODO: shouldn't this work on oracle too ?
# code within this block is **programmatically,
# statically generated** by tools/generate_sql_functions.py
-stmt1 = select(func.char_length(column("x")))
+stmt1 = select(func.aggregate_strings(column("x"), column("x")))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt1)
-stmt2 = select(func.concat())
+stmt2 = select(func.char_length(column("x")))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt2)
-stmt3 = select(func.count(column("x")))
+stmt3 = select(func.concat())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt3)
-stmt4 = select(func.cume_dist())
+stmt4 = select(func.count(column("x")))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt4)
-stmt5 = select(func.current_date())
+stmt5 = select(func.cume_dist())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
reveal_type(stmt5)
-stmt6 = select(func.current_time())
+stmt6 = select(func.current_date())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
reveal_type(stmt6)
-stmt7 = select(func.current_timestamp())
+stmt7 = select(func.current_time())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
reveal_type(stmt7)
-stmt8 = select(func.current_user())
+stmt8 = select(func.current_timestamp())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt8)
-stmt9 = select(func.dense_rank())
+stmt9 = select(func.current_user())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt9)
-stmt10 = select(func.localtime())
+stmt10 = select(func.dense_rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt10)
-stmt11 = select(func.localtimestamp())
+stmt11 = select(func.localtime())
# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt11)
-stmt12 = select(func.next_value(column("x")))
+stmt12 = select(func.localtimestamp())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt12)
-stmt13 = select(func.now())
+stmt13 = select(func.next_value(column("x")))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt13)
-stmt14 = select(func.percent_rank())
+stmt14 = select(func.now())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt14)
-stmt15 = select(func.rank())
+stmt15 = select(func.percent_rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
reveal_type(stmt15)
-stmt16 = select(func.session_user())
+stmt16 = select(func.rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt16)
-stmt17 = select(func.sysdate())
+stmt17 = select(func.session_user())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt17)
-stmt18 = select(func.user())
+stmt18 = select(func.sysdate())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt18)
+
+stmt19 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt19)
+
# END GENERATED FUNCTION TYPING TESTS