From: Federico Caselli Date: Sun, 19 Apr 2020 18:09:39 +0000 (+0200) Subject: Support `ARRAY` of `Enum`, `JSON` or `JSONB` X-Git-Tag: rel_1_4_0b1~371^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=aaec1bdedfc73ead3aef3a3e4d835a8df339e2dd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support `ARRAY` of `Enum`, `JSON` or `JSONB` Added support for columns or type :class:`.ARRAY` of :class:`.Enum`, :class:`.JSON` or :class:`_postgresql.JSONB` in PostgreSQL. Previously a workaround was required in these use cases. Raise an explicit :class:`.exc.CompileError` when adding a table with a column of type :class:`.ARRAY` of :class:`.Enum` configured with :paramref:`.Enum.native_enum` set to ``False`` when :paramref:`.Enum.create_constraint` is not set to ``False`` Fixes: #5265 Fixes: #5266 Change-Id: I83a2d20a599232b7066d0839f3e55ff8b78cd8fc --- diff --git a/doc/build/changelog/unreleased_13/5265.rst b/doc/build/changelog/unreleased_13/5265.rst new file mode 100644 index 0000000000..f2b32f5fb9 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5265.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 5265 + + Added support for columns or type :class:`.ARRAY` of :class:`.Enum`, + :class:`.JSON` or :class:`_postgresql.JSONB` in PostgreSQL. + Previously a workaround was required in these use cases. + diff --git a/doc/build/changelog/unreleased_13/5266.rst b/doc/build/changelog/unreleased_13/5266.rst new file mode 100644 index 0000000000..4aec9f44bb --- /dev/null +++ b/doc/build/changelog/unreleased_13/5266.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 5266 + + Raise an explicit :class:`.exc.CompileError` when adding a table with a + column of type :class:`.ARRAY` of :class:`.Enum` configured with + :paramref:`.Enum.native_enum` set to ``False`` when + :paramref:`.Enum.create_constraint` is not set to ``False`` diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index a3537ba601..84fbd2e501 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -5,19 +5,14 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .base import colspecs -from .base import ischema_names +import re + from ... import types as sqltypes +from ... import util from ...sql import expression from ...sql import operators -try: - from uuid import UUID as _python_UUID # noqa -except ImportError: - _python_UUID = None - - def Any(other, arrexpr, operator=operators.eq): """A synonym for the :meth:`.ARRAY.Comparator.any` method. @@ -318,6 +313,25 @@ class ARRAY(sqltypes.ARRAY): for x in arr ) + @util.memoized_property + def _require_cast(self): + return self._against_native_enum or isinstance( + self.item_type, sqltypes.JSON + ) + + @util.memoized_property + def _against_native_enum(self): + return ( + isinstance(self.item_type, sqltypes.Enum) + and self.item_type.native_enum + ) + + def bind_expression(self, bindvalue): + if self._require_cast: + return expression.cast(bindvalue, self) + else: + return bindvalue + def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor( dialect @@ -349,8 +363,23 @@ class ARRAY(sqltypes.ARRAY): tuple if self.as_tuple else list, ) - return process - + if self._against_native_enum: + super_rp = process + + def handle_raw_string(value): + inner = re.match(r"^{(.*)}$", value).group(1) + return inner.split(",") if inner else [] + + def process(value): + if value is None: + return value + # isinstance(value, util.string_types) is required to handle + # the # case where a TypeDecorator for and Array of Enum is + # used like was required in sa < 1.3.17 + return super_rp( + handle_raw_string(value) + if isinstance(value, util.string_types) + else value + ) -colspecs[sqltypes.ARRAY] = ARRAY -ischema_names["_array"] = ARRAY + return process diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 20540ac020..962642e0ab 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -870,9 +870,11 @@ Using ENUM with ARRAY ^^^^^^^^^^^^^^^^^^^^^ The combination of ENUM and ARRAY is not directly supported by backend -DBAPIs at this time. In order to send and receive an ARRAY of ENUM, -use the following workaround type, which decorates the -:class:`_postgresql.ARRAY` datatype. +DBAPIs at this time. Prior to SQLAlchemy 1.3.17, a special workaround +was needed in order to allow this combination to work, described below. + +.. versionchanged:: 1.3.17 The combination of ENUM and ARRAY is now directly + handled by SQLAlchemy's implementation without any workarounds needed. .. sourcecode:: python @@ -917,10 +919,15 @@ a new version. Using JSON/JSONB with ARRAY ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Similar to using ENUM, for an ARRAY of JSON/JSONB we need to render the -appropriate CAST, however current psycopg2 drivers seem to handle the result -for ARRAY of JSON automatically, so the type is simpler:: +Similar to using ENUM, prior to SQLAlchemy 1.3.17, for an ARRAY of JSON/JSONB +we need to render the appropriate CAST. Current psycopg2 drivers accomodate +the result set correctly without any special steps. + +.. versionchanged:: 1.3.17 The combination of JSON/JSONB and ARRAY is now + directly handled by SQLAlchemy's implementation without any workarounds + needed. +.. sourcecode:: python class CastingArray(ARRAY): def bind_expression(self, bindvalue): @@ -940,6 +947,10 @@ from collections import defaultdict import datetime as dt import re +from . import array as _array +from . import hstore as _hstore +from . import json as _json +from . import ranges as _ranges from ... import exc from ... import schema from ... import sql @@ -1523,9 +1534,25 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): self.drop(bind=bind, checkfirst=checkfirst) -colspecs = {sqltypes.Interval: INTERVAL, sqltypes.Enum: ENUM} +colspecs = { + sqltypes.ARRAY: _array.ARRAY, + sqltypes.Interval: INTERVAL, + sqltypes.Enum: ENUM, + sqltypes.JSON.JSONPathType: _json.JSONPathType, + sqltypes.JSON: _json.JSON, +} ischema_names = { + "_array": _array.ARRAY, + "hstore": _hstore.HSTORE, + "json": _json.JSON, + "jsonb": _json.JSONB, + "int4range": _ranges.INT4RANGE, + "int8range": _ranges.INT8RANGE, + "numrange": _ranges.NUMRANGE, + "daterange": _ranges.DATERANGE, + "tsrange": _ranges.TSRANGE, + "tstzrange": _ranges.TSTZRANGE, "integer": INTEGER, "bigint": BIGINT, "smallint": SMALLINT, @@ -1917,6 +1944,22 @@ class PGDDLCompiler(compiler.DDLCompiler): colspec += " NOT NULL" return colspec + def visit_check_constraint(self, constraint): + if constraint._type_bound: + typ = list(constraint.columns)[0].type + if ( + isinstance(typ, sqltypes.ARRAY) + and isinstance(typ.item_type, sqltypes.Enum) + and not typ.item_type.native_enum + ): + raise exc.CompileError( + "PostgreSQL dialect cannot produce the CHECK constraint " + "for ARRAY of non-native ENUM; please specify " + "create_constraint=False on this Enum datatype." + ) + + return super(PGDDLCompiler, self).visit_check_constraint(constraint) + def visit_drop_table_comment(self, drop): return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( drop.element diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index 6798051836..4e048feb05 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -8,7 +8,6 @@ import re from .array import ARRAY -from .base import ischema_names from ... import types as sqltypes from ... import util from ...sql import functions as sqlfunc @@ -268,9 +267,6 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): return process -ischema_names["hstore"] = HSTORE - - class hstore(sqlfunc.GenericFunction): """Construct an hstore value within a SQL expression using the PostgreSQL ``hstore()`` function. diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 953ad9993a..ea7b04d4f8 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -6,8 +6,6 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from __future__ import absolute_import -from .base import colspecs -from .base import ischema_names from ... import types as sqltypes from ... import util from ...sql import operators @@ -96,9 +94,6 @@ class JSONPathType(sqltypes.JSON.JSONPathType): return process -colspecs[sqltypes.JSON.JSONPathType] = JSONPathType - - class JSON(sqltypes.JSON): """Represent the PostgreSQL JSON type. @@ -236,10 +231,6 @@ class JSON(sqltypes.JSON): comparator_factory = Comparator -colspecs[sqltypes.JSON] = JSON -ischema_names["json"] = JSON - - class JSONB(JSON): """Represent the PostgreSQL JSONB type. @@ -324,6 +315,3 @@ class JSONB(JSON): ) comparator_factory = Comparator - - -ischema_names["jsonb"] = JSONB diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 76d0aeaebc..d4f75b4948 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -4,7 +4,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .base import ischema_names from ... import types as sqltypes @@ -108,9 +107,6 @@ class INT4RANGE(RangeOperators, sqltypes.TypeEngine): __visit_name__ = "INT4RANGE" -ischema_names["int4range"] = INT4RANGE - - class INT8RANGE(RangeOperators, sqltypes.TypeEngine): """Represent the PostgreSQL INT8RANGE type. @@ -119,9 +115,6 @@ class INT8RANGE(RangeOperators, sqltypes.TypeEngine): __visit_name__ = "INT8RANGE" -ischema_names["int8range"] = INT8RANGE - - class NUMRANGE(RangeOperators, sqltypes.TypeEngine): """Represent the PostgreSQL NUMRANGE type. @@ -130,9 +123,6 @@ class NUMRANGE(RangeOperators, sqltypes.TypeEngine): __visit_name__ = "NUMRANGE" -ischema_names["numrange"] = NUMRANGE - - class DATERANGE(RangeOperators, sqltypes.TypeEngine): """Represent the PostgreSQL DATERANGE type. @@ -141,9 +131,6 @@ class DATERANGE(RangeOperators, sqltypes.TypeEngine): __visit_name__ = "DATERANGE" -ischema_names["daterange"] = DATERANGE - - class TSRANGE(RangeOperators, sqltypes.TypeEngine): """Represent the PostgreSQL TSRANGE type. @@ -152,15 +139,9 @@ class TSRANGE(RangeOperators, sqltypes.TypeEngine): __visit_name__ = "TSRANGE" -ischema_names["tsrange"] = TSRANGE - - class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): """Represent the PostgreSQL TSTZRANGE type. """ __visit_name__ = "TSTZRANGE" - - -ischema_names["tstzrange"] = TSTZRANGE diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 34ea4d7ed9..2adde8edc3 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1,6 +1,7 @@ # coding: utf-8 import datetime import decimal +import re import uuid import sqlalchemy as sa @@ -44,6 +45,7 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import NUMRANGE from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql import TSTZRANGE +from sqlalchemy.exc import CompileError from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import operators @@ -1720,6 +1722,144 @@ class PGArrayRoundTripTest( ) +class _ArrayOfEnum(TypeDecorator): + # previous workaround for array of enum + impl = postgresql.ARRAY + + def bind_expression(self, bindvalue): + return sa.cast(bindvalue, self) + + def result_processor(self, dialect, coltype): + super_rp = super(_ArrayOfEnum, self).result_processor(dialect, coltype) + + def handle_raw_string(value): + inner = re.match(r"^{(.*)}$", value).group(1) + return inner.split(",") if inner else [] + + def process(value): + if value is None: + return None + return super_rp(handle_raw_string(value)) + + return process + + +class ArrayEnum(fixtures.TestBase): + __backend__ = True + __only_on__ = "postgresql" + __unsupported_on__ = ("postgresql+pg8000",) + + @testing.combinations( + sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls" + ) + @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") + @testing.provide_metadata + def test_raises_non_native_enums(self, array_cls, enum_cls): + Table( + "my_table", + self.metadata, + Column( + "my_col", + array_cls( + enum_cls( + "foo", "bar", "baz", name="my_enum", native_enum=False + ) + ), + ), + ) + + testing.assert_raises_message( + CompileError, + "PostgreSQL dialect cannot produce the CHECK constraint " + "for ARRAY of non-native ENUM; please specify " + "create_constraint=False on this Enum datatype.", + self.metadata.create_all, + testing.db, + ) + + @testing.combinations( + sqltypes.ARRAY, postgresql.ARRAY, _ArrayOfEnum, argnames="array_cls" + ) + @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") + @testing.provide_metadata + def test_array_of_enums(self, array_cls, enum_cls, connection): + tbl = Table( + "enum_table", + self.metadata, + Column("id", Integer, primary_key=True), + Column( + "enum_col", + array_cls(enum_cls("foo", "bar", "baz", name="an_enum")), + ), + ) + + if util.py3k: + from enum import Enum + + class MyEnum(Enum): + a = "aaa" + b = "bbb" + c = "ccc" + + tbl.append_column( + Column("pyenum_col", array_cls(enum_cls(MyEnum)),), + ) + + self.metadata.create_all(connection) + + connection.execute( + tbl.insert(), [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}] + ) + + sel = select([tbl.c.enum_col]).order_by(tbl.c.id) + eq_( + connection.execute(sel).fetchall(), [(["foo"],), (["foo", "bar"],)] + ) + + if util.py3k: + connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]}) + sel = select([tbl.c.pyenum_col]).order_by(tbl.c.id.desc()) + eq_(connection.scalar(sel), [MyEnum.a]) + + +class ArrayJSON(fixtures.TestBase): + __backend__ = True + __only_on__ = "postgresql" + __unsupported_on__ = ("postgresql+pg8000",) + + @testing.combinations( + sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls" + ) + @testing.combinations( + sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls" + ) + @testing.provide_metadata + def test_array_of_json(self, array_cls, json_cls, connection): + tbl = Table( + "json_table", + self.metadata, + Column("id", Integer, primary_key=True), + Column("json_col", array_cls(json_cls),), + ) + + self.metadata.create_all(connection) + + connection.execute( + tbl.insert(), + [ + {"json_col": ["foo"]}, + {"json_col": [{"foo": "bar"}, [1]]}, + {"json_col": [None]}, + ], + ) + + sel = select([tbl.c.json_col]).order_by(tbl.c.id) + eq_( + connection.execute(sel).fetchall(), + [(["foo"],), ([{"foo": "bar"}, [1]],), ([None],)], + ) + + class HashableFlagORMTest(fixtures.TestBase): """test the various 'collection' types that they flip the 'hashable' flag appropriately. [ticket:3499]"""