From 1b463058e3282c73d0fb361f78e96ecaa23ce9f4 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 11 Apr 2017 10:26:38 -0400 Subject: [PATCH] Set up base ARRAY to be compatible with postgresql.ARRAY. For some reason, when ARRAY was added to the base it was never linked to postgresql.ARRAY. Link the two types and also make base ARRAY the schema event target so that it supports the same features as postgresql.ARRAY. Change-Id: I82fa6c9d2b8c5028dba3a009715f7bc296b2bc0b Fixes: #3964 --- doc/build/changelog/changelog_12.rst | 7 + lib/sqlalchemy/dialects/postgresql/array.py | 17 +- lib/sqlalchemy/sql/sqltypes.py | 19 +- test/dialect/postgresql/test_types.py | 231 +++++++++++--------- 4 files changed, 160 insertions(+), 114 deletions(-) diff --git a/doc/build/changelog/changelog_12.rst b/doc/build/changelog/changelog_12.rst index 815f587f81..7c04210192 100644 --- a/doc/build/changelog/changelog_12.rst +++ b/doc/build/changelog/changelog_12.rst @@ -13,6 +13,13 @@ .. changelog:: :version: 1.2.0b1 + .. change:: 3964 + :tags: bug, postgresql + :tickets: 3964 + + Fixed bug where the base :class:`.sqltypes.ARRAY` datatype would not + invoke the bind/result processors of :class:`.postgresql.ARRAY`. + .. change:: 3963 :tags: bug, orm :tickets: 3963 diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 98cab95626..009c83c0d4 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -5,7 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .base import ischema_names +from .base import ischema_names, colspecs from ...sql import expression, operators from ...sql.base import SchemaEventTarget from ... import types as sqltypes @@ -114,7 +114,7 @@ CONTAINED_BY = operators.custom_op("<@", precedence=5) OVERLAP = operators.custom_op("&&", precedence=5) -class ARRAY(SchemaEventTarget, sqltypes.ARRAY): +class ARRAY(sqltypes.ARRAY): """PostgreSQL ARRAY type. @@ -248,18 +248,6 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY): def compare_values(self, x, y): return x == y - def _set_parent(self, column): - """Support SchemaEventTarget""" - - if isinstance(self.item_type, SchemaEventTarget): - self.item_type._set_parent(column) - - def _set_parent_with_dispatch(self, parent): - """Support SchemaEventTarget""" - - if isinstance(self.item_type, SchemaEventTarget): - self.item_type._set_parent_with_dispatch(parent) - def _proc_array(self, arr, itemproc, dim, collection): if dim is None: arr = list(arr) @@ -311,4 +299,5 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY): tuple if self.as_tuple else list) return process +colspecs[sqltypes.ARRAY] = ARRAY ischema_names['_array'] = ARRAY diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 8a114ece60..b8117e3ca1 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2061,7 +2061,7 @@ class JSON(Indexable, TypeEngine): return process -class ARRAY(Indexable, Concatenable, TypeEngine): +class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """Represent a SQL Array type. .. note:: This type serves as the basis for all ARRAY operations. @@ -2199,6 +2199,11 @@ class ARRAY(Indexable, Concatenable, TypeEngine): return operators.getitem, index, return_type + def contains(self, *arg, **kw): + raise NotImplementedError( + "ARRAY.contains() not implemented for the base " + "ARRAY type; please use the dialect-specific ARRAY type") + @util.dependencies("sqlalchemy.sql.elements") def any(self, elements, other, operator=None): """Return ``other operator ANY (array)`` clause. @@ -2325,6 +2330,18 @@ class ARRAY(Indexable, Concatenable, TypeEngine): def compare_values(self, x, y): return x == y + def _set_parent(self, column): + """Support SchemaEventTarget""" + + if isinstance(self.item_type, SchemaEventTarget): + self.item_type._set_parent(column) + + def _set_parent_with_dispatch(self, parent): + """Support SchemaEventTarget""" + + if isinstance(self.item_type, SchemaEventTarget): + self.item_type._set_parent_with_dispatch(parent) + class REAL(Float): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 807eeb60c4..d2e19a04a2 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4,6 +4,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises, \ AssertsCompiledSQL, ComparesTables from sqlalchemy.testing import engines, fixtures from sqlalchemy import testing +from sqlalchemy.sql import sqltypes import datetime from sqlalchemy import Table, MetaData, Column, Integer, Enum, Float, select, \ func, DateTime, Numeric, exc, String, cast, REAL, TypeDecorator, Unicode, \ @@ -85,7 +86,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') @testing.provide_metadata - def test_arrays(self): + def test_arrays_pg(self): metadata = self.metadata t1 = Table('t', metadata, Column('x', postgresql.ARRAY(Float)), @@ -101,6 +102,25 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): ([5], [5], [6], [decimal.Decimal("6.4")]) ) + @testing.fails_on('postgresql+zxjdbc', + 'zxjdbc has no support for PG arrays') + @testing.provide_metadata + def test_arrays_base(self): + metadata = self.metadata + t1 = Table('t', metadata, + Column('x', sqltypes.ARRAY(Float)), + Column('y', sqltypes.ARRAY(REAL)), + Column('z', sqltypes.ARRAY(postgresql.DOUBLE_PRECISION)), + Column('q', sqltypes.ARRAY(Numeric)) + ) + metadata.create_all() + t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")]) + row = t1.select().execute().first() + eq_( + row, + ([5], [5], [6], [decimal.Decimal("6.4")]) + ) + class EnumTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True @@ -987,17 +1007,19 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): is_(expr.type.item_type.__class__, Integer) -class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): +class ArrayRoundTripTest(object): __only_on__ = 'postgresql' __backend__ = True __unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc' + ARRAY = postgresql.ARRAY + @classmethod def define_tables(cls, metadata): class ProcValue(TypeDecorator): - impl = postgresql.ARRAY(Integer, dimensions=2) + impl = cls.ARRAY(Integer, dimensions=2) def process_bind_param(self, value, dialect): if value is None: @@ -1017,15 +1039,15 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): Table('arrtable', metadata, Column('id', Integer, primary_key=True), - Column('intarr', postgresql.ARRAY(Integer)), - Column('strarr', postgresql.ARRAY(Unicode())), + Column('intarr', cls.ARRAY(Integer)), + Column('strarr', cls.ARRAY(Unicode())), Column('dimarr', ProcValue) ) Table('dim_arrtable', metadata, Column('id', Integer, primary_key=True), - Column('intarr', postgresql.ARRAY(Integer, dimensions=1)), - Column('strarr', postgresql.ARRAY(Unicode(), dimensions=1)), + Column('intarr', cls.ARRAY(Integer, dimensions=1)), + Column('strarr', cls.ARRAY(Unicode(), dimensions=1)), Column('dimarr', ProcValue) ) @@ -1038,8 +1060,8 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): def test_reflect_array_column(self): metadata2 = MetaData(testing.db) tbl = Table('arrtable', metadata2, autoload=True) - assert isinstance(tbl.c.intarr.type, postgresql.ARRAY) - assert isinstance(tbl.c.strarr.type, postgresql.ARRAY) + assert isinstance(tbl.c.intarr.type, self.ARRAY) + assert isinstance(tbl.c.strarr.type, self.ARRAY) assert isinstance(tbl.c.intarr.type.item_type, Integer) assert isinstance(tbl.c.strarr.type.item_type, String) @@ -1107,19 +1129,19 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): func.array_cat( array([1, 2, 3]), array([4, 5, 6]), - type_=postgresql.ARRAY(Integer) + type_=self.ARRAY(Integer) )[2:5] ]) eq_( testing.db.execute(stmt).scalar(), [2, 3, 4, 5] ) - def test_any_all_exprs(self): + def test_any_all_exprs_array(self): stmt = select([ 3 == any_(func.array_cat( array([1, 2, 3]), array([4, 5, 6]), - type_=postgresql.ARRAY(Integer) + type_=self.ARRAY(Integer) )) ]) eq_( @@ -1225,17 +1247,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): 7 ) - def test_undim_array_empty(self): - arrtable = self.tables.arrtable - self._fixture_456(arrtable) - eq_( - testing.db.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.contains([])) - ), - [4, 5, 6] - ) - def test_array_getitem_slice_exec(self): arrtable = self.tables.arrtable testing.db.execute( @@ -1255,49 +1266,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): [7, 8] ) - def _test_undim_array_contains_typed_exec(self, struct): - arrtable = self.tables.arrtable - self._fixture_456(arrtable) - eq_( - testing.db.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.contains(struct([4, 5]))) - ), - [4, 5, 6] - ) - - def test_undim_array_contains_set_exec(self): - self._test_undim_array_contains_typed_exec(set) - - def test_undim_array_contains_list_exec(self): - self._test_undim_array_contains_typed_exec(list) - - def test_undim_array_contains_generator_exec(self): - self._test_undim_array_contains_typed_exec( - lambda elem: (x for x in elem)) - - def _test_dim_array_contains_typed_exec(self, struct): - dim_arrtable = self.tables.dim_arrtable - self._fixture_456(dim_arrtable) - eq_( - testing.db.scalar( - select([dim_arrtable.c.intarr]). - where(dim_arrtable.c.intarr.contains(struct([4, 5]))) - ), - [4, 5, 6] - ) - - def test_dim_array_contains_set_exec(self): - self._test_dim_array_contains_typed_exec(set) - - def test_dim_array_contains_list_exec(self): - self._test_dim_array_contains_typed_exec(list) - - def test_dim_array_contains_generator_exec(self): - self._test_dim_array_contains_typed_exec( - lambda elem: ( - x for x in elem)) - def test_multi_dim_roundtrip(self): arrtable = self.tables.arrtable testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]]) @@ -1306,35 +1274,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): [[-1, 0, 1], [2, 3, 4]] ) - def test_array_contained_by_exec(self): - arrtable = self.tables.arrtable - with testing.db.connect() as conn: - conn.execute( - arrtable.insert(), - intarr=[6, 5, 4] - ) - eq_( - conn.scalar( - select([arrtable.c.intarr.contained_by([4, 5, 6, 7])]) - ), - True - ) - - def test_array_overlap_exec(self): - arrtable = self.tables.arrtable - with testing.db.connect() as conn: - conn.execute( - arrtable.insert(), - intarr=[4, 5, 6] - ) - eq_( - conn.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.overlap([7, 6])) - ), - [4, 5, 6] - ) - def test_array_any_exec(self): arrtable = self.tables.arrtable with testing.db.connect() as conn: @@ -1372,10 +1311,10 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): t1 = Table( 't1', metadata, Column('id', Integer, primary_key=True), - Column('data', postgresql.ARRAY(String(5), as_tuple=True)), + Column('data', self.ARRAY(String(5), as_tuple=True)), Column( 'data2', - postgresql.ARRAY( + self.ARRAY( Numeric(asdecimal=False), as_tuple=True) ) ) @@ -1416,13 +1355,13 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): 't', m, Column( 'data_1', - postgresql.ARRAY( + self.ARRAY( postgresql.ENUM('a', 'b', 'c', name='my_enum_1') ) ), Column( 'data_2', - postgresql.ARRAY( + self.ARRAY( types.Enum('a', 'b', 'c', name='my_enum_2') ) ) @@ -1437,6 +1376,100 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): eq_(inspect(testing.db).get_enums(), []) +class CoreArrayRoundTripTest(ArrayRoundTripTest, + fixtures.TablesTest, AssertsExecutionResults): + + ARRAY = sqltypes.ARRAY + + +class PGArrayRoundTripTest(ArrayRoundTripTest, + fixtures.TablesTest, AssertsExecutionResults): + ARRAY = postgresql.ARRAY + + def _test_undim_array_contains_typed_exec(self, struct): + arrtable = self.tables.arrtable + self._fixture_456(arrtable) + eq_( + testing.db.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.contains(struct([4, 5]))) + ), + [4, 5, 6] + ) + + def test_undim_array_contains_set_exec(self): + self._test_undim_array_contains_typed_exec(set) + + def test_undim_array_contains_list_exec(self): + self._test_undim_array_contains_typed_exec(list) + + def test_undim_array_contains_generator_exec(self): + self._test_undim_array_contains_typed_exec( + lambda elem: (x for x in elem)) + + def _test_dim_array_contains_typed_exec(self, struct): + dim_arrtable = self.tables.dim_arrtable + self._fixture_456(dim_arrtable) + eq_( + testing.db.scalar( + select([dim_arrtable.c.intarr]). + where(dim_arrtable.c.intarr.contains(struct([4, 5]))) + ), + [4, 5, 6] + ) + + def test_dim_array_contains_set_exec(self): + self._test_dim_array_contains_typed_exec(set) + + def test_dim_array_contains_list_exec(self): + self._test_dim_array_contains_typed_exec(list) + + def test_dim_array_contains_generator_exec(self): + self._test_dim_array_contains_typed_exec( + lambda elem: ( + x for x in elem)) + + def test_array_contained_by_exec(self): + arrtable = self.tables.arrtable + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[6, 5, 4] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr.contained_by([4, 5, 6, 7])]) + ), + True + ) + + def test_undim_array_empty(self): + arrtable = self.tables.arrtable + self._fixture_456(arrtable) + eq_( + testing.db.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.contains([])) + ), + [4, 5, 6] + ) + + def test_array_overlap_exec(self): + arrtable = self.tables.arrtable + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[4, 5, 6] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.overlap([7, 6])) + ), + [4, 5, 6] + ) + + class HashableFlagORMTest(fixtures.TestBase): """test the various 'collection' types that they flip the 'hashable' flag appropriately. [ticket:3499]""" -- 2.47.2