]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Set up base ARRAY to be compatible with postgresql.ARRAY.
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Apr 2017 14:26:38 +0000 (10:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Apr 2017 14:49:30 +0000 (10:49 -0400)
For some reason, when ARRAY was added to the base it was never linked
to postgresql.ARRAY.   Link the two types and also make base
ARRAY the schema event target so that it supports the same
features as postgresql.ARRAY.

Change-Id: I82fa6c9d2b8c5028dba3a009715f7bc296b2bc0b
Fixes: #3964
doc/build/changelog/changelog_12.rst
lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/sql/sqltypes.py
test/dialect/postgresql/test_types.py

index 815f587f811d382db23ca2f20a5d0df69924f408..7c04210192e8e1e36e446479afad708c62a78a5f 100644 (file)
 .. changelog::
     :version: 1.2.0b1
 
+    .. change:: 3964
+        :tags: bug, postgresql
+        :tickets: 3964
+
+        Fixed bug where the base :class:`.sqltypes.ARRAY` datatype would not
+        invoke the bind/result processors of :class:`.postgresql.ARRAY`.
+
     .. change:: 3963
         :tags: bug, orm
         :tickets: 3963
index 98cab95626adf53f022042757984794d7c4a32ad..009c83c0d49e117ca5c919ec9ebf1f621323bb6b 100644 (file)
@@ -5,7 +5,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from .base import ischema_names
+from .base import ischema_names, colspecs
 from ...sql import expression, operators
 from ...sql.base import SchemaEventTarget
 from ... import types as sqltypes
@@ -114,7 +114,7 @@ CONTAINED_BY = operators.custom_op("<@", precedence=5)
 OVERLAP = operators.custom_op("&&", precedence=5)
 
 
-class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
+class ARRAY(sqltypes.ARRAY):
 
     """PostgreSQL ARRAY type.
 
@@ -248,18 +248,6 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
     def compare_values(self, x, y):
         return x == y
 
-    def _set_parent(self, column):
-        """Support SchemaEventTarget"""
-
-        if isinstance(self.item_type, SchemaEventTarget):
-            self.item_type._set_parent(column)
-
-    def _set_parent_with_dispatch(self, parent):
-        """Support SchemaEventTarget"""
-
-        if isinstance(self.item_type, SchemaEventTarget):
-            self.item_type._set_parent_with_dispatch(parent)
-
     def _proc_array(self, arr, itemproc, dim, collection):
         if dim is None:
             arr = list(arr)
@@ -311,4 +299,5 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
                     tuple if self.as_tuple else list)
         return process
 
+colspecs[sqltypes.ARRAY] = ARRAY
 ischema_names['_array'] = ARRAY
index 8a114ece60ec7a6fc8bb2dc662caba8441644640..b8117e3ca1ef93fdc85aa6ba730f25e44eae05d7 100644 (file)
@@ -2061,7 +2061,7 @@ class JSON(Indexable, TypeEngine):
         return process
 
 
-class ARRAY(Indexable, Concatenable, TypeEngine):
+class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
     """Represent a SQL Array type.
 
     .. note::  This type serves as the basis for all ARRAY operations.
@@ -2199,6 +2199,11 @@ class ARRAY(Indexable, Concatenable, TypeEngine):
 
             return operators.getitem, index, return_type
 
+        def contains(self, *arg, **kw):
+            raise NotImplementedError(
+                "ARRAY.contains() not implemented for the base "
+                "ARRAY type; please use the dialect-specific ARRAY type")
+
         @util.dependencies("sqlalchemy.sql.elements")
         def any(self, elements, other, operator=None):
             """Return ``other operator ANY (array)`` clause.
@@ -2325,6 +2330,18 @@ class ARRAY(Indexable, Concatenable, TypeEngine):
     def compare_values(self, x, y):
         return x == y
 
+    def _set_parent(self, column):
+        """Support SchemaEventTarget"""
+
+        if isinstance(self.item_type, SchemaEventTarget):
+            self.item_type._set_parent(column)
+
+    def _set_parent_with_dispatch(self, parent):
+        """Support SchemaEventTarget"""
+
+        if isinstance(self.item_type, SchemaEventTarget):
+            self.item_type._set_parent_with_dispatch(parent)
+
 
 class REAL(Float):
 
index 807eeb60c42ac8eff4b6074ea42866ce83fb0614..d2e19a04a2451703b03ea328bdf1d3d270ddf124 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises, \
     AssertsCompiledSQL, ComparesTables
 from sqlalchemy.testing import engines, fixtures
 from sqlalchemy import testing
+from sqlalchemy.sql import sqltypes
 import datetime
 from sqlalchemy import Table, MetaData, Column, Integer, Enum, Float, select, \
     func, DateTime, Numeric, exc, String, cast, REAL, TypeDecorator, Unicode, \
@@ -85,7 +86,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
     @testing.fails_on('postgresql+zxjdbc',
                       'zxjdbc has no support for PG arrays')
     @testing.provide_metadata
-    def test_arrays(self):
+    def test_arrays_pg(self):
         metadata = self.metadata
         t1 = Table('t', metadata,
                    Column('x', postgresql.ARRAY(Float)),
@@ -101,6 +102,25 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
             ([5], [5], [6], [decimal.Decimal("6.4")])
         )
 
+    @testing.fails_on('postgresql+zxjdbc',
+                      'zxjdbc has no support for PG arrays')
+    @testing.provide_metadata
+    def test_arrays_base(self):
+        metadata = self.metadata
+        t1 = Table('t', metadata,
+                   Column('x', sqltypes.ARRAY(Float)),
+                   Column('y', sqltypes.ARRAY(REAL)),
+                   Column('z', sqltypes.ARRAY(postgresql.DOUBLE_PRECISION)),
+                   Column('q', sqltypes.ARRAY(Numeric))
+                   )
+        metadata.create_all()
+        t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")])
+        row = t1.select().execute().first()
+        eq_(
+            row,
+            ([5], [5], [6], [decimal.Decimal("6.4")])
+        )
+
 
 class EnumTest(fixtures.TestBase, AssertsExecutionResults):
     __backend__ = True
@@ -987,17 +1007,19 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
         is_(expr.type.item_type.__class__, Integer)
 
 
-class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
+class ArrayRoundTripTest(object):
 
     __only_on__ = 'postgresql'
     __backend__ = True
     __unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc'
 
+    ARRAY = postgresql.ARRAY
+
     @classmethod
     def define_tables(cls, metadata):
 
         class ProcValue(TypeDecorator):
-            impl = postgresql.ARRAY(Integer, dimensions=2)
+            impl = cls.ARRAY(Integer, dimensions=2)
 
             def process_bind_param(self, value, dialect):
                 if value is None:
@@ -1017,15 +1039,15 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
 
         Table('arrtable', metadata,
               Column('id', Integer, primary_key=True),
-              Column('intarr', postgresql.ARRAY(Integer)),
-              Column('strarr', postgresql.ARRAY(Unicode())),
+              Column('intarr', cls.ARRAY(Integer)),
+              Column('strarr', cls.ARRAY(Unicode())),
               Column('dimarr', ProcValue)
               )
 
         Table('dim_arrtable', metadata,
               Column('id', Integer, primary_key=True),
-              Column('intarr', postgresql.ARRAY(Integer, dimensions=1)),
-              Column('strarr', postgresql.ARRAY(Unicode(), dimensions=1)),
+              Column('intarr', cls.ARRAY(Integer, dimensions=1)),
+              Column('strarr', cls.ARRAY(Unicode(), dimensions=1)),
               Column('dimarr', ProcValue)
               )
 
@@ -1038,8 +1060,8 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
     def test_reflect_array_column(self):
         metadata2 = MetaData(testing.db)
         tbl = Table('arrtable', metadata2, autoload=True)
-        assert isinstance(tbl.c.intarr.type, postgresql.ARRAY)
-        assert isinstance(tbl.c.strarr.type, postgresql.ARRAY)
+        assert isinstance(tbl.c.intarr.type, self.ARRAY)
+        assert isinstance(tbl.c.strarr.type, self.ARRAY)
         assert isinstance(tbl.c.intarr.type.item_type, Integer)
         assert isinstance(tbl.c.strarr.type.item_type, String)
 
@@ -1107,19 +1129,19 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
             func.array_cat(
                 array([1, 2, 3]),
                 array([4, 5, 6]),
-                type_=postgresql.ARRAY(Integer)
+                type_=self.ARRAY(Integer)
             )[2:5]
         ])
         eq_(
             testing.db.execute(stmt).scalar(), [2, 3, 4, 5]
         )
 
-    def test_any_all_exprs(self):
+    def test_any_all_exprs_array(self):
         stmt = select([
             3 == any_(func.array_cat(
                 array([1, 2, 3]),
                 array([4, 5, 6]),
-                type_=postgresql.ARRAY(Integer)
+                type_=self.ARRAY(Integer)
             ))
         ])
         eq_(
@@ -1225,17 +1247,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
             7
         )
 
-    def test_undim_array_empty(self):
-        arrtable = self.tables.arrtable
-        self._fixture_456(arrtable)
-        eq_(
-            testing.db.scalar(
-                select([arrtable.c.intarr]).
-                where(arrtable.c.intarr.contains([]))
-            ),
-            [4, 5, 6]
-        )
-
     def test_array_getitem_slice_exec(self):
         arrtable = self.tables.arrtable
         testing.db.execute(
@@ -1255,49 +1266,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
             [7, 8]
         )
 
-    def _test_undim_array_contains_typed_exec(self, struct):
-        arrtable = self.tables.arrtable
-        self._fixture_456(arrtable)
-        eq_(
-            testing.db.scalar(
-                select([arrtable.c.intarr]).
-                where(arrtable.c.intarr.contains(struct([4, 5])))
-            ),
-            [4, 5, 6]
-        )
-
-    def test_undim_array_contains_set_exec(self):
-        self._test_undim_array_contains_typed_exec(set)
-
-    def test_undim_array_contains_list_exec(self):
-        self._test_undim_array_contains_typed_exec(list)
-
-    def test_undim_array_contains_generator_exec(self):
-        self._test_undim_array_contains_typed_exec(
-            lambda elem: (x for x in elem))
-
-    def _test_dim_array_contains_typed_exec(self, struct):
-        dim_arrtable = self.tables.dim_arrtable
-        self._fixture_456(dim_arrtable)
-        eq_(
-            testing.db.scalar(
-                select([dim_arrtable.c.intarr]).
-                where(dim_arrtable.c.intarr.contains(struct([4, 5])))
-            ),
-            [4, 5, 6]
-        )
-
-    def test_dim_array_contains_set_exec(self):
-        self._test_dim_array_contains_typed_exec(set)
-
-    def test_dim_array_contains_list_exec(self):
-        self._test_dim_array_contains_typed_exec(list)
-
-    def test_dim_array_contains_generator_exec(self):
-        self._test_dim_array_contains_typed_exec(
-            lambda elem: (
-                x for x in elem))
-
     def test_multi_dim_roundtrip(self):
         arrtable = self.tables.arrtable
         testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]])
@@ -1306,35 +1274,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
             [[-1, 0, 1], [2, 3, 4]]
         )
 
-    def test_array_contained_by_exec(self):
-        arrtable = self.tables.arrtable
-        with testing.db.connect() as conn:
-            conn.execute(
-                arrtable.insert(),
-                intarr=[6, 5, 4]
-            )
-            eq_(
-                conn.scalar(
-                    select([arrtable.c.intarr.contained_by([4, 5, 6, 7])])
-                ),
-                True
-            )
-
-    def test_array_overlap_exec(self):
-        arrtable = self.tables.arrtable
-        with testing.db.connect() as conn:
-            conn.execute(
-                arrtable.insert(),
-                intarr=[4, 5, 6]
-            )
-            eq_(
-                conn.scalar(
-                    select([arrtable.c.intarr]).
-                    where(arrtable.c.intarr.overlap([7, 6]))
-                ),
-                [4, 5, 6]
-            )
-
     def test_array_any_exec(self):
         arrtable = self.tables.arrtable
         with testing.db.connect() as conn:
@@ -1372,10 +1311,10 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
         t1 = Table(
             't1', metadata,
             Column('id', Integer, primary_key=True),
-            Column('data', postgresql.ARRAY(String(5), as_tuple=True)),
+            Column('data', self.ARRAY(String(5), as_tuple=True)),
             Column(
                 'data2',
-                postgresql.ARRAY(
+                self.ARRAY(
                     Numeric(asdecimal=False), as_tuple=True)
             )
         )
@@ -1416,13 +1355,13 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
             't', m,
             Column(
                 'data_1',
-                postgresql.ARRAY(
+                self.ARRAY(
                     postgresql.ENUM('a', 'b', 'c', name='my_enum_1')
                 )
             ),
             Column(
                 'data_2',
-                postgresql.ARRAY(
+                self.ARRAY(
                     types.Enum('a', 'b', 'c', name='my_enum_2')
                 )
             )
@@ -1437,6 +1376,100 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
         eq_(inspect(testing.db).get_enums(), [])
 
 
+class CoreArrayRoundTripTest(ArrayRoundTripTest,
+                             fixtures.TablesTest, AssertsExecutionResults):
+
+    ARRAY = sqltypes.ARRAY
+
+
+class PGArrayRoundTripTest(ArrayRoundTripTest,
+                           fixtures.TablesTest, AssertsExecutionResults):
+    ARRAY = postgresql.ARRAY
+
+    def _test_undim_array_contains_typed_exec(self, struct):
+        arrtable = self.tables.arrtable
+        self._fixture_456(arrtable)
+        eq_(
+            testing.db.scalar(
+                select([arrtable.c.intarr]).
+                where(arrtable.c.intarr.contains(struct([4, 5])))
+            ),
+            [4, 5, 6]
+        )
+
+    def test_undim_array_contains_set_exec(self):
+        self._test_undim_array_contains_typed_exec(set)
+
+    def test_undim_array_contains_list_exec(self):
+        self._test_undim_array_contains_typed_exec(list)
+
+    def test_undim_array_contains_generator_exec(self):
+        self._test_undim_array_contains_typed_exec(
+            lambda elem: (x for x in elem))
+
+    def _test_dim_array_contains_typed_exec(self, struct):
+        dim_arrtable = self.tables.dim_arrtable
+        self._fixture_456(dim_arrtable)
+        eq_(
+            testing.db.scalar(
+                select([dim_arrtable.c.intarr]).
+                where(dim_arrtable.c.intarr.contains(struct([4, 5])))
+            ),
+            [4, 5, 6]
+        )
+
+    def test_dim_array_contains_set_exec(self):
+        self._test_dim_array_contains_typed_exec(set)
+
+    def test_dim_array_contains_list_exec(self):
+        self._test_dim_array_contains_typed_exec(list)
+
+    def test_dim_array_contains_generator_exec(self):
+        self._test_dim_array_contains_typed_exec(
+            lambda elem: (
+                x for x in elem))
+
+    def test_array_contained_by_exec(self):
+        arrtable = self.tables.arrtable
+        with testing.db.connect() as conn:
+            conn.execute(
+                arrtable.insert(),
+                intarr=[6, 5, 4]
+            )
+            eq_(
+                conn.scalar(
+                    select([arrtable.c.intarr.contained_by([4, 5, 6, 7])])
+                ),
+                True
+            )
+
+    def test_undim_array_empty(self):
+        arrtable = self.tables.arrtable
+        self._fixture_456(arrtable)
+        eq_(
+            testing.db.scalar(
+                select([arrtable.c.intarr]).
+                where(arrtable.c.intarr.contains([]))
+            ),
+            [4, 5, 6]
+        )
+
+    def test_array_overlap_exec(self):
+        arrtable = self.tables.arrtable
+        with testing.db.connect() as conn:
+            conn.execute(
+                arrtable.insert(),
+                intarr=[4, 5, 6]
+            )
+            eq_(
+                conn.scalar(
+                    select([arrtable.c.intarr]).
+                    where(arrtable.c.intarr.overlap([7, 6]))
+                ),
+                [4, 5, 6]
+            )
+
+
 class HashableFlagORMTest(fixtures.TestBase):
     """test the various 'collection' types that they flip the 'hashable' flag
     appropriately.  [ticket:3499]"""