]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow SchemaType and Variant to work together
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Mar 2017 17:53:00 +0000 (12:53 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Mar 2017 21:24:18 +0000 (16:24 -0500)
Added support for the :class:`.Variant` and the :class:`.SchemaType`
objects to be compatible with each other.  That is, a variant
can be created against a type like :class:`.Enum`, and the instructions
to create constraints and/or database-specific type objects will
propagate correctly as per the variant's dialect mapping.

Also added testing for some potential double-event scenarios
on TypeDecorator but it seems usually this doesn't occur.

Change-Id: I4a7e7c26b4133cd14e870f5bc34a1b2f0f19a14a
Fixes: #2892
doc/build/changelog/changelog_11.rst
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/util/langhelpers.py
test/dialect/postgresql/test_types.py
test/requirements.py
test/sql/test_metadata.py
test/sql/test_types.py

index 22f4e2bc4cfc17cbdb54267bbcf08de65b8884d4..9d43d3ae163436a698601b1b9785808fe6dc810e 100644 (file)
 .. changelog::
     :version: 1.1.7
 
+    .. change::
+        :tags: bug, sql, postgresql
+        :tickets: 2892
+
+        Added support for the :class:`.Variant` and the :class:`.SchemaType`
+        objects to be compatible with each other.  That is, a variant
+        can be created against a type like :class:`.Enum`, and the instructions
+        to create constraints and/or database-specific type objects will
+        propagate correctly as per the variant's dialect mapping.
+
     .. change::
         :tags: bug, sql
         :tickets: 3931
index bb39388ab4d49f3f1af531d1150408a99e9edfce..8a114ece60ec7a6fc8bb2dc662caba8441644640 100644 (file)
@@ -15,7 +15,7 @@ import collections
 import json
 
 from . import elements
-from .type_api import TypeEngine, TypeDecorator, to_instance
+from .type_api import TypeEngine, TypeDecorator, to_instance, Variant
 from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \
     Slice, _literal_as_binds
 from .. import exc, util, processors
@@ -1003,6 +1003,14 @@ class SchemaType(SchemaEventTarget):
     def _set_parent(self, column):
         column._on_table_attach(util.portable_instancemethod(self._set_table))
 
+    def _variant_mapping_for_set_table(self, column):
+        if isinstance(column.type, Variant):
+            variant_mapping = column.type.mapping.copy()
+            variant_mapping['_default'] = column.type.impl
+        else:
+            variant_mapping = None
+        return variant_mapping
+
     def _set_table(self, column, table):
         if self.inherit_schema:
             self.schema = table.schema
@@ -1010,16 +1018,21 @@ class SchemaType(SchemaEventTarget):
         if not self._create_events:
             return
 
+        variant_mapping = self._variant_mapping_for_set_table(column)
+
         event.listen(
             table,
             "before_create",
             util.portable_instancemethod(
-                self._on_table_create)
+                self._on_table_create,
+                {"variant_mapping": variant_mapping})
         )
         event.listen(
             table,
             "after_drop",
-            util.portable_instancemethod(self._on_table_drop)
+            util.portable_instancemethod(
+                self._on_table_drop,
+                {"variant_mapping": variant_mapping})
         )
         if self.metadata is None:
             # TODO: what's the difference between self.metadata
@@ -1027,12 +1040,16 @@ class SchemaType(SchemaEventTarget):
             event.listen(
                 table.metadata,
                 "before_create",
-                util.portable_instancemethod(self._on_metadata_create)
+                util.portable_instancemethod(
+                    self._on_metadata_create,
+                    {"variant_mapping": variant_mapping})
             )
             event.listen(
                 table.metadata,
                 "after_drop",
-                util.portable_instancemethod(self._on_metadata_drop)
+                util.portable_instancemethod(
+                    self._on_metadata_drop,
+                    {"variant_mapping": variant_mapping})
             )
 
     def copy(self, **kw):
@@ -1073,25 +1090,48 @@ class SchemaType(SchemaEventTarget):
             t.drop(bind=bind, checkfirst=checkfirst)
 
     def _on_table_create(self, target, bind, **kw):
+        if not self._is_impl_for_variant(bind.dialect, kw):
+            return
+
         t = self.dialect_impl(bind.dialect)
         if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t._on_table_create(target, bind, **kw)
 
     def _on_table_drop(self, target, bind, **kw):
+        if not self._is_impl_for_variant(bind.dialect, kw):
+            return
+
         t = self.dialect_impl(bind.dialect)
         if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t._on_table_drop(target, bind, **kw)
 
     def _on_metadata_create(self, target, bind, **kw):
+        if not self._is_impl_for_variant(bind.dialect, kw):
+            return
+
         t = self.dialect_impl(bind.dialect)
         if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t._on_metadata_create(target, bind, **kw)
 
     def _on_metadata_drop(self, target, bind, **kw):
+        if not self._is_impl_for_variant(bind.dialect, kw):
+            return
+
         t = self.dialect_impl(bind.dialect)
         if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t._on_metadata_drop(target, bind, **kw)
 
+    def _is_impl_for_variant(self, dialect, kw):
+        variant_mapping = kw.pop('variant_mapping', None)
+        if variant_mapping is None:
+            return True
+
+        if dialect.name in variant_mapping and \
+                variant_mapping[dialect.name] is self:
+            return True
+        elif dialect.name not in variant_mapping:
+            return variant_mapping['_default'] is self
+
 
 class Enum(String, SchemaType):
 
@@ -1339,7 +1379,9 @@ class Enum(String, SchemaType):
                                  to_inspect=[Enum, SchemaType],
                                  )
 
-    def _should_create_constraint(self, compiler):
+    def _should_create_constraint(self, compiler, **kw):
+        if not self._is_impl_for_variant(compiler.dialect, kw):
+            return False
         return not self.native_enum or \
             not compiler.dialect.supports_native_enum
 
@@ -1351,11 +1393,14 @@ class Enum(String, SchemaType):
         if not self.create_constraint:
             return
 
+        variant_mapping = self._variant_mapping_for_set_table(column)
+
         e = schema.CheckConstraint(
             type_coerce(column, self).in_(self.enums),
             name=_defer_name(self.name),
             _create_rule=util.portable_instancemethod(
-                self._should_create_constraint),
+                self._should_create_constraint,
+                {"variant_mapping": variant_mapping}),
             _type_bound=True
         )
         assert e.table is table
@@ -1534,7 +1579,9 @@ class Boolean(TypeEngine, SchemaType):
         self.name = name
         self._create_events = _create_events
 
-    def _should_create_constraint(self, compiler):
+    def _should_create_constraint(self, compiler, **kw):
+        if not self._is_impl_for_variant(compiler.dialect, kw):
+            return False
         return not compiler.dialect.supports_native_boolean
 
     @util.dependencies("sqlalchemy.sql.schema")
@@ -1542,11 +1589,14 @@ class Boolean(TypeEngine, SchemaType):
         if not self.create_constraint:
             return
 
+        variant_mapping = self._variant_mapping_for_set_table(column)
+
         e = schema.CheckConstraint(
             type_coerce(column, self).in_([0, 1]),
             name=_defer_name(self.name),
             _create_rule=util.portable_instancemethod(
-                self._should_create_constraint),
+                self._should_create_constraint,
+                {"variant_mapping": variant_mapping}),
             _type_bound=True
         )
         assert e.table is table
index 2b697480dac353cb95948664cf74b2b169ac2d2a..d537e49f023d31405c43d70cc12544e5453a491e 100644 (file)
@@ -858,7 +858,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         return self.impl._type_affinity
 
     def _set_parent(self, column):
-        """Support SchemaEentTarget"""
+        """Support SchemaEventTarget"""
 
         super(TypeDecorator, self)._set_parent(column)
 
@@ -866,7 +866,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
             self.impl._set_parent(column)
 
     def _set_parent_with_dispatch(self, parent):
-        """Support SchemaEentTarget"""
+        """Support SchemaEventTarget"""
 
         super(TypeDecorator, self)._set_parent_with_dispatch(parent)
 
@@ -1222,6 +1222,24 @@ class Variant(TypeDecorator):
         else:
             return self.impl
 
+    def _set_parent(self, column):
+        """Support SchemaEventTarget"""
+
+        if isinstance(self.impl, SchemaEventTarget):
+            self.impl._set_parent(column)
+        for impl in self.mapping.values():
+            if isinstance(impl, SchemaEventTarget):
+                impl._set_parent(column)
+
+    def _set_parent_with_dispatch(self, parent):
+        """Support SchemaEventTarget"""
+
+        if isinstance(self.impl, SchemaEventTarget):
+            self.impl._set_parent_with_dispatch(parent)
+        for impl in self.mapping.values():
+            if isinstance(impl, SchemaEventTarget):
+                impl._set_parent_with_dispatch(parent)
+
     def with_variant(self, type_, dialect_name):
         """Return a new :class:`.Variant` which adds the given
         type + dialect name to the mapping, in addition to the
index 973de426c34ad2d0703f63421fea204f9c740a9d..68c0f885b58c4c1659d90eb2f29687b19d9488b3 100644 (file)
@@ -529,20 +529,24 @@ class portable_instancemethod(object):
 
     """
 
-    __slots__ = 'target', 'name', '__weakref__'
+    __slots__ = 'target', 'name', 'kwargs', '__weakref__'
 
     def __getstate__(self):
-        return {'target': self.target, 'name': self.name}
+        return {'target': self.target, 'name': self.name,
+                'kwargs': self.kwargs}
 
     def __setstate__(self, state):
         self.target = state['target']
         self.name = state['name']
+        self.kwargs = state.get('kwargs', ())
 
-    def __init__(self, meth):
+    def __init__(self, meth, kwargs=()):
         self.target = meth.__self__
         self.name = meth.__name__
+        self.kwargs = kwargs
 
     def __call__(self, *arg, **kw):
+        kw.update(self.kwargs)
         return getattr(self.target, self.name)(*arg, **kw)
 
 
index 3f2f6db3f7035efa168a02189a0540b8827398d3..807eeb60c42ac8eff4b6074ea42866ce83fb0614 100644 (file)
@@ -523,6 +523,77 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
                 "twoHITHERE"
             )
 
+    @testing.provide_metadata
+    def test_generic_w_pg_variant(self):
+        some_table = Table(
+            'some_table', self.metadata,
+            Column(
+                'data',
+                Enum(
+                    "one", "two", "three",
+                    native_enum=True   # make sure this is True because
+                                       # it should *not* take effect due to
+                                       # the variant
+                ).with_variant(
+                    postgresql.ENUM("four", "five", "six", name="my_enum"),
+                    "postgresql"
+                )
+            )
+        )
+
+        with testing.db.begin() as conn:
+            assert 'my_enum' not in [
+                e['name'] for e in inspect(conn).get_enums()]
+
+            self.metadata.create_all(conn)
+
+            assert 'my_enum' in [
+                e['name'] for e in inspect(conn).get_enums()]
+
+            conn.execute(
+                some_table.insert(), {"data": "five"}
+            )
+
+            self.metadata.drop_all(conn)
+
+            assert 'my_enum' not in [
+                e['name'] for e in inspect(conn).get_enums()]
+
+    @testing.provide_metadata
+    def test_generic_w_some_other_variant(self):
+        some_table = Table(
+            'some_table', self.metadata,
+            Column(
+                'data',
+                Enum(
+                    "one", "two", "three",
+                    name="my_enum",
+                    native_enum=True
+                ).with_variant(
+                    Enum("four", "five", "six"),
+                    "mysql"
+                )
+            )
+        )
+
+        with testing.db.begin() as conn:
+            assert 'my_enum' not in [
+                e['name'] for e in inspect(conn).get_enums()]
+
+            self.metadata.create_all(conn)
+
+            assert 'my_enum' in [
+                e['name'] for e in inspect(conn).get_enums()]
+
+            conn.execute(
+                some_table.insert(), {"data": "two"}
+            )
+
+            self.metadata.drop_all(conn)
+
+            assert 'my_enum' not in [
+                e['name'] for e in inspect(conn).get_enums()]
+
 
 class OIDTest(fixtures.TestBase):
     __only_on__ = 'postgresql'
index f91a426e77398ca5255332a0513b324e8531a00a..3b9059bdaf83314955848640c8d93c515a3e584d 100644 (file)
@@ -47,6 +47,14 @@ class DefaultRequirements(SuiteRequirements):
 
         return exclusions.open()
 
+    @property
+    def enforces_check_constraints(self):
+        """Target database must also enforce check constraints."""
+
+        return self.check_constraints + fails_on(
+            ['mysql'], "check constraints don't enforce"
+        )
+
     @property
     def named_constraints(self):
         """target database must support names for constraints."""
index bd67b6f69b25720fc20678ccb6919a4957fbbc1c..6d0df3b5f096f4dd0d5a60ab4bb56fc593d84bdc 100644 (file)
@@ -1607,7 +1607,7 @@ class SchemaTypeTest(fixtures.TestBase):
             impl = target_typ
 
         typ = MyType()
-        self._test_before_parent_attach(typ, target_typ)
+        self._test_before_parent_attach(typ, target_typ, double=True)
 
     def test_before_parent_attach_typedec_of_schematype(self):
         class MyType(TypeDecorator, sqltypes.SchemaType):
@@ -1623,17 +1623,52 @@ class SchemaTypeTest(fixtures.TestBase):
         typ = MyType()
         self._test_before_parent_attach(typ)
 
-    def _test_before_parent_attach(self, typ, evt_target=None):
+    def _test_before_parent_attach(self, typ, evt_target=None, double=False):
         canary = mock.Mock()
 
         if evt_target is None:
             evt_target = typ
 
-        event.listen(evt_target, "before_parent_attach", canary.go)
+        orig_set_parent = evt_target._set_parent
+        orig_set_parent_w_dispatch = evt_target._set_parent_with_dispatch
 
-        c = Column('q', typ)
+        def _set_parent(parent):
+            orig_set_parent(parent)
+            canary._set_parent(parent)
 
-        eq_(canary.mock_calls, [mock.call.go(evt_target, c)])
+        def _set_parent_w_dispatch(parent):
+            orig_set_parent_w_dispatch(parent)
+            canary._set_parent_with_dispatch(parent)
+
+        with mock.patch.object(evt_target, '_set_parent', _set_parent):
+            with mock.patch.object(
+                    evt_target, '_set_parent_with_dispatch',
+                    _set_parent_w_dispatch):
+                event.listen(evt_target, "before_parent_attach", canary.go)
+
+                c = Column('q', typ)
+
+        if double:
+            # no clean way yet to fix this, inner schema type is called
+            # twice, but this is a very unusual use case.
+            eq_(
+                canary.mock_calls,
+                [
+                    mock.call._set_parent(c),
+                    mock.call.go(evt_target, c),
+                    mock.call._set_parent(c),
+                    mock.call._set_parent_with_dispatch(c)
+                ]
+            )
+        else:
+            eq_(
+                canary.mock_calls,
+                [
+                    mock.call.go(evt_target, c),
+                    mock.call._set_parent(c),
+                    mock.call._set_parent_with_dispatch(c)
+                ]
+            )
 
     def test_independent_schema(self):
         m = MetaData()
index 8fbc65ef27a697e4ad030b60ae80032b1197aa15..b417e696409d22c473e25e37ae1db6e08b13729b 100644 (file)
@@ -1315,11 +1315,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
                 non_native_enum_table.insert(), {"id": 1, "someenum": None})
             eq_(conn.scalar(select([non_native_enum_table.c.someenum])), None)
 
-    @testing.fails_on(
-        'mysql',
-        "The CHECK clause is parsed but ignored by all storage engines.")
-    @testing.fails_on(
-        'mssql', "FIXME: MS-SQL 2005 doesn't honor CHECK ?!?")
+    @testing.requires.enforces_check_constraints
     def test_check_constraint(self):
         assert_raises(
             (exc.IntegrityError, exc.ProgrammingError),
@@ -1327,6 +1323,62 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             "insert into non_native_enum_table "
             "(id, someenum) values(1, 'four')")
 
+    @testing.requires.enforces_check_constraints
+    @testing.provide_metadata
+    def test_variant_we_are_default(self):
+        # test that the "variant" does not create a constraint
+        t = Table(
+            'my_table', self.metadata,
+            Column(
+                'data', Enum("one", "two", "three", name="e1").with_variant(
+                    Enum("four", "five", "six", name="e2"), "some_other_db"
+                )
+            )
+        )
+
+        eq_(
+            len([c for c in t.constraints if isinstance(c, CheckConstraint)]),
+            2
+        )
+
+        with testing.db.connect() as conn:
+            self.metadata.create_all(conn)
+            assert_raises(
+                (exc.IntegrityError, exc.ProgrammingError, exc.DataError),
+                conn.execute,
+                "insert into my_table "
+                "(data) values('four')")
+            conn.execute("insert into my_table (data) values ('two')")
+
+    @testing.requires.enforces_check_constraints
+    @testing.provide_metadata
+    def test_variant_we_are_not_default(self):
+        # test that the "variant" does not create a constraint
+        t = Table(
+            'my_table', self.metadata,
+            Column(
+                'data', Enum("one", "two", "three", name="e1").with_variant(
+                    Enum("four", "five", "six", name="e2"),
+                    testing.db.dialect.name
+                )
+            )
+        )
+
+        # ensure Variant isn't exploding the constraints
+        eq_(
+            len([c for c in t.constraints if isinstance(c, CheckConstraint)]),
+            2
+        )
+
+        with testing.db.connect() as conn:
+            self.metadata.create_all(conn)
+            assert_raises(
+                (exc.IntegrityError, exc.ProgrammingError, exc.DataError),
+                conn.execute,
+                "insert into my_table "
+                "(data) values('two')")
+            conn.execute("insert into my_table (data) values ('four')")
+
     def test_skip_check_constraint(self):
         with testing.db.connect() as conn:
             conn.execute(