]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
The operators for the Postgresql ARRAY type supports
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 22 Apr 2013 20:57:15 +0000 (16:57 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 22 Apr 2013 20:57:15 +0000 (16:57 -0400)
input types of sets, generators, etc. but only when a dimension
is specified for the ARRAY; otherwise, the dialect
needs to peek inside of "arr[0]" to guess how many
dimensions are in use.  If this occurs with a non
list/tuple type, the error message is now informative
and directs to specify a dimension for the ARRAY.
[ticket:2681]

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/test_postgresql.py

index 7f98c4b0b1af25157ecf380c8734bc24594e5f8c..b64a188f79cb6f840fa076a1a888d73e873494fe 100644 (file)
@@ -6,6 +6,18 @@
 .. changelog::
     :version: 0.8.1
 
+    .. change::
+      :tags: bug, postgresql
+      :tickets: 2681
+
+      The operators for the Postgresql ARRAY type supports
+      input types of sets, generators, etc. but only when a dimension
+      is specified for the ARRAY; otherwise, the dialect
+      needs to peek inside of "arr[0]" to guess how many
+      dimensions are in use.  If this occurs with a non
+      list/tuple type, the error message is now informative
+      and directs to specify a dimension for the ARRAY.
+
     .. change::
       :tags: bug, mysql
       :pullreq: 55
index c59caff8d25cc5c5707d39225c23ed080b9feba9..f3a88ff70970a84e2182a1c8edf9852cbea478aa 100644 (file)
@@ -669,10 +669,22 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine):
     def compare_values(self, x, y):
         return x == y
 
+    def _test_array_of_scalars(self, arr):
+        if not arr:
+            return True
+        else:
+            try:
+                return not isinstance(arr[0], (list, tuple))
+            except TypeError:
+                raise TypeError(
+                            "Cannot auto-coerce ARRAY value of type "
+                            "%s unless dimensions are specified "
+                            "for ARRAY type" % type(arr))
+
     def _proc_array(self, arr, itemproc, dim, collection):
         if dim == 1 or (
                     dim is None and
-                    (not arr or not isinstance(arr[0], (list, tuple)))
+                    self._test_array_of_scalars(arr)
                 ):
             if itemproc:
                 return collection(itemproc(x) for x in arr)
index 5f1ed7604351ac654d5e673a001e30bc5d3aec9e..3455b8ff1028c20fd1dfb6d53639dbae4b9cf696 100644 (file)
@@ -2116,16 +2116,14 @@ class TimePrecisionTest(fixtures.TestBase, AssertsCompiledSQL):
         eq_(t2.c.c5.type.timezone, False)
         eq_(t2.c.c6.type.timezone, True)
 
-class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
+class ArrayTest(fixtures.TablesTest, AssertsExecutionResults):
 
     __only_on__ = 'postgresql'
 
     __unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc'
 
     @classmethod
-    def setup_class(cls):
-        global metadata, arrtable
-        metadata = MetaData(testing.db)
+    def define_tables(cls, metadata):
 
         class ProcValue(TypeDecorator):
             impl = postgresql.ARRAY(Integer, dimensions=2)
@@ -2146,20 +2144,25 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
                     for v in value
                 ]
 
-        arrtable = Table('arrtable', metadata,
+        Table('arrtable', metadata,
                         Column('id', Integer, primary_key=True),
-                        Column('intarr',postgresql.ARRAY(Integer)),
-                         Column('strarr',postgresql.ARRAY(Unicode())),
+                        Column('intarr', postgresql.ARRAY(Integer)),
+                         Column('strarr', postgresql.ARRAY(Unicode())),
                         Column('dimarr', ProcValue)
                     )
-        metadata.create_all()
 
-    def teardown(self):
-        arrtable.delete().execute()
+        Table('dim_arrtable', metadata,
+                        Column('id', Integer, primary_key=True),
+                        Column('intarr', postgresql.ARRAY(Integer, dimensions=1)),
+                         Column('strarr', postgresql.ARRAY(Unicode(), dimensions=1)),
+                        Column('dimarr', ProcValue)
+                    )
 
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
+    def _fixture_456(self, table):
+        testing.db.execute(
+                table.insert(),
+                intarr=[4, 5, 6]
+        )
 
     def test_reflect_array_column(self):
         metadata2 = MetaData(testing.db)
@@ -2170,6 +2173,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
         assert isinstance(tbl.c.strarr.type.item_type, String)
 
     def test_insert_array(self):
+        arrtable = self.tables.arrtable
         arrtable.insert().execute(intarr=[1, 2, 3], strarr=[u'abc',
                                   u'def'])
         results = arrtable.select().execute().fetchall()
@@ -2178,6 +2182,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(results[0]['strarr'], ['abc', 'def'])
 
     def test_array_where(self):
+        arrtable = self.tables.arrtable
         arrtable.insert().execute(intarr=[1, 2, 3], strarr=[u'abc',
                                   u'def'])
         arrtable.insert().execute(intarr=[4, 5, 6], strarr=u'ABC')
@@ -2187,6 +2192,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(results[0]['intarr'], [1, 2, 3])
 
     def test_array_concat(self):
+        arrtable = self.tables.arrtable
         arrtable.insert().execute(intarr=[1, 2, 3],
                     strarr=[u'abc', u'def'])
         results = select([arrtable.c.intarr + [4, 5,
@@ -2195,6 +2201,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(results[0][0], [ 1, 2, 3, 4, 5, 6, ])
 
     def test_array_subtype_resultprocessor(self):
+        arrtable = self.tables.arrtable
         arrtable.insert().execute(intarr=[4, 5, 6],
                                   strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6'
                                   ]])
@@ -2216,66 +2223,100 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
         )
 
     def test_array_getitem_single_type(self):
+        arrtable = self.tables.arrtable
         is_(arrtable.c.intarr[1].type._type_affinity, Integer)
         is_(arrtable.c.strarr[1].type._type_affinity, String)
 
     def test_array_getitem_slice_type(self):
+        arrtable = self.tables.arrtable
         is_(arrtable.c.intarr[1:3].type._type_affinity, postgresql.ARRAY)
         is_(arrtable.c.strarr[1:3].type._type_affinity, postgresql.ARRAY)
 
     def test_array_getitem_single_exec(self):
-        with testing.db.connect() as conn:
-            conn.execute(
-                arrtable.insert(),
-                intarr=[4, 5, 6],
-                strarr=[u'abc', u'def']
-            )
-            eq_(
-                conn.scalar(select([arrtable.c.intarr[2]])),
-                5
-            )
-            conn.execute(
-                arrtable.update().values({arrtable.c.intarr[2]: 7})
-            )
-            eq_(
-                conn.scalar(select([arrtable.c.intarr[2]])),
-                7
-            )
+        arrtable = self.tables.arrtable
+        self._fixture_456(arrtable)
+        eq_(
+            testing.db.scalar(select([arrtable.c.intarr[2]])),
+            5
+        )
+        testing.db.execute(
+            arrtable.update().values({arrtable.c.intarr[2]: 7})
+        )
+        eq_(
+            testing.db.scalar(select([arrtable.c.intarr[2]])),
+            7
+        )
 
     def test_array_getitem_slice_exec(self):
-        with testing.db.connect() as conn:
-            conn.execute(
-                arrtable.insert(),
-                intarr=[4, 5, 6],
-                strarr=[u'abc', u'def']
-            )
-            eq_(
-                conn.scalar(select([arrtable.c.intarr[2:3]])),
-                [5, 6]
-            )
-            conn.execute(
-                arrtable.update().values({arrtable.c.intarr[2:3]: [7, 8]})
-            )
-            eq_(
-                conn.scalar(select([arrtable.c.intarr[2:3]])),
-                [7, 8]
-            )
+        arrtable = self.tables.arrtable
+        testing.db.execute(
+            arrtable.insert(),
+            intarr=[4, 5, 6],
+            strarr=[u'abc', u'def']
+        )
+        eq_(
+            testing.db.scalar(select([arrtable.c.intarr[2:3]])),
+            [5, 6]
+        )
+        testing.db.execute(
+            arrtable.update().values({arrtable.c.intarr[2:3]: [7, 8]})
+        )
+        eq_(
+            testing.db.scalar(select([arrtable.c.intarr[2:3]])),
+            [7, 8]
+        )
 
-    def test_array_contains_exec(self):
-        with testing.db.connect() as conn:
-            conn.execute(
-                arrtable.insert(),
-                intarr=[4, 5, 6]
-            )
-            eq_(
-                conn.scalar(
-                    select([arrtable.c.intarr]).
-                        where(arrtable.c.intarr.contains([4, 5]))
-                ),
-                [4, 5, 6]
-            )
+
+    def _test_undim_array_contains_typed_exec(self, struct):
+        arrtable = self.tables.arrtable
+        self._fixture_456(arrtable)
+        eq_(
+            testing.db.scalar(
+                select([arrtable.c.intarr]).
+                    where(arrtable.c.intarr.contains(struct([4, 5])))
+            ),
+            [4, 5, 6]
+        )
+
+    def test_undim_array_contains_set_exec(self):
+        assert_raises_message(
+            exc.StatementError,
+            "Cannot auto-coerce ARRAY value of type",
+            self._test_undim_array_contains_typed_exec, set
+        )
+
+    def test_undim_array_contains_list_exec(self):
+        self._test_undim_array_contains_typed_exec(list)
+
+    def test_undim_array_contains_generator_exec(self):
+        assert_raises_message(
+            exc.StatementError,
+            "Cannot auto-coerce ARRAY value of type",
+            self._test_undim_array_contains_typed_exec, lambda elem: (x for x in elem)
+        )
+
+    def _test_dim_array_contains_typed_exec(self, struct):
+        dim_arrtable = self.tables.dim_arrtable
+        self._fixture_456(dim_arrtable)
+        eq_(
+            testing.db.scalar(
+                select([dim_arrtable.c.intarr]).
+                    where(dim_arrtable.c.intarr.contains(struct([4, 5])))
+            ),
+            [4, 5, 6]
+        )
+
+    def test_dim_array_contains_set_exec(self):
+        self._test_dim_array_contains_typed_exec(set)
+
+    def test_dim_array_contains_list_exec(self):
+        self._test_dim_array_contains_typed_exec(list)
+
+    def test_dim_array_contains_generator_exec(self):
+        self._test_dim_array_contains_typed_exec(lambda elem: (x for x in elem))
 
     def test_array_contained_by_exec(self):
+        arrtable = self.tables.arrtable
         with testing.db.connect() as conn:
             conn.execute(
                 arrtable.insert(),
@@ -2289,6 +2330,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
             )
 
     def test_array_overlap_exec(self):
+        arrtable = self.tables.arrtable
         with testing.db.connect() as conn:
             conn.execute(
                 arrtable.insert(),
@@ -2303,6 +2345,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
             )
 
     def test_array_any_exec(self):
+        arrtable = self.tables.arrtable
         with testing.db.connect() as conn:
             conn.execute(
                 arrtable.insert(),
@@ -2317,6 +2360,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
             )
 
     def test_array_all_exec(self):
+        arrtable = self.tables.arrtable
         with testing.db.connect() as conn:
             conn.execute(
                 arrtable.insert(),
@@ -2330,6 +2374,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
                 [4, 5, 6]
             )
 
+
     @testing.provide_metadata
     def test_tuple_flag(self):
         metadata = self.metadata
@@ -2361,6 +2406,7 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
         )
 
     def test_dimension(self):
+        arrtable = self.tables.arrtable
         testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4,5, 6]])
         eq_(
             testing.db.scalar(select([arrtable.c.dimarr])),