]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Cast empty PostgreSQL ARRAY from the type specified to array()
authorDenis Laxalde <denis@laxalde.org>
Wed, 19 Mar 2025 08:17:27 +0000 (04:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Mar 2025 23:32:09 +0000 (19:32 -0400)
When building a PostgreSQL ``ARRAY`` literal using
:class:`_postgresql.array` with an empty ``clauses`` argument, the
:paramref:`_postgresql.array.type_` parameter is now significant in that it
will be used to render the resulting ``ARRAY[]`` SQL expression with a
cast, such as ``ARRAY[]::INTEGER``. Pull request courtesy Denis Laxalde.

Fixes: #12432
Closes: #12435
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12435
Pull-request-sha: 9633d3c15d42026f8f45f5a4d201a5d72e57b8d4

Change-Id: I29ed7bd0562b82351d22de0658fb46c31cfe44f6

doc/build/changelog/unreleased_20/12432.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_query.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_20/12432.rst b/doc/build/changelog/unreleased_20/12432.rst
new file mode 100644 (file)
index 0000000..ff781fb
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 12432
+
+    When building a PostgreSQL ``ARRAY`` literal using
+    :class:`_postgresql.array` with an empty ``clauses`` argument, the
+    :paramref:`_postgresql.array.type_` parameter is now significant in that it
+    will be used to render the resulting ``ARRAY[]`` SQL expression with a
+    cast, such as ``ARRAY[]::INTEGER``. Pull request courtesy Denis Laxalde.
index f32f146664209f7d0f6cb5f3234c60d0e6e8edaa..9d6212f473229dc088cce251cdb3d94e3a6d73fc 100644 (file)
@@ -24,6 +24,7 @@ from ... import types as sqltypes
 from ... import util
 from ...sql import expression
 from ...sql import operators
+from ...sql.visitors import InternalTraversal
 
 if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
@@ -38,6 +39,7 @@ if TYPE_CHECKING:
     from ...sql.type_api import _LiteralProcessorType
     from ...sql.type_api import _ResultProcessorType
     from ...sql.type_api import TypeEngine
+    from ...sql.visitors import _TraverseInternalsType
     from ...util.typing import Self
 
 
@@ -91,11 +93,32 @@ class array(expression.ExpressionClauseList[_T]):
             ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
 
     An instance of :class:`.array` will always have the datatype
-    :class:`_types.ARRAY`.  The "inner" type of the array is inferred from
-    the values present, unless the ``type_`` keyword argument is passed::
+    :class:`_types.ARRAY`.  The "inner" type of the array is inferred from the
+    values present, unless the :paramref:`_postgresql.array.type_` keyword
+    argument is passed::
 
         array(["foo", "bar"], type_=CHAR)
 
+    When constructing an empty array, the :paramref:`_postgresql.array.type_`
+    argument is particularly important as PostgreSQL server typically requires
+    a cast to be rendered for the inner type in order to render an empty array.
+    SQLAlchemy's compilation for the empty array will produce this cast so
+    that::
+
+        stmt = array([], type_=Integer)
+        print(stmt.compile(dialect=postgresql.dialect()))
+
+    Produces:
+
+    .. sourcecode:: sql
+
+        ARRAY[]::INTEGER[]
+
+    As required by PostgreSQL for empty arrays.
+
+    .. versionadded:: 2.0.40 added support to render empty PostgreSQL array
+       literals with a required cast.
+
     Multidimensional arrays are produced by nesting :class:`.array` constructs.
     The dimensionality of the final :class:`_types.ARRAY`
     type is calculated by
@@ -128,7 +151,11 @@ class array(expression.ExpressionClauseList[_T]):
     __visit_name__ = "array"
 
     stringify_dialect = "postgresql"
-    inherit_cache = True
+
+    _traverse_internals: _TraverseInternalsType = [
+        ("clauses", InternalTraversal.dp_clauseelement_tuple),
+        ("type", InternalTraversal.dp_type),
+    ]
 
     def __init__(
         self,
@@ -137,6 +164,14 @@ class array(expression.ExpressionClauseList[_T]):
         type_: Optional[_TypeEngineArgument[_T]] = None,
         **kw: typing_Any,
     ):
+        r"""Construct an ARRAY literal.
+
+        :param clauses: iterable, such as a list, containing elements to be
+         rendered in the array
+        :param type\_: optional type.  If omitted, the type is inferred
+         from the contents of the array.
+
+        """
         super().__init__(operators.comma_op, *clauses, **kw)
 
         main_type = (
index 28348af15c426041d08898711b6994177b78fba8..b9bb796e2ad6842d2bd6e8dd26697d5d7aa4330c 100644 (file)
@@ -1807,6 +1807,8 @@ class PGCompiler(compiler.SQLCompiler):
         }"""
 
     def visit_array(self, element, **kw):
+        if not element.clauses and not element.type.item_type._isnull:
+            return "ARRAY[]::%s" % element.type.compile(self.dialect)
         return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
 
     def visit_slice(self, element, **kw):
index 058c51145ea8c0208f50c22284e6a8135a91a3d8..370981e19db4f2f377e52120ce297e32fc91b65a 100644 (file)
@@ -38,6 +38,7 @@ from sqlalchemy import tuple_
 from sqlalchemy import types as sqltypes
 from sqlalchemy import UniqueConstraint
 from sqlalchemy import update
+from sqlalchemy import VARCHAR
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.dialects.postgresql import aggregate_order_by
 from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY
@@ -1991,6 +1992,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             String,
         )
 
+    @testing.combinations(
+        ("with type_", Date, "ARRAY[]::DATE[]"),
+        ("no type_", None, "ARRAY[]"),
+        id_="iaa",
+    )
+    def test_array_literal_empty(self, type_, expected):
+        self.assert_compile(postgresql.array([], type_=type_), expected)
+
     def test_array_literal(self):
         self.assert_compile(
             func.array_dims(
@@ -4351,3 +4360,49 @@ class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase):
             ),
             compare_values=False,
         )
+
+    def test_array_equivalent_keys_one_element(self):
+        self._run_cache_key_equal_fixture(
+            lambda: (
+                array([random.randint(0, 10)]),
+                array([random.randint(0, 10)], type_=Integer),
+                array([random.randint(0, 10)], type_=Integer),
+            ),
+            compare_values=False,
+        )
+
+    def test_array_equivalent_keys_two_elements(self):
+        self._run_cache_key_equal_fixture(
+            lambda: (
+                array([random.randint(0, 10), random.randint(0, 10)]),
+                array(
+                    [random.randint(0, 10), random.randint(0, 10)],
+                    type_=Integer,
+                ),
+                array(
+                    [random.randint(0, 10), random.randint(0, 10)],
+                    type_=Integer,
+                ),
+            ),
+            compare_values=False,
+        )
+
+    def test_array_heterogeneous(self):
+        self._run_cache_key_fixture(
+            lambda: (
+                array([], type_=Integer),
+                array([], type_=Text),
+                array([]),
+                array([random.choice(["t1", "t2", "t3"])]),
+                array(
+                    [
+                        random.choice(["t1", "t2", "t3"]),
+                        random.choice(["t1", "t2", "t3"]),
+                    ]
+                ),
+                array([random.choice(["t1", "t2", "t3"])], type_=Text),
+                array([random.choice(["t1", "t2", "t3"])], type_=VARCHAR(30)),
+                array([random.randint(0, 10), random.randint(0, 10)]),
+            ),
+            compare_values=False,
+        )
index f8bb9dbc79d35ca926da6f209b17e1e2242f3817..c55cd0a5d7ce38a3bc555c9817ba586916fa9ffb 100644 (file)
@@ -1640,6 +1640,10 @@ class TableValuedRoundTripTest(fixtures.TestBase):
 
         eq_(connection.execute(stmt).all(), [(4, 1), (3, 2), (2, 3), (1, 4)])
 
+    def test_array_empty_with_type(self, connection):
+        stmt = select(postgresql.array([], type_=Integer))
+        eq_(connection.execute(stmt).all(), [([],)])
+
     def test_plain_old_unnest(self, connection):
         fn = func.unnest(
             postgresql.array(["one", "two", "three", "four"])
index 8b1869e8d0d97aa57b9daecb5b85e7b4a24170bc..c42bdac7c1454e44cd600f46203bda4295b2a309 100644 (file)
@@ -1479,6 +1479,7 @@ class HasCacheKeySubclass(fixtures.TestBase):
             "modifiers",
         },
         "next_value": {"sequence"},
+        "array": ({"type", "clauses"}),
     }
 
     ignore_keys = {
@@ -1661,6 +1662,7 @@ class HasCacheKeySubclass(fixtures.TestBase):
             {"_with_options", "_raw_columns", "_setup_joins"},
             {"args"},
         ),
+        "array": ({"type", "clauses"}, {"clauses", "type_"}),
         "next_value": ({"sequence"}, {"seq"}),
     }