]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
TypeDecorator passes "outer" flag to itself for set_parent accounting
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Mar 2021 21:57:20 +0000 (17:57 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Mar 2021 22:03:55 +0000 (18:03 -0400)
Fixed bug first introduced in as some combination of :ticket:`2892`,
:ticket:`2919` nnd :ticket:`3832` where the attachment events for a
:class:`_types.TypeDecorator` would be doubled up against the "impl" class,
if the "impl" were also a :class:`_types.SchemaType`. The real-world case
is any :class:`_types.TypeDecorator` against :class:`_types.Enum` or
:class:`_types.Boolean` would get a doubled
:class:`_schema.CheckConstraint` when the ``create_constraint=True`` flag
is set.

Fixes: #6152
Change-Id: I3218b7081297270c132421f6765b5c3673d10a5c
(cherry picked from commit 3b18c9db3a81dfeec6de33e7e2ffbd8d265d1d79)

doc/build/changelog/unreleased_13/6152.rst [new file with mode: 0644]
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/sql/test_metadata.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/6152.rst b/doc/build/changelog/unreleased_13/6152.rst
new file mode 100644 (file)
index 0000000..b6dae22
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, schema
+    :tickets: 6152
+
+    Fixed bug first introduced in as some combination of :ticket:`2892`,
+    :ticket:`2919` nnd :ticket:`3832` where the attachment events for a
+    :class:`_types.TypeDecorator` would be doubled up against the "impl" class,
+    if the "impl" were also a :class:`_types.SchemaType`. The real-world case
+    is any :class:`_types.TypeDecorator` against :class:`_types.Enum` or
+    :class:`_types.Boolean` would get a doubled
+    :class:`_schema.CheckConstraint` when the ``create_constraint=True`` flag
+    is set.
+
index 0d3d12e66a6812aeef365b10dcca67f6b9d0751c..ec8cbcf2256c1fd9833f9ae68a40c7431bd3a38d 100644 (file)
@@ -456,9 +456,9 @@ class SchemaEventTarget(object):
     def _set_parent(self, parent):
         """Associate with this SchemaEvent's parent object."""
 
-    def _set_parent_with_dispatch(self, parent):
+    def _set_parent_with_dispatch(self, parent, **kw):
         self.dispatch.before_parent_attach(self, parent)
-        self._set_parent(parent)
+        self._set_parent(parent, **kw)
         self.dispatch.after_parent_attach(self, parent)
 
 
index d47d47cfbbf2fc7e9726c5727f6eb0110a87fbf2..fe2ca9a09cd99c7dcba445fa859362160c645798 100644 (file)
@@ -1050,7 +1050,7 @@ class SchemaType(SchemaEventTarget):
     def _translate_schema(self, effective_schema, map_):
         return map_.get(effective_schema, effective_schema)
 
-    def _set_parent(self, column):
+    def _set_parent(self, column, **kw):
         column._on_table_attach(util.portable_instancemethod(self._set_table))
 
     def _variant_mapping_for_set_table(self, column):
@@ -2745,13 +2745,13 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
     def compare_values(self, x, y):
         return x == y
 
-    def _set_parent(self, column):
+    def _set_parent(self, column, **kw):
         """Support SchemaEventTarget"""
 
         if isinstance(self.item_type, SchemaEventTarget):
-            self.item_type._set_parent(column)
+            self.item_type._set_parent(column, **kw)
 
-    def _set_parent_with_dispatch(self, parent):
+    def _set_parent_with_dispatch(self, parent, **kw):
         """Support SchemaEventTarget"""
 
         super(ARRAY, self)._set_parent_with_dispatch(parent)
index a29e222cbca250b2bebe6cafae44b6be91fa27c5..7cf6971cde85972d8c872f86eaaa27cd85bc0cdf 100644 (file)
@@ -992,18 +992,20 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         """
         return self.impl._type_affinity
 
-    def _set_parent(self, column):
+    def _set_parent(self, column, outer=False, **kw):
         """Support SchemaEventTarget"""
 
         super(TypeDecorator, self)._set_parent(column)
 
-        if isinstance(self.impl, SchemaEventTarget):
-            self.impl._set_parent(column)
+        if not outer and isinstance(self.impl, SchemaEventTarget):
+            self.impl._set_parent(column, outer=False, **kw)
 
     def _set_parent_with_dispatch(self, parent):
         """Support SchemaEventTarget"""
 
-        super(TypeDecorator, self)._set_parent_with_dispatch(parent)
+        super(TypeDecorator, self)._set_parent_with_dispatch(
+            parent, outer=True
+        )
 
         if isinstance(self.impl, SchemaEventTarget):
             self.impl._set_parent_with_dispatch(parent)
@@ -1410,14 +1412,14 @@ class Variant(TypeDecorator):
         else:
             return self.impl
 
-    def _set_parent(self, column):
+    def _set_parent(self, column, **kw):
         """Support SchemaEventTarget"""
 
         if isinstance(self.impl, SchemaEventTarget):
-            self.impl._set_parent(column)
+            self.impl._set_parent(column, **kw)
         for impl in self.mapping.values():
             if isinstance(impl, SchemaEventTarget):
-                impl._set_parent(column)
+                impl._set_parent(column, **kw)
 
     def _set_parent_with_dispatch(self, parent):
         """Support SchemaEventTarget"""
index 43f987e95bee71d872c5ab4d4e5f283359658c6b..65c5def205001b1b11fe1fc132b93168fd1657d1 100644 (file)
@@ -2012,6 +2012,7 @@ class SchemaTypeTest(fixtures.TestBase):
     def test_before_parent_attach_typedec_enclosing_schematype(self):
         # additional test for [ticket:2919] as part of test for
         # [ticket:3832]
+        # this also serves as the test for [ticket:6152]
 
         class MySchemaType(sqltypes.TypeEngine, sqltypes.SchemaType):
             pass
@@ -2022,7 +2023,7 @@ class SchemaTypeTest(fixtures.TestBase):
             impl = target_typ
 
         typ = MyType()
-        self._test_before_parent_attach(typ, target_typ, double=True)
+        self._test_before_parent_attach(typ, target_typ)
 
     def test_before_parent_attach_array_enclosing_schematype(self):
         # test for [ticket:4141] which is the same idea as [ticket:3832]
@@ -2046,7 +2047,7 @@ class SchemaTypeTest(fixtures.TestBase):
         typ = MyType()
         self._test_before_parent_attach(typ)
 
-    def _test_before_parent_attach(self, typ, evt_target=None, double=False):
+    def _test_before_parent_attach(self, typ, evt_target=None):
         canary = mock.Mock()
 
         if evt_target is None:
@@ -2055,8 +2056,8 @@ class SchemaTypeTest(fixtures.TestBase):
         orig_set_parent = evt_target._set_parent
         orig_set_parent_w_dispatch = evt_target._set_parent_with_dispatch
 
-        def _set_parent(parent):
-            orig_set_parent(parent)
+        def _set_parent(parent, **kw):
+            orig_set_parent(parent, **kw)
             canary._set_parent(parent)
 
         def _set_parent_w_dispatch(parent):
@@ -2071,27 +2072,14 @@ class SchemaTypeTest(fixtures.TestBase):
 
                 c = Column("q", typ)
 
-        if double:
-            # no clean way yet to fix this, inner schema type is called
-            # twice, but this is a very unusual use case.
-            eq_(
-                canary.mock_calls,
-                [
-                    mock.call._set_parent(c),
-                    mock.call.go(evt_target, c),
-                    mock.call._set_parent(c),
-                    mock.call._set_parent_with_dispatch(c),
-                ],
-            )
-        else:
-            eq_(
-                canary.mock_calls,
-                [
-                    mock.call.go(evt_target, c),
-                    mock.call._set_parent(c),
-                    mock.call._set_parent_with_dispatch(c),
-                ],
-            )
+        eq_(
+            canary.mock_calls,
+            [
+                mock.call.go(evt_target, c),
+                mock.call._set_parent(c),
+                mock.call._set_parent_with_dispatch(c),
+            ],
+        )
 
     def test_independent_schema(self):
         m = MetaData()
index 6782f262b65cf7c52632ae1325c5dfd2d057a2d4..0d4dae93117270f96f32ba8591b14c7d2d633013 100644 (file)
@@ -737,6 +737,17 @@ class UserDefinedTest(
             Float().dialect_impl(pg).__class__,
         )
 
+    @testing.combinations((Boolean,), (Enum,))
+    def test_typedecorator_schematype_constraint(self, typ):
+        class B(TypeDecorator):
+            impl = typ
+
+        t1 = Table("t1", MetaData(), Column("q", B(create_constraint=True)))
+        eq_(
+            len([c for c in t1.constraints if isinstance(c, CheckConstraint)]),
+            1,
+        )
+
     def test_type_decorator_repr(self):
         class MyType(TypeDecorator):
             impl = VARCHAR