]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support multidimensional array literals in Postgresql
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Jul 2019 19:46:35 +0000 (15:46 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Jul 2019 20:21:54 +0000 (16:21 -0400)
Added support for multidimensional Postgresql array literals via nesting
the :class:`.postgresql.array` object within another one.  The
multidimensional array type is detected automatically.

Fixes: #4756
Change-Id: Ie2107ad3cf291112f6ca330dc90dc15a0a940cee

doc/build/changelog/unreleased_13/4756.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/array.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/4756.rst b/doc/build/changelog/unreleased_13/4756.rst
new file mode 100644 (file)
index 0000000..d27c392
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 4756
+
+    Added support for multidimensional Postgresql array literals via nesting
+    the :class:`.postgresql.array` object within another one.  The
+    multidimensional array type is detected automatically.
+
+    .. seealso::
+
+        :class:`.postgresql.array`
index 594cd3a0c1ef6c91f62a3d2e2053ab49bbb8cf6a..81bde2a02cc40c4c4ddbc4f57fb20d11a3bca0cb 100644 (file)
@@ -60,7 +60,7 @@ class array(expression.Tuple):
                         array([1,2]) + array([3,4,5])
                     ])
 
-        print stmt.compile(dialect=postgresql.dialect())
+        print(stmt.compile(dialect=postgresql.dialect()))
 
     Produces the SQL::
 
@@ -73,6 +73,24 @@ class array(expression.Tuple):
 
         array(['foo', 'bar'], type_=CHAR)
 
+    Multidimensional arrays are produced by nesting :class:`.array` constructs.
+    The dimensionality of the final :class:`.ARRAY` type is calculated by
+    recursively adding the dimensions of the inner :class:`.ARRAY` type::
+
+        stmt = select([
+            array([
+                array([1, 2]), array([3, 4]), array([column('q'), column('x')])
+            ])
+        ])
+        print(stmt.compile(dialect=postgresql.dialect()))
+
+    Produces::
+
+        SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
+        ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
+
+    .. versionadded:: 1.3.6 added support for multidimensional array literals
+
     .. seealso::
 
         :class:`.postgresql.ARRAY`
@@ -83,7 +101,15 @@ class array(expression.Tuple):
 
     def __init__(self, clauses, **kw):
         super(array, self).__init__(*clauses, **kw)
-        self.type = ARRAY(self.type)
+        if isinstance(self.type, ARRAY):
+            self.type = ARRAY(
+                self.type.item_type,
+                dimensions=self.type.dimensions + 1
+                if self.type.dimensions is not None
+                else 2,
+            )
+        else:
+            self.type = ARRAY(self.type)
 
     def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
         if _assume_scalar or operator is operators.getitem:
index 72335ebe3978562c4b8b901e4eebe095f5deeae9..557b916222eca70bc400219e190ba3210cca7e19 100644 (file)
@@ -950,6 +950,55 @@ class TimePrecisionTest(fixtures.TestBase, AssertsCompiledSQL):
 class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
     __dialect__ = "postgresql"
 
+    def test_array_literal(self):
+        obj = postgresql.array([1, 2]) + postgresql.array([3, 4, 5])
+
+        self.assert_compile(
+            obj,
+            "ARRAY[%(param_1)s, %(param_2)s] || "
+            "ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]",
+            params={
+                "param_1": 1,
+                "param_2": 2,
+                "param_3": 3,
+                "param_4": 4,
+                "param_5": 5,
+            },
+        )
+        self.assert_compile(
+            obj[1],
+            "(ARRAY[%(param_1)s, %(param_2)s] || ARRAY[%(param_3)s, "
+            "%(param_4)s, %(param_5)s])[%(param_6)s]",
+            params={
+                "param_1": 1,
+                "param_2": 2,
+                "param_3": 3,
+                "param_4": 4,
+                "param_5": 5,
+            },
+        )
+
+    def test_array_literal_getitem_multidim(self):
+        obj = postgresql.array(
+            [postgresql.array([1, 2]), postgresql.array([3, 4])]
+        )
+
+        self.assert_compile(
+            obj,
+            "ARRAY[ARRAY[%(param_1)s, %(param_2)s], "
+            "ARRAY[%(param_3)s, %(param_4)s]]",
+        )
+        self.assert_compile(
+            obj[1],
+            "(ARRAY[ARRAY[%(param_1)s, %(param_2)s], "
+            "ARRAY[%(param_3)s, %(param_4)s]])[%(param_5)s]",
+        )
+        self.assert_compile(
+            obj[1][0],
+            "(ARRAY[ARRAY[%(param_1)s, %(param_2)s], "
+            "ARRAY[%(param_3)s, %(param_4)s]])[%(param_5)s][%(param_6)s]",
+        )
+
     def test_array_type_render_str(self):
         self.assert_compile(postgresql.ARRAY(Unicode(30)), "VARCHAR(30)[]")
 
@@ -1350,7 +1399,7 @@ class ArrayRoundTripTest(object):
             [[util.ue("m\xe4\xe4")], [util.ue("m\xf6\xf6")]],
         )
 
-    def test_array_literal(self):
+    def test_array_literal_roundtrip(self):
         eq_(
             testing.db.scalar(
                 select(
@@ -1360,6 +1409,67 @@ class ArrayRoundTripTest(object):
             [1, 2, 3, 4, 5],
         )
 
+        eq_(
+            testing.db.scalar(
+                select(
+                    [
+                        (
+                            postgresql.array([1, 2])
+                            + postgresql.array([3, 4, 5])
+                        )[3]
+                    ]
+                )
+            ),
+            3,
+        )
+
+        eq_(
+            testing.db.scalar(
+                select(
+                    [
+                        (
+                            postgresql.array([1, 2])
+                            + postgresql.array([3, 4, 5])
+                        )[2:4]
+                    ]
+                )
+            ),
+            [2, 3, 4],
+        )
+
+    def test_array_literal_multidimensional_roundtrip(self):
+        eq_(
+            testing.db.scalar(
+                select(
+                    [
+                        postgresql.array(
+                            [
+                                postgresql.array([1, 2]),
+                                postgresql.array([3, 4]),
+                            ]
+                        )
+                    ]
+                )
+            ),
+            [[1, 2], [3, 4]],
+        )
+
+        eq_(
+            testing.db.scalar(
+                select(
+                    [
+                        postgresql.array(
+                            [
+                                postgresql.array([1, 2]),
+                                postgresql.array([3, 4]),
+                            ]
+                        )[2][1]
+                    ]
+                )
+            ),
+            3,
+        )
+
     def test_array_literal_compare(self):
         eq_(
             testing.db.scalar(select([postgresql.array([1, 2]) < [3, 4, 5]])),