]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support `ARRAY` of `Enum`, `JSON` or `JSONB`
authorFederico Caselli <cfederico87@gmail.com>
Sun, 19 Apr 2020 18:09:39 +0000 (20:09 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Apr 2020 15:54:20 +0000 (11:54 -0400)
Added support for columns or type :class:`.ARRAY` of :class:`.Enum`,
:class:`.JSON` or :class:`_postgresql.JSONB` in PostgreSQL.
Previously a workaround was required in these use cases.
Raise an explicit :class:`.exc.CompileError` when adding a table with a
column of type :class:`.ARRAY` of :class:`.Enum` configured with
:paramref:`.Enum.native_enum` set to ``False`` when
:paramref:`.Enum.create_constraint` is not set to ``False``

Fixes: #5265
Fixes: #5266
Change-Id: I83a2d20a599232b7066d0839f3e55ff8b78cd8fc

doc/build/changelog/unreleased_13/5265.rst [new file with mode: 0644]
doc/build/changelog/unreleased_13/5266.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/hstore.py
lib/sqlalchemy/dialects/postgresql/json.py
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/5265.rst b/doc/build/changelog/unreleased_13/5265.rst
new file mode 100644 (file)
index 0000000..f2b32f5
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 5265
+
+    Added support for columns or type :class:`.ARRAY` of :class:`.Enum`,
+    :class:`.JSON` or :class:`_postgresql.JSONB` in PostgreSQL.
+    Previously a workaround was required in these use cases.
+
diff --git a/doc/build/changelog/unreleased_13/5266.rst b/doc/build/changelog/unreleased_13/5266.rst
new file mode 100644 (file)
index 0000000..4aec9f4
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 5266
+
+    Raise an explicit :class:`.exc.CompileError` when adding a table with a
+    column of type :class:`.ARRAY` of :class:`.Enum` configured with
+    :paramref:`.Enum.native_enum` set to ``False`` when
+    :paramref:`.Enum.create_constraint` is not set to ``False``
index a3537ba601e44ee5a3064fe5b00f1a6044487adf..84fbd2e5019d149c1761e6e5327b312a292d6270 100644 (file)
@@ -5,19 +5,14 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from .base import colspecs
-from .base import ischema_names
+import re
+
 from ... import types as sqltypes
+from ... import util
 from ...sql import expression
 from ...sql import operators
 
 
-try:
-    from uuid import UUID as _python_UUID  # noqa
-except ImportError:
-    _python_UUID = None
-
-
 def Any(other, arrexpr, operator=operators.eq):
     """A synonym for the :meth:`.ARRAY.Comparator.any` method.
 
@@ -318,6 +313,25 @@ class ARRAY(sqltypes.ARRAY):
                 for x in arr
             )
 
+    @util.memoized_property
+    def _require_cast(self):
+        return self._against_native_enum or isinstance(
+            self.item_type, sqltypes.JSON
+        )
+
+    @util.memoized_property
+    def _against_native_enum(self):
+        return (
+            isinstance(self.item_type, sqltypes.Enum)
+            and self.item_type.native_enum
+        )
+
+    def bind_expression(self, bindvalue):
+        if self._require_cast:
+            return expression.cast(bindvalue, self)
+        else:
+            return bindvalue
+
     def bind_processor(self, dialect):
         item_proc = self.item_type.dialect_impl(dialect).bind_processor(
             dialect
@@ -349,8 +363,23 @@ class ARRAY(sqltypes.ARRAY):
                     tuple if self.as_tuple else list,
                 )
 
-        return process
-
+        if self._against_native_enum:
+            super_rp = process
+
+            def handle_raw_string(value):
+                inner = re.match(r"^{(.*)}$", value).group(1)
+                return inner.split(",") if inner else []
+
+            def process(value):
+                if value is None:
+                    return value
+                # isinstance(value, util.string_types) is required to handle
+                # the # case where a TypeDecorator for and Array of Enum is
+                # used like was required in sa < 1.3.17
+                return super_rp(
+                    handle_raw_string(value)
+                    if isinstance(value, util.string_types)
+                    else value
+                )
 
-colspecs[sqltypes.ARRAY] = ARRAY
-ischema_names["_array"] = ARRAY
+        return process
index 20540ac0200c78c71b308f103d735443d4351cc1..962642e0ab49f4e22b46a890f058cbb2d5829c3e 100644 (file)
@@ -870,9 +870,11 @@ Using ENUM with ARRAY
 ^^^^^^^^^^^^^^^^^^^^^
 
 The combination of ENUM and ARRAY is not directly supported by backend
-DBAPIs at this time.   In order to send and receive an ARRAY of ENUM,
-use the following workaround type, which decorates the
-:class:`_postgresql.ARRAY` datatype.
+DBAPIs at this time.   Prior to SQLAlchemy 1.3.17, a special workaround
+was needed in order to allow this combination to work, described below.
+
+.. versionchanged:: 1.3.17 The combination of ENUM and ARRAY is now directly
+   handled by SQLAlchemy's implementation without any workarounds needed.
 
 .. sourcecode:: python
 
@@ -917,10 +919,15 @@ a new version.
 Using JSON/JSONB with ARRAY
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-Similar to using ENUM, for an ARRAY of JSON/JSONB we need to render the
-appropriate CAST, however current psycopg2 drivers seem to handle the result
-for ARRAY of JSON automatically, so the type is simpler::
+Similar to using ENUM, prior to SQLAlchemy 1.3.17, for an ARRAY of JSON/JSONB
+we need to render the appropriate CAST.   Current psycopg2 drivers accomodate
+the result set correctly without any special steps.
+
+.. versionchanged:: 1.3.17 The combination of JSON/JSONB and ARRAY is now
+   directly handled by SQLAlchemy's implementation without any workarounds
+   needed.
 
+.. sourcecode:: python
 
     class CastingArray(ARRAY):
         def bind_expression(self, bindvalue):
@@ -940,6 +947,10 @@ from collections import defaultdict
 import datetime as dt
 import re
 
+from . import array as _array
+from . import hstore as _hstore
+from . import json as _json
+from . import ranges as _ranges
 from ... import exc
 from ... import schema
 from ... import sql
@@ -1523,9 +1534,25 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
             self.drop(bind=bind, checkfirst=checkfirst)
 
 
-colspecs = {sqltypes.Interval: INTERVAL, sqltypes.Enum: ENUM}
+colspecs = {
+    sqltypes.ARRAY: _array.ARRAY,
+    sqltypes.Interval: INTERVAL,
+    sqltypes.Enum: ENUM,
+    sqltypes.JSON.JSONPathType: _json.JSONPathType,
+    sqltypes.JSON: _json.JSON,
+}
 
 ischema_names = {
+    "_array": _array.ARRAY,
+    "hstore": _hstore.HSTORE,
+    "json": _json.JSON,
+    "jsonb": _json.JSONB,
+    "int4range": _ranges.INT4RANGE,
+    "int8range": _ranges.INT8RANGE,
+    "numrange": _ranges.NUMRANGE,
+    "daterange": _ranges.DATERANGE,
+    "tsrange": _ranges.TSRANGE,
+    "tstzrange": _ranges.TSTZRANGE,
     "integer": INTEGER,
     "bigint": BIGINT,
     "smallint": SMALLINT,
@@ -1917,6 +1944,22 @@ class PGDDLCompiler(compiler.DDLCompiler):
             colspec += " NOT NULL"
         return colspec
 
+    def visit_check_constraint(self, constraint):
+        if constraint._type_bound:
+            typ = list(constraint.columns)[0].type
+            if (
+                isinstance(typ, sqltypes.ARRAY)
+                and isinstance(typ.item_type, sqltypes.Enum)
+                and not typ.item_type.native_enum
+            ):
+                raise exc.CompileError(
+                    "PostgreSQL dialect cannot produce the CHECK constraint "
+                    "for ARRAY of non-native ENUM; please specify "
+                    "create_constraint=False on this Enum datatype."
+                )
+
+        return super(PGDDLCompiler, self).visit_check_constraint(constraint)
+
     def visit_drop_table_comment(self, drop):
         return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
             drop.element
index 6798051836a615ffb5ca6af8fe42672c1dc50f17..4e048feb05ec9ca88cd014f8dae0981db8aaccda 100644 (file)
@@ -8,7 +8,6 @@
 import re
 
 from .array import ARRAY
-from .base import ischema_names
 from ... import types as sqltypes
 from ... import util
 from ...sql import functions as sqlfunc
@@ -268,9 +267,6 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
         return process
 
 
-ischema_names["hstore"] = HSTORE
-
-
 class hstore(sqlfunc.GenericFunction):
     """Construct an hstore value within a SQL expression using the
     PostgreSQL ``hstore()`` function.
index 953ad9993a186622fff9c6d887c518501dc4dbee..ea7b04d4f86d1cc964ff33811f25e943c9414600 100644 (file)
@@ -6,8 +6,6 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 from __future__ import absolute_import
 
-from .base import colspecs
-from .base import ischema_names
 from ... import types as sqltypes
 from ... import util
 from ...sql import operators
@@ -96,9 +94,6 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
         return process
 
 
-colspecs[sqltypes.JSON.JSONPathType] = JSONPathType
-
-
 class JSON(sqltypes.JSON):
     """Represent the PostgreSQL JSON type.
 
@@ -236,10 +231,6 @@ class JSON(sqltypes.JSON):
     comparator_factory = Comparator
 
 
-colspecs[sqltypes.JSON] = JSON
-ischema_names["json"] = JSON
-
-
 class JSONB(JSON):
     """Represent the PostgreSQL JSONB type.
 
@@ -324,6 +315,3 @@ class JSONB(JSON):
             )
 
     comparator_factory = Comparator
-
-
-ischema_names["jsonb"] = JSONB
index 76d0aeaebcda490e017667a88c63e694a4920613..d4f75b4948cb70a3b3f7d8e19f28c4cbc8a5d93a 100644 (file)
@@ -4,7 +4,6 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from .base import ischema_names
 from ... import types as sqltypes
 
 
@@ -108,9 +107,6 @@ class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
     __visit_name__ = "INT4RANGE"
 
 
-ischema_names["int4range"] = INT4RANGE
-
-
 class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
     """Represent the PostgreSQL INT8RANGE type.
 
@@ -119,9 +115,6 @@ class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
     __visit_name__ = "INT8RANGE"
 
 
-ischema_names["int8range"] = INT8RANGE
-
-
 class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
     """Represent the PostgreSQL NUMRANGE type.
 
@@ -130,9 +123,6 @@ class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
     __visit_name__ = "NUMRANGE"
 
 
-ischema_names["numrange"] = NUMRANGE
-
-
 class DATERANGE(RangeOperators, sqltypes.TypeEngine):
     """Represent the PostgreSQL DATERANGE type.
 
@@ -141,9 +131,6 @@ class DATERANGE(RangeOperators, sqltypes.TypeEngine):
     __visit_name__ = "DATERANGE"
 
 
-ischema_names["daterange"] = DATERANGE
-
-
 class TSRANGE(RangeOperators, sqltypes.TypeEngine):
     """Represent the PostgreSQL TSRANGE type.
 
@@ -152,15 +139,9 @@ class TSRANGE(RangeOperators, sqltypes.TypeEngine):
     __visit_name__ = "TSRANGE"
 
 
-ischema_names["tsrange"] = TSRANGE
-
-
 class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
     """Represent the PostgreSQL TSTZRANGE type.
 
     """
 
     __visit_name__ = "TSTZRANGE"
-
-
-ischema_names["tstzrange"] = TSTZRANGE
index 34ea4d7ed91cfa5150ab4d5717302b16a3f7b462..2adde8edc35172654a5752c838489f829797d27a 100644 (file)
@@ -1,6 +1,7 @@
 # coding: utf-8
 import datetime
 import decimal
+import re
 import uuid
 
 import sqlalchemy as sa
@@ -44,6 +45,7 @@ from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import NUMRANGE
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql import TSTZRANGE
+from sqlalchemy.exc import CompileError
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import operators
@@ -1720,6 +1722,144 @@ class PGArrayRoundTripTest(
         )
 
 
+class _ArrayOfEnum(TypeDecorator):
+    # previous workaround for array of enum
+    impl = postgresql.ARRAY
+
+    def bind_expression(self, bindvalue):
+        return sa.cast(bindvalue, self)
+
+    def result_processor(self, dialect, coltype):
+        super_rp = super(_ArrayOfEnum, self).result_processor(dialect, coltype)
+
+        def handle_raw_string(value):
+            inner = re.match(r"^{(.*)}$", value).group(1)
+            return inner.split(",") if inner else []
+
+        def process(value):
+            if value is None:
+                return None
+            return super_rp(handle_raw_string(value))
+
+        return process
+
+
+class ArrayEnum(fixtures.TestBase):
+    __backend__ = True
+    __only_on__ = "postgresql"
+    __unsupported_on__ = ("postgresql+pg8000",)
+
+    @testing.combinations(
+        sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls"
+    )
+    @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
+    @testing.provide_metadata
+    def test_raises_non_native_enums(self, array_cls, enum_cls):
+        Table(
+            "my_table",
+            self.metadata,
+            Column(
+                "my_col",
+                array_cls(
+                    enum_cls(
+                        "foo", "bar", "baz", name="my_enum", native_enum=False
+                    )
+                ),
+            ),
+        )
+
+        testing.assert_raises_message(
+            CompileError,
+            "PostgreSQL dialect cannot produce the CHECK constraint "
+            "for ARRAY of non-native ENUM; please specify "
+            "create_constraint=False on this Enum datatype.",
+            self.metadata.create_all,
+            testing.db,
+        )
+
+    @testing.combinations(
+        sqltypes.ARRAY, postgresql.ARRAY, _ArrayOfEnum, argnames="array_cls"
+    )
+    @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
+    @testing.provide_metadata
+    def test_array_of_enums(self, array_cls, enum_cls, connection):
+        tbl = Table(
+            "enum_table",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "enum_col",
+                array_cls(enum_cls("foo", "bar", "baz", name="an_enum")),
+            ),
+        )
+
+        if util.py3k:
+            from enum import Enum
+
+            class MyEnum(Enum):
+                a = "aaa"
+                b = "bbb"
+                c = "ccc"
+
+            tbl.append_column(
+                Column("pyenum_col", array_cls(enum_cls(MyEnum)),),
+            )
+
+        self.metadata.create_all(connection)
+
+        connection.execute(
+            tbl.insert(), [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}]
+        )
+
+        sel = select([tbl.c.enum_col]).order_by(tbl.c.id)
+        eq_(
+            connection.execute(sel).fetchall(), [(["foo"],), (["foo", "bar"],)]
+        )
+
+        if util.py3k:
+            connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]})
+            sel = select([tbl.c.pyenum_col]).order_by(tbl.c.id.desc())
+            eq_(connection.scalar(sel), [MyEnum.a])
+
+
+class ArrayJSON(fixtures.TestBase):
+    __backend__ = True
+    __only_on__ = "postgresql"
+    __unsupported_on__ = ("postgresql+pg8000",)
+
+    @testing.combinations(
+        sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls"
+    )
+    @testing.combinations(
+        sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls"
+    )
+    @testing.provide_metadata
+    def test_array_of_json(self, array_cls, json_cls, connection):
+        tbl = Table(
+            "json_table",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("json_col", array_cls(json_cls),),
+        )
+
+        self.metadata.create_all(connection)
+
+        connection.execute(
+            tbl.insert(),
+            [
+                {"json_col": ["foo"]},
+                {"json_col": [{"foo": "bar"}, [1]]},
+                {"json_col": [None]},
+            ],
+        )
+
+        sel = select([tbl.c.json_col]).order_by(tbl.c.id)
+        eq_(
+            connection.execute(sel).fetchall(),
+            [(["foo"],), ([{"foo": "bar"}, [1]],), ([None],)],
+        )
+
+
 class HashableFlagORMTest(fixtures.TestBase):
     """test the various 'collection' types that they flip the 'hashable' flag
     appropriately.  [ticket:3499]"""