From 9633d3c15d42026f8f45f5a4d201a5d72e57b8d4 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Fri, 14 Mar 2025 12:02:20 +0100 Subject: [PATCH] Cast empty PostgreSQL ARRAY from the type specified to array() When using `postgresql.array()` with an empty list as `clauses` and a `type_` argument specified, we now cast the resulting `ARRAY[]` SQL expression using that type information thus removing the need for an explicit cast on user side. We need to add the 'type' attribute to the cache key of `postgresql.array` in order to distinguish statements produced with an empty list clause and with or without a `type_` argument. --- doc/build/changelog/unreleased_20/12432.rst | 8 ++++++ lib/sqlalchemy/dialects/postgresql/array.py | 19 ++++++++++++++ lib/sqlalchemy/dialects/postgresql/base.py | 2 ++ test/dialect/postgresql/test_compiler.py | 28 +++++++++++++++++++++ test/dialect/postgresql/test_query.py | 4 +++ test/sql/test_compare.py | 1 + 6 files changed, 62 insertions(+) create mode 100644 doc/build/changelog/unreleased_20/12432.rst diff --git a/doc/build/changelog/unreleased_20/12432.rst b/doc/build/changelog/unreleased_20/12432.rst new file mode 100644 index 0000000000..547539d878 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12432.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 12432 + + When building a PostgreSQL ``ARRAY`` literal using + :class:`_postgresql.array` with an empty ``clauses`` argument, use the + ``type_`` argument to cast the resulting ``ARRAY[]`` SQL expression. + Pull request courtesy Denis Laxalde. diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index f32f146664..c6ee8ebf92 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -24,6 +24,7 @@ from ... import types as sqltypes from ... import util from ...sql import expression from ...sql import operators +from ...sql.visitors import InternalTraversal if TYPE_CHECKING: from ...engine.interfaces import Dialect @@ -96,6 +97,20 @@ class array(expression.ExpressionClauseList[_T]): array(["foo", "bar"], type_=CHAR) + In particular, when constructing an empty array, the ``type_`` argument + will be used as a type cast so that:: + + stmt = array([], type_=Integer) + print(stmt.compile(dialect=postgresql.dialect())) + + Produces: + + .. sourcecode:: sql + + ARRAY[]::INTEGER[] + + As required by PostgreSQL for empty arrays. + Multidimensional arrays are produced by nesting :class:`.array` constructs. The dimensionality of the final :class:`_types.ARRAY` type is calculated by @@ -129,6 +144,10 @@ class array(expression.ExpressionClauseList[_T]): stringify_dialect = "postgresql" inherit_cache = True + _traverse_internals = ( + expression.ExpressionClauseList._traverse_internals + + [("type", InternalTraversal.dp_type)] + ) def __init__( self, diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 6852080303..38812d9112 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1807,6 +1807,8 @@ class PGCompiler(compiler.SQLCompiler): }""" def visit_array(self, element, **kw): + if not element.clauses and not element.type.item_type._isnull: + return "ARRAY[]::%s" % element.type.compile(self.dialect) return "ARRAY[%s]" % self.visit_clauselist(element, **kw) def visit_slice(self, element, **kw): diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index ac49f6f4b5..0ec008e030 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1988,6 +1988,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): String, ) + @testing.combinations( + ("with type_", Date, "ARRAY[]::DATE[]"), + ("no type_", None, "ARRAY[]"), + id_="iaa", + ) + def test_array_literal_empty(self, type_, expected): + self.assert_compile(postgresql.array([], type_=type_), expected) + def test_array_literal(self): self.assert_compile( func.array_dims( @@ -4238,3 +4246,23 @@ class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase): ), compare_values=False, ) + + def test_array(self): + self._run_cache_key_equal_fixture( + lambda: ( + array([0]), + array([0], type_=Integer), + array([1], type_=Integer), + ), + compare_values=False, + ) + + def test_array_empty(self): + self._run_cache_key_fixture( + lambda: ( + array([], type_=Integer), + array([], type_=Text), + array([0]), + ), + compare_values=True, + ) diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index f8bb9dbc79..c55cd0a5d7 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -1640,6 +1640,10 @@ class TableValuedRoundTripTest(fixtures.TestBase): eq_(connection.execute(stmt).all(), [(4, 1), (3, 2), (2, 3), (1, 4)]) + def test_array_empty_with_type(self, connection): + stmt = select(postgresql.array([], type_=Integer)) + eq_(connection.execute(stmt).all(), [([],)]) + def test_plain_old_unnest(self, connection): fn = func.unnest( postgresql.array(["one", "two", "three", "four"]) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 8b1869e8d0..79c520ed3b 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1661,6 +1661,7 @@ class HasCacheKeySubclass(fixtures.TestBase): {"_with_options", "_raw_columns", "_setup_joins"}, {"args"}, ), + "array": ({"operator", "type", "clauses"}, {"clauses", "type_"}), "next_value": ({"sequence"}, {"seq"}), } -- 2.47.2