]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Cast empty PostgreSQL ARRAY from the type specified to array() 12435/head
authorDenis Laxalde <denis@laxalde.org>
Fri, 14 Mar 2025 11:02:20 +0000 (12:02 +0100)
committerDenis Laxalde <denis@laxalde.org>
Wed, 19 Mar 2025 07:18:32 +0000 (08:18 +0100)
When using `postgresql.array()` with an empty list as `clauses` and a
`type_` argument specified, we now cast the resulting `ARRAY[]` SQL
expression using that type information thus removing the need for an
explicit cast on user side.

We need to add the 'type' attribute to the cache key of
`postgresql.array` in order to distinguish statements produced with an
empty list clause and with or without a `type_` argument.

doc/build/changelog/unreleased_20/12432.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_query.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_20/12432.rst b/doc/build/changelog/unreleased_20/12432.rst
new file mode 100644 (file)
index 0000000..547539d
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 12432
+
+    When building a PostgreSQL ``ARRAY`` literal using
+    :class:`_postgresql.array` with an empty ``clauses`` argument, use the
+    ``type_`` argument to cast the resulting ``ARRAY[]`` SQL expression.
+    Pull request courtesy Denis Laxalde.
index f32f146664209f7d0f6cb5f3234c60d0e6e8edaa..c6ee8ebf92036c6966d512ebeff068024fe0d9dc 100644 (file)
@@ -24,6 +24,7 @@ from ... import types as sqltypes
 from ... import util
 from ...sql import expression
 from ...sql import operators
+from ...sql.visitors import InternalTraversal
 
 if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
@@ -96,6 +97,20 @@ class array(expression.ExpressionClauseList[_T]):
 
         array(["foo", "bar"], type_=CHAR)
 
+    In particular, when constructing an empty array, the ``type_`` argument
+    will be used as a type cast so that::
+
+        stmt = array([], type_=Integer)
+        print(stmt.compile(dialect=postgresql.dialect()))
+
+    Produces:
+
+    .. sourcecode:: sql
+
+        ARRAY[]::INTEGER[]
+
+    As required by PostgreSQL for empty arrays.
+
     Multidimensional arrays are produced by nesting :class:`.array` constructs.
     The dimensionality of the final :class:`_types.ARRAY`
     type is calculated by
@@ -129,6 +144,10 @@ class array(expression.ExpressionClauseList[_T]):
 
     stringify_dialect = "postgresql"
     inherit_cache = True
+    _traverse_internals = (
+        expression.ExpressionClauseList._traverse_internals
+        + [("type", InternalTraversal.dp_type)]
+    )
 
     def __init__(
         self,
index 6852080303ab75446693948063bbd703ff83d8ff..38812d911253917738d5edbd210103695e0b5d1f 100644 (file)
@@ -1807,6 +1807,8 @@ class PGCompiler(compiler.SQLCompiler):
         }"""
 
     def visit_array(self, element, **kw):
+        if not element.clauses and not element.type.item_type._isnull:
+            return "ARRAY[]::%s" % element.type.compile(self.dialect)
         return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
 
     def visit_slice(self, element, **kw):
index ac49f6f4b5105f481d3b32d497b9d0d224fff9f0..0ec008e0301b9d0e02e86ace83b143cbe8e51d12 100644 (file)
@@ -1988,6 +1988,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             String,
         )
 
+    @testing.combinations(
+        ("with type_", Date, "ARRAY[]::DATE[]"),
+        ("no type_", None, "ARRAY[]"),
+        id_="iaa",
+    )
+    def test_array_literal_empty(self, type_, expected):
+        self.assert_compile(postgresql.array([], type_=type_), expected)
+
     def test_array_literal(self):
         self.assert_compile(
             func.array_dims(
@@ -4238,3 +4246,23 @@ class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase):
             ),
             compare_values=False,
         )
+
+    def test_array(self):
+        self._run_cache_key_equal_fixture(
+            lambda: (
+                array([0]),
+                array([0], type_=Integer),
+                array([1], type_=Integer),
+            ),
+            compare_values=False,
+        )
+
+    def test_array_empty(self):
+        self._run_cache_key_fixture(
+            lambda: (
+                array([], type_=Integer),
+                array([], type_=Text),
+                array([0]),
+            ),
+            compare_values=True,
+        )
index f8bb9dbc79d35ca926da6f209b17e1e2242f3817..c55cd0a5d7ce38a3bc555c9817ba586916fa9ffb 100644 (file)
@@ -1640,6 +1640,10 @@ class TableValuedRoundTripTest(fixtures.TestBase):
 
         eq_(connection.execute(stmt).all(), [(4, 1), (3, 2), (2, 3), (1, 4)])
 
+    def test_array_empty_with_type(self, connection):
+        stmt = select(postgresql.array([], type_=Integer))
+        eq_(connection.execute(stmt).all(), [([],)])
+
     def test_plain_old_unnest(self, connection):
         fn = func.unnest(
             postgresql.array(["one", "two", "three", "four"])
index 8b1869e8d0d97aa57b9daecb5b85e7b4a24170bc..79c520ed3b5812772876ec89e753af808a388655 100644 (file)
@@ -1661,6 +1661,7 @@ class HasCacheKeySubclass(fixtures.TestBase):
             {"_with_options", "_raw_columns", "_setup_joins"},
             {"args"},
         ),
+        "array": ({"operator", "type", "clauses"}, {"clauses", "type_"}),
         "next_value": ({"sequence"}, {"seq"}),
     }