From 02fe382d6bfc5e8ccab6e2024a5241379a02b7e0 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sat, 17 Sep 2022 13:12:35 +0200 Subject: [PATCH] Improve array_agg and Array processing The :class:`_functions.array_agg` will now set the array dimensions to 1. Improved :class:`_types.ARRAY` processing to accept ``None`` values as value of a multi-array. Fixes: #7083 Change-Id: Iafec4f77fde9719ccc7c8535bf6235dbfbc62102 --- doc/build/changelog/unreleased_20/7083.rst | 7 ++++ lib/sqlalchemy/sql/functions.py | 4 +- lib/sqlalchemy/sql/sqltypes.py | 2 + lib/sqlalchemy/testing/config.py | 11 ++++++ test/dialect/postgresql/test_types.py | 43 ++++++++++++++++++++++ test/sql/test_functions.py | 5 ++- 6 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/7083.rst diff --git a/doc/build/changelog/unreleased_20/7083.rst b/doc/build/changelog/unreleased_20/7083.rst new file mode 100644 index 0000000000..6b3836a59e --- /dev/null +++ b/doc/build/changelog/unreleased_20/7083.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, sql + :tickets: 7083 + + The :class:`_functions.array_agg` will now set the array dimensions to 1. + Improved :class:`_types.ARRAY` processing to accept ``None`` values as + value of a multi-array. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index a028e7fedb..c04f5fa1d2 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1407,7 +1407,9 @@ class array_agg(GenericFunction[_T]): if isinstance(type_from_args, sqltypes.ARRAY): kwargs["type_"] = type_from_args else: - kwargs["type_"] = default_array_type(type_from_args) + kwargs["type_"] = default_array_type( + type_from_args, dimensions=1 + ) kwargs["_parsed_args"] = fn_args super(array_agg, self).__init__(*fn_args, **kwargs) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index fd52ec6ea1..414ff03c3e 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -3065,6 +3065,8 @@ class ARRAY( dim - 1 if dim is not None else None, collection_callable, ) + if x is not None + else None for x in arr ) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index e418b48be7..1cb463977b 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -13,6 +13,8 @@ import collections import typing from typing import Any from typing import Iterable +from typing import Optional +from typing import overload from typing import Tuple from typing import Union @@ -37,6 +39,15 @@ else: _fixture_functions = None # installed by plugin_base +@overload +def combinations( + *comb: Union[Any, Tuple[Any, ...]], + argnames: Optional[str] = None, + id_: Optional[str] = None, +): + ... + + def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str): r"""Deliver multiple versions of a test based on positional combinations. diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 92fcfbcab3..b5c20bd8d4 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -36,6 +36,7 @@ from sqlalchemy import types from sqlalchemy import Unicode from sqlalchemy import util from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import DATEMULTIRANGE from sqlalchemy.dialects.postgresql import DATERANGE @@ -1901,6 +1902,48 @@ class ArrayRoundTripTest: stmt = select(func.array_agg(values_table.c.value)[2:4]) eq_(connection.execute(stmt).scalar(), [2, 3, 4]) + def test_array_agg_json(self, metadata, connection): + table = Table( + "values", metadata, Column("id", Integer), Column("bar", JSON) + ) + metadata.create_all(connection) + connection.execute( + table.insert(), + [{"id": 1, "bar": [{"buz": 1}]}, {"id": 2, "bar": None}], + ) + + arg = aggregate_order_by(table.c.bar, table.c.id) + stmt = select(sa.func.array_agg(arg)) + eq_(connection.execute(stmt).scalar(), [[{"buz": 1}], None]) + + arg = aggregate_order_by(table.c.bar, table.c.id.desc()) + stmt = select(sa.func.array_agg(arg)) + eq_(connection.execute(stmt).scalar(), [None, [{"buz": 1}]]) + + @testing.combinations(ARRAY, postgresql.ARRAY, argnames="cls") + def test_array_none(self, connection, metadata, cls): + table = Table( + "values", metadata, Column("id", Integer), Column("bar", cls(JSON)) + ) + metadata.create_all(connection) + connection.execute( + table.insert().values( + [ + { + "id": 1, + "bar": sa.text("""array['[{"x": 1}]'::json, null]"""), + }, + {"id": 2, "bar": None}, + ] + ) + ) + + stmt = select(table.c.bar).order_by(table.c.id) + eq_(connection.scalars(stmt).all(), [[[{"x": 1}], None], None]) + + stmt = select(table.c.bar).order_by(table.c.id.desc()) + eq_(connection.scalars(stmt).all(), [None, [[{"x": 1}], None]]) + def test_array_index_slice_exprs(self, connection): """test a variety of expressions that sometimes need parenthesizing""" diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 6c00660ffe..a48d42f507 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -948,11 +948,14 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase): expr = func.array_agg(column("data", Integer)) is_(expr.type._type_affinity, ARRAY) is_(expr.type.item_type._type_affinity, Integer) + is_(expr.type.dimensions, 1) def test_array_agg_array_datatype(self): - expr = func.array_agg(column("data", ARRAY(Integer))) + col = column("data", ARRAY(Integer)) + expr = func.array_agg(col) is_(expr.type._type_affinity, ARRAY) is_(expr.type.item_type._type_affinity, Integer) + eq_(expr.type.dimensions, col.type.dimensions) def test_array_agg_array_literal_implicit_type(self): from sqlalchemy.dialects.postgresql import array, ARRAY as PG_ARRAY -- 2.47.2