]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve array_agg and Array processing
authorFederico Caselli <cfederico87@gmail.com>
Sat, 17 Sep 2022 11:12:35 +0000 (13:12 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 17 Sep 2022 11:12:35 +0000 (13:12 +0200)
The :class:`_functions.array_agg` will now set the array dimensions to 1.
Improved :class:`_types.ARRAY` processing to accept ``None`` values as
value of a multi-array.

Fixes: #7083
Change-Id: Iafec4f77fde9719ccc7c8535bf6235dbfbc62102

doc/build/changelog/unreleased_20/7083.rst [new file with mode: 0644]
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/config.py
test/dialect/postgresql/test_types.py
test/sql/test_functions.py

diff --git a/doc/build/changelog/unreleased_20/7083.rst b/doc/build/changelog/unreleased_20/7083.rst
new file mode 100644 (file)
index 0000000..6b3836a
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 7083
+
+    The :class:`_functions.array_agg` will now set the array dimensions to 1.
+    Improved :class:`_types.ARRAY` processing to accept ``None`` values as
+    value of a multi-array.
index a028e7fedbd3ca36e12a2b413c665a34900702ed..c04f5fa1d245eaa7c39dfe87ac33954385d37aba 100644 (file)
@@ -1407,7 +1407,9 @@ class array_agg(GenericFunction[_T]):
             if isinstance(type_from_args, sqltypes.ARRAY):
                 kwargs["type_"] = type_from_args
             else:
-                kwargs["type_"] = default_array_type(type_from_args)
+                kwargs["type_"] = default_array_type(
+                    type_from_args, dimensions=1
+                )
         kwargs["_parsed_args"] = fn_args
         super(array_agg, self).__init__(*fn_args, **kwargs)
 
index fd52ec6ea15802a8dbd1cf9e7f9662bbcf36228e..414ff03c3e0fb8da6bc00ec0155e17a5df43f287 100644 (file)
@@ -3065,6 +3065,8 @@ class ARRAY(
                     dim - 1 if dim is not None else None,
                     collection_callable,
                 )
+                if x is not None
+                else None
                 for x in arr
             )
 
index e418b48be726a886bee6a4d1720a72eedf0219b5..1cb463977b2858a9122889844f5e89cb77ac0b22 100644 (file)
@@ -13,6 +13,8 @@ import collections
 import typing
 from typing import Any
 from typing import Iterable
+from typing import Optional
+from typing import overload
 from typing import Tuple
 from typing import Union
 
@@ -37,6 +39,15 @@ else:
     _fixture_functions = None  # installed by plugin_base
 
 
+@overload
+def combinations(
+    *comb: Union[Any, Tuple[Any, ...]],
+    argnames: Optional[str] = None,
+    id_: Optional[str] = None,
+):
+    ...
+
+
 def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str):
     r"""Deliver multiple versions of a test based on positional combinations.
 
index 92fcfbcab39f49e6aa2831444db851d6fb2f3fb7..b5c20bd8d42c8cee53059111d7147df6bbd0ea9c 100644 (file)
@@ -36,6 +36,7 @@ from sqlalchemy import types
 from sqlalchemy import Unicode
 from sqlalchemy import util
 from sqlalchemy.dialects import postgresql
+from sqlalchemy.dialects.postgresql import aggregate_order_by
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import DATEMULTIRANGE
 from sqlalchemy.dialects.postgresql import DATERANGE
@@ -1901,6 +1902,48 @@ class ArrayRoundTripTest:
         stmt = select(func.array_agg(values_table.c.value)[2:4])
         eq_(connection.execute(stmt).scalar(), [2, 3, 4])
 
+    def test_array_agg_json(self, metadata, connection):
+        table = Table(
+            "values", metadata, Column("id", Integer), Column("bar", JSON)
+        )
+        metadata.create_all(connection)
+        connection.execute(
+            table.insert(),
+            [{"id": 1, "bar": [{"buz": 1}]}, {"id": 2, "bar": None}],
+        )
+
+        arg = aggregate_order_by(table.c.bar, table.c.id)
+        stmt = select(sa.func.array_agg(arg))
+        eq_(connection.execute(stmt).scalar(), [[{"buz": 1}], None])
+
+        arg = aggregate_order_by(table.c.bar, table.c.id.desc())
+        stmt = select(sa.func.array_agg(arg))
+        eq_(connection.execute(stmt).scalar(), [None, [{"buz": 1}]])
+
+    @testing.combinations(ARRAY, postgresql.ARRAY, argnames="cls")
+    def test_array_none(self, connection, metadata, cls):
+        table = Table(
+            "values", metadata, Column("id", Integer), Column("bar", cls(JSON))
+        )
+        metadata.create_all(connection)
+        connection.execute(
+            table.insert().values(
+                [
+                    {
+                        "id": 1,
+                        "bar": sa.text("""array['[{"x": 1}]'::json, null]"""),
+                    },
+                    {"id": 2, "bar": None},
+                ]
+            )
+        )
+
+        stmt = select(table.c.bar).order_by(table.c.id)
+        eq_(connection.scalars(stmt).all(), [[[{"x": 1}], None], None])
+
+        stmt = select(table.c.bar).order_by(table.c.id.desc())
+        eq_(connection.scalars(stmt).all(), [None, [[{"x": 1}], None]])
+
     def test_array_index_slice_exprs(self, connection):
         """test a variety of expressions that sometimes need parenthesizing"""
 
index 6c00660ffea9658c1979804ebc9421450f5bffed..a48d42f50763e28bb09c8b2557eb5f4f92416e56 100644 (file)
@@ -948,11 +948,14 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):
         expr = func.array_agg(column("data", Integer))
         is_(expr.type._type_affinity, ARRAY)
         is_(expr.type.item_type._type_affinity, Integer)
+        is_(expr.type.dimensions, 1)
 
     def test_array_agg_array_datatype(self):
-        expr = func.array_agg(column("data", ARRAY(Integer)))
+        col = column("data", ARRAY(Integer))
+        expr = func.array_agg(col)
         is_(expr.type._type_affinity, ARRAY)
         is_(expr.type.item_type._type_affinity, Integer)
+        eq_(expr.type.dimensions, col.type.dimensions)
 
     def test_array_agg_array_literal_implicit_type(self):
         from sqlalchemy.dialects.postgresql import array, ARRAY as PG_ARRAY