]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement literal stringification for arrays
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 15 Jun 2022 16:42:44 +0000 (12:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 15 Jun 2022 18:32:53 +0000 (14:32 -0400)
as we already implement stringification for the contents,
provide a bracketed syntax for default and ARRAY literal
for PG specifically.   ARRAY literal seems much simpler to
render than their quoted syntax which requires double quotes
for strings.

also open up testing for pg8000 which has likely been
fine with arrays for awhile now, bump the version pin
also.

Fixes: #8138
Change-Id: Id85b052b0a9564d6aa1489160e58b7359f130fdd

doc/build/changelog/unreleased_20/8138.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/suite/test_types.py
setup.cfg
test/dialect/postgresql/test_types.py
test/requirements.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/8138.rst b/doc/build/changelog/unreleased_20/8138.rst
new file mode 100644 (file)
index 0000000..510e8f9
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 8138
+
+    Added literal type rendering for the :class:`_sqltypes.ARRAY` and
+    :class:`_postgresql.ARRAY` datatypes. The generic stringify will render
+    using brackets, e.g. ``[1, 2, 3]`` and the PostgreSQL specific will use the
+    ARRAY literal e.g. ``ARRAY[1, 2, 3]``.   Multiple dimensions and quoting
+    are also taken into account.
index 3b5eaed30e3c9107ee3b95b0badb480429fff362..515eb2d1503bf16bf207fef269225a35abfce7bc 100644 (file)
@@ -310,35 +310,6 @@ class ARRAY(sqltypes.ARRAY):
     def compare_values(self, x, y):
         return x == y
 
-    def _proc_array(self, arr, itemproc, dim, collection):
-        if dim is None:
-            arr = list(arr)
-        if (
-            dim == 1
-            or dim is None
-            and (
-                # this has to be (list, tuple), or at least
-                # not hasattr('__iter__'), since Py3K strings
-                # etc. have __iter__
-                not arr
-                or not isinstance(arr[0], (list, tuple))
-            )
-        ):
-            if itemproc:
-                return collection(itemproc(x) for x in arr)
-            else:
-                return collection(arr)
-        else:
-            return collection(
-                self._proc_array(
-                    x,
-                    itemproc,
-                    dim - 1 if dim is not None else None,
-                    collection,
-                )
-                for x in arr
-            )
-
     @util.memoized_property
     def _against_native_enum(self):
         return (
@@ -346,6 +317,24 @@ class ARRAY(sqltypes.ARRAY):
             and self.item_type.native_enum
         )
 
+    def literal_processor(self, dialect):
+        item_proc = self.item_type.dialect_impl(dialect).literal_processor(
+            dialect
+        )
+        if item_proc is None:
+            return None
+
+        def to_str(elements):
+            return f"ARRAY[{', '.join(elements)}]"
+
+        def process(value):
+            inner = self._apply_item_processor(
+                value, item_proc, self.dimensions, to_str
+            )
+            return inner
+
+        return process
+
     def bind_processor(self, dialect):
         item_proc = self.item_type.dialect_impl(dialect).bind_processor(
             dialect
@@ -355,7 +344,7 @@ class ARRAY(sqltypes.ARRAY):
             if value is None:
                 return value
             else:
-                return self._proc_array(
+                return self._apply_item_processor(
                     value, item_proc, self.dimensions, list
                 )
 
@@ -370,7 +359,7 @@ class ARRAY(sqltypes.ARRAY):
             if value is None:
                 return value
             else:
-                return self._proc_array(
+                return self._apply_item_processor(
                     value,
                     item_proc,
                     self.dimensions,
index 32f0813f5d9e95e50e5b28cac7e4235d97687f22..b4b444f23f08aecf205a557abc2834c1b7b7e344 100644 (file)
@@ -2964,6 +2964,64 @@ class ARRAY(
         if isinstance(self.item_type, SchemaEventTarget):
             self.item_type._set_parent_with_dispatch(parent)
 
+    def literal_processor(self, dialect):
+        item_proc = self.item_type.dialect_impl(dialect).literal_processor(
+            dialect
+        )
+        if item_proc is None:
+            return None
+
+        def to_str(elements):
+            return f"[{', '.join(elements)}]"
+
+        def process(value):
+            inner = self._apply_item_processor(
+                value, item_proc, self.dimensions, to_str
+            )
+            return inner
+
+        return process
+
+    def _apply_item_processor(self, arr, itemproc, dim, collection_callable):
+        """Helper method that can be used by bind_processor(),
+        literal_processor(), etc. to apply an item processor to elements of
+        an array value, taking into account the 'dimensions' for this
+        array type.
+
+        See the Postgresql ARRAY datatype for usage examples.
+
+        .. versionadded:: 2.0
+
+        """
+
+        if dim is None:
+            arr = list(arr)
+        if (
+            dim == 1
+            or dim is None
+            and (
+                # this has to be (list, tuple), or at least
+                # not hasattr('__iter__'), since Py3K strings
+                # etc. have __iter__
+                not arr
+                or not isinstance(arr[0], (list, tuple))
+            )
+        ):
+            if itemproc:
+                return collection_callable(itemproc(x) for x in arr)
+            else:
+                return collection_callable(arr)
+        else:
+            return collection_callable(
+                self._apply_item_processor(
+                    x,
+                    itemproc,
+                    dim - 1 if dim is not None else None,
+                    collection_callable,
+                )
+                for x in arr
+            )
+
 
 class TupleType(TypeEngine[Tuple[Any, ...]]):
     """represent the composite type of a Tuple."""
index 3913799569caf510bb565acc0e0a3516e123622f..9461298b9fcffeb182ebfc3d6e77de45b5851c99 100644 (file)
@@ -17,6 +17,7 @@ from ..config import requirements
 from ..schema import Column
 from ..schema import Table
 from ... import and_
+from ... import ARRAY
 from ... import BigInteger
 from ... import bindparam
 from ... import Boolean
@@ -222,6 +223,61 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
         self._test_null_strings(connection)
 
 
+class ArrayTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+    """Add ARRAY test suite, #8138.
+
+    This only works on PostgreSQL right now.
+
+    """
+
+    __requires__ = ("array_type",)
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "array_table",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("single_dim", ARRAY(Integer)),
+            Column("multi_dim", ARRAY(String, dimensions=2)),
+        )
+
+    def test_array_roundtrip(self, connection):
+        array_table = self.tables.array_table
+
+        connection.execute(
+            array_table.insert(),
+            {
+                "id": 1,
+                "single_dim": [1, 2, 3],
+                "multi_dim": [["one", "two"], ["thr'ee", "réve🐍 illé"]],
+            },
+        )
+        row = connection.execute(
+            select(array_table.c.single_dim, array_table.c.multi_dim)
+        ).first()
+        eq_(row, ([1, 2, 3], [["one", "two"], ["thr'ee", "réve🐍 illé"]]))
+
+    def test_literal_simple(self, literal_round_trip):
+        literal_round_trip(
+            ARRAY(Integer),
+            ([1, 2, 3],),
+            ([1, 2, 3],),
+            support_whereclause=False,
+        )
+
+    def test_literal_complex(self, literal_round_trip):
+        literal_round_trip(
+            ARRAY(String, dimensions=2),
+            ([["one", "two"], ["thr'ee", "réve🐍 illé"]],),
+            ([["one", "two"], ["thr'ee", "réve🐍 illé"]],),
+            support_whereclause=False,
+        )
+
+
 class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     __requires__ = ("binary_literals",)
     __backend__ = True
@@ -1779,6 +1835,7 @@ class NativeUUIDTest(UuidTest):
 
 
 __all__ = (
+    "ArrayTest",
     "BinaryTest",
     "UnicodeVarcharTest",
     "UnicodeTextTest",
index 80406827523e5b440166cc22cf3cb2b595977b8c..1be662d2afb7eb4b86a55cfa30122e6b32b142f4 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -59,7 +59,7 @@ oracle =
 oracle_oracledb =
     oracledb>=1.0.1
 postgresql = psycopg2>=2.7
-postgresql_pg8000 = pg8000>=1.16.6,!=1.29.0
+postgresql_pg8000 = pg8000>=1.29.1
 postgresql_asyncpg =
     %(asyncio)s
     asyncpg
index 266263d5fb9e3b6968faf7d43c9cb2407c9cffdf..fd4b91db1c40ab1bed0961fc34690eb1cf570b8e 100644 (file)
@@ -19,6 +19,7 @@ from sqlalchemy import Float
 from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import null
 from sqlalchemy import Numeric
@@ -52,6 +53,7 @@ from sqlalchemy.orm import Session
 from sqlalchemy.sql import bindparam
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
@@ -64,6 +66,7 @@ from sqlalchemy.testing.assertsql import RegexSQL
 from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.testing.suite import test_types as suite
 from sqlalchemy.testing.util import round_decimal
+from sqlalchemy.types import UserDefinedType
 
 
 class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
@@ -1230,6 +1233,23 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
             render_postcompile=True,
         )
 
+    def test_array_literal_render_no_inner_render(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            def get_col_spec(self, **kw):
+                return "MYTYPE"
+
+        with expect_raises_message(
+            NotImplementedError,
+            r"Don't know how to literal-quote value \[1, 2, 3\]",
+        ):
+            self.assert_compile(
+                select(literal([1, 2, 3], ARRAY(MyType()))),
+                "nothing",
+                literal_binds=True,
+            )
+
     def test_array_in_str_psycopg2_cast(self):
         expr = column("x", postgresql.ARRAY(String(15))).in_(
             [["one", "two"], ["three", "four"]]
index 2d0876158d0b1aec58a03fac4fdb4f7fc76e041c..bea861a83ff1d9c4162a94da9beb0a92e5e1db98 100644 (file)
@@ -969,12 +969,7 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def array_type(self):
-        return only_on(
-            [
-                lambda config: against(config, "postgresql")
-                and not against(config, "+pg8000")
-            ]
-        )
+        return only_on([lambda config: against(config, "postgresql")])
 
     @property
     def json_type(self):
@@ -1356,10 +1351,7 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def postgresql_jsonb(self):
-        return only_on("postgresql >= 9.4") + skip_if(
-            lambda config: config.db.dialect.driver == "pg8000"
-            and config.db.dialect._dbapi_version <= (1, 10, 1)
-        )
+        return only_on("postgresql >= 9.4")
 
     @property
     def native_hstore(self):
index ef39157269c9064a81585fa46406561ad1590b78..04aa4e000e32e87ce3ce9eedaafda1d26903004a 100644 (file)
@@ -93,6 +93,7 @@ from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.util import picklers
+from sqlalchemy.types import UserDefinedType
 
 
 def _all_dialect_modules():
@@ -2904,7 +2905,7 @@ class JSONTest(fixtures.TestBase):
         eq_(bindproc(expr.right.value), "'five'")
 
 
-class ArrayTest(fixtures.TestBase):
+class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
     def _myarray_fixture(self):
         class MyArray(ARRAY):
             pass
@@ -2957,6 +2958,44 @@ class ArrayTest(fixtures.TestBase):
         assert isinstance(arrtable.c.intarr[1:3].type, MyArray)
         assert isinstance(arrtable.c.strarr[1:3].type, MyArray)
 
+    def test_array_literal_simple(self):
+        self.assert_compile(
+            select(literal([1, 2, 3], ARRAY(Integer))),
+            "SELECT [1, 2, 3] AS anon_1",
+            literal_binds=True,
+            dialect="default",
+        )
+
+    def test_array_literal_complex(self):
+        self.assert_compile(
+            select(
+                literal(
+                    [["one", "two"], ["thr'ee", "réve🐍 illé"]],
+                    ARRAY(String, dimensions=2),
+                )
+            ),
+            "SELECT [['one', 'two'], ['thr''ee', 'réve🐍 illé']] AS anon_1",
+            literal_binds=True,
+            dialect="default",
+        )
+
+    def test_array_literal_render_no_inner_render(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            def get_col_spec(self, **kw):
+                return "MYTYPE"
+
+        with expect_raises_message(
+            NotImplementedError,
+            r"Don't know how to literal-quote value \[1, 2, 3\]",
+        ):
+            self.assert_compile(
+                select(literal([1, 2, 3], ARRAY(MyType()))),
+                "nothing",
+                literal_binds=True,
+            )
+
 
 MyCustomType = MyTypeDec = None