From: Mike Bayer Date: Tue, 7 Mar 2017 17:53:00 +0000 (-0500) Subject: Allow SchemaType and Variant to work together X-Git-Tag: rel_1_1_7~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c04870ba7b8098c7d408ad66f60efe7229496fde;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Allow SchemaType and Variant to work together Added support for the :class:`.Variant` and the :class:`.SchemaType` objects to be compatible with each other. That is, a variant can be created against a type like :class:`.Enum`, and the instructions to create constraints and/or database-specific type objects will propagate correctly as per the variant's dialect mapping. Also added testing for some potential double-event scenarios on TypeDecorator but it seems usually this doesn't occur. Change-Id: I4a7e7c26b4133cd14e870f5bc34a1b2f0f19a14a Fixes: #2892 --- diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index 22f4e2bc4c..9d43d3ae16 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -21,6 +21,16 @@ .. changelog:: :version: 1.1.7 + .. change:: + :tags: bug, sql, postgresql + :tickets: 2892 + + Added support for the :class:`.Variant` and the :class:`.SchemaType` + objects to be compatible with each other. That is, a variant + can be created against a type like :class:`.Enum`, and the instructions + to create constraints and/or database-specific type objects will + propagate correctly as per the variant's dialect mapping. + .. change:: :tags: bug, sql :tickets: 3931 diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index bb39388ab4..8a114ece60 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -15,7 +15,7 @@ import collections import json from . import elements -from .type_api import TypeEngine, TypeDecorator, to_instance +from .type_api import TypeEngine, TypeDecorator, to_instance, Variant from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \ Slice, _literal_as_binds from .. import exc, util, processors @@ -1003,6 +1003,14 @@ class SchemaType(SchemaEventTarget): def _set_parent(self, column): column._on_table_attach(util.portable_instancemethod(self._set_table)) + def _variant_mapping_for_set_table(self, column): + if isinstance(column.type, Variant): + variant_mapping = column.type.mapping.copy() + variant_mapping['_default'] = column.type.impl + else: + variant_mapping = None + return variant_mapping + def _set_table(self, column, table): if self.inherit_schema: self.schema = table.schema @@ -1010,16 +1018,21 @@ class SchemaType(SchemaEventTarget): if not self._create_events: return + variant_mapping = self._variant_mapping_for_set_table(column) + event.listen( table, "before_create", util.portable_instancemethod( - self._on_table_create) + self._on_table_create, + {"variant_mapping": variant_mapping}) ) event.listen( table, "after_drop", - util.portable_instancemethod(self._on_table_drop) + util.portable_instancemethod( + self._on_table_drop, + {"variant_mapping": variant_mapping}) ) if self.metadata is None: # TODO: what's the difference between self.metadata @@ -1027,12 +1040,16 @@ class SchemaType(SchemaEventTarget): event.listen( table.metadata, "before_create", - util.portable_instancemethod(self._on_metadata_create) + util.portable_instancemethod( + self._on_metadata_create, + {"variant_mapping": variant_mapping}) ) event.listen( table.metadata, "after_drop", - util.portable_instancemethod(self._on_metadata_drop) + util.portable_instancemethod( + self._on_metadata_drop, + {"variant_mapping": variant_mapping}) ) def copy(self, **kw): @@ -1073,25 +1090,48 @@ class SchemaType(SchemaEventTarget): t.drop(bind=bind, checkfirst=checkfirst) def _on_table_create(self, target, bind, **kw): + if not self._is_impl_for_variant(bind.dialect, kw): + return + t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_table_create(target, bind, **kw) def _on_table_drop(self, target, bind, **kw): + if not self._is_impl_for_variant(bind.dialect, kw): + return + t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_table_drop(target, bind, **kw) def _on_metadata_create(self, target, bind, **kw): + if not self._is_impl_for_variant(bind.dialect, kw): + return + t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_metadata_create(target, bind, **kw) def _on_metadata_drop(self, target, bind, **kw): + if not self._is_impl_for_variant(bind.dialect, kw): + return + t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_metadata_drop(target, bind, **kw) + def _is_impl_for_variant(self, dialect, kw): + variant_mapping = kw.pop('variant_mapping', None) + if variant_mapping is None: + return True + + if dialect.name in variant_mapping and \ + variant_mapping[dialect.name] is self: + return True + elif dialect.name not in variant_mapping: + return variant_mapping['_default'] is self + class Enum(String, SchemaType): @@ -1339,7 +1379,9 @@ class Enum(String, SchemaType): to_inspect=[Enum, SchemaType], ) - def _should_create_constraint(self, compiler): + def _should_create_constraint(self, compiler, **kw): + if not self._is_impl_for_variant(compiler.dialect, kw): + return False return not self.native_enum or \ not compiler.dialect.supports_native_enum @@ -1351,11 +1393,14 @@ class Enum(String, SchemaType): if not self.create_constraint: return + variant_mapping = self._variant_mapping_for_set_table(column) + e = schema.CheckConstraint( type_coerce(column, self).in_(self.enums), name=_defer_name(self.name), _create_rule=util.portable_instancemethod( - self._should_create_constraint), + self._should_create_constraint, + {"variant_mapping": variant_mapping}), _type_bound=True ) assert e.table is table @@ -1534,7 +1579,9 @@ class Boolean(TypeEngine, SchemaType): self.name = name self._create_events = _create_events - def _should_create_constraint(self, compiler): + def _should_create_constraint(self, compiler, **kw): + if not self._is_impl_for_variant(compiler.dialect, kw): + return False return not compiler.dialect.supports_native_boolean @util.dependencies("sqlalchemy.sql.schema") @@ -1542,11 +1589,14 @@ class Boolean(TypeEngine, SchemaType): if not self.create_constraint: return + variant_mapping = self._variant_mapping_for_set_table(column) + e = schema.CheckConstraint( type_coerce(column, self).in_([0, 1]), name=_defer_name(self.name), _create_rule=util.portable_instancemethod( - self._should_create_constraint), + self._should_create_constraint, + {"variant_mapping": variant_mapping}), _type_bound=True ) assert e.table is table diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 2b697480da..d537e49f02 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -858,7 +858,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): return self.impl._type_affinity def _set_parent(self, column): - """Support SchemaEentTarget""" + """Support SchemaEventTarget""" super(TypeDecorator, self)._set_parent(column) @@ -866,7 +866,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): self.impl._set_parent(column) def _set_parent_with_dispatch(self, parent): - """Support SchemaEentTarget""" + """Support SchemaEventTarget""" super(TypeDecorator, self)._set_parent_with_dispatch(parent) @@ -1222,6 +1222,24 @@ class Variant(TypeDecorator): else: return self.impl + def _set_parent(self, column): + """Support SchemaEventTarget""" + + if isinstance(self.impl, SchemaEventTarget): + self.impl._set_parent(column) + for impl in self.mapping.values(): + if isinstance(impl, SchemaEventTarget): + impl._set_parent(column) + + def _set_parent_with_dispatch(self, parent): + """Support SchemaEventTarget""" + + if isinstance(self.impl, SchemaEventTarget): + self.impl._set_parent_with_dispatch(parent) + for impl in self.mapping.values(): + if isinstance(impl, SchemaEventTarget): + impl._set_parent_with_dispatch(parent) + def with_variant(self, type_, dialect_name): """Return a new :class:`.Variant` which adds the given type + dialect name to the mapping, in addition to the diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 973de426c3..68c0f885b5 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -529,20 +529,24 @@ class portable_instancemethod(object): """ - __slots__ = 'target', 'name', '__weakref__' + __slots__ = 'target', 'name', 'kwargs', '__weakref__' def __getstate__(self): - return {'target': self.target, 'name': self.name} + return {'target': self.target, 'name': self.name, + 'kwargs': self.kwargs} def __setstate__(self, state): self.target = state['target'] self.name = state['name'] + self.kwargs = state.get('kwargs', ()) - def __init__(self, meth): + def __init__(self, meth, kwargs=()): self.target = meth.__self__ self.name = meth.__name__ + self.kwargs = kwargs def __call__(self, *arg, **kw): + kw.update(self.kwargs) return getattr(self.target, self.name)(*arg, **kw) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 3f2f6db3f7..807eeb60c4 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -523,6 +523,77 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): "twoHITHERE" ) + @testing.provide_metadata + def test_generic_w_pg_variant(self): + some_table = Table( + 'some_table', self.metadata, + Column( + 'data', + Enum( + "one", "two", "three", + native_enum=True # make sure this is True because + # it should *not* take effect due to + # the variant + ).with_variant( + postgresql.ENUM("four", "five", "six", name="my_enum"), + "postgresql" + ) + ) + ) + + with testing.db.begin() as conn: + assert 'my_enum' not in [ + e['name'] for e in inspect(conn).get_enums()] + + self.metadata.create_all(conn) + + assert 'my_enum' in [ + e['name'] for e in inspect(conn).get_enums()] + + conn.execute( + some_table.insert(), {"data": "five"} + ) + + self.metadata.drop_all(conn) + + assert 'my_enum' not in [ + e['name'] for e in inspect(conn).get_enums()] + + @testing.provide_metadata + def test_generic_w_some_other_variant(self): + some_table = Table( + 'some_table', self.metadata, + Column( + 'data', + Enum( + "one", "two", "three", + name="my_enum", + native_enum=True + ).with_variant( + Enum("four", "five", "six"), + "mysql" + ) + ) + ) + + with testing.db.begin() as conn: + assert 'my_enum' not in [ + e['name'] for e in inspect(conn).get_enums()] + + self.metadata.create_all(conn) + + assert 'my_enum' in [ + e['name'] for e in inspect(conn).get_enums()] + + conn.execute( + some_table.insert(), {"data": "two"} + ) + + self.metadata.drop_all(conn) + + assert 'my_enum' not in [ + e['name'] for e in inspect(conn).get_enums()] + class OIDTest(fixtures.TestBase): __only_on__ = 'postgresql' diff --git a/test/requirements.py b/test/requirements.py index f91a426e77..3b9059bdaf 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -47,6 +47,14 @@ class DefaultRequirements(SuiteRequirements): return exclusions.open() + @property + def enforces_check_constraints(self): + """Target database must also enforce check constraints.""" + + return self.check_constraints + fails_on( + ['mysql'], "check constraints don't enforce" + ) + @property def named_constraints(self): """target database must support names for constraints.""" diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index bd67b6f69b..6d0df3b5f0 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -1607,7 +1607,7 @@ class SchemaTypeTest(fixtures.TestBase): impl = target_typ typ = MyType() - self._test_before_parent_attach(typ, target_typ) + self._test_before_parent_attach(typ, target_typ, double=True) def test_before_parent_attach_typedec_of_schematype(self): class MyType(TypeDecorator, sqltypes.SchemaType): @@ -1623,17 +1623,52 @@ class SchemaTypeTest(fixtures.TestBase): typ = MyType() self._test_before_parent_attach(typ) - def _test_before_parent_attach(self, typ, evt_target=None): + def _test_before_parent_attach(self, typ, evt_target=None, double=False): canary = mock.Mock() if evt_target is None: evt_target = typ - event.listen(evt_target, "before_parent_attach", canary.go) + orig_set_parent = evt_target._set_parent + orig_set_parent_w_dispatch = evt_target._set_parent_with_dispatch - c = Column('q', typ) + def _set_parent(parent): + orig_set_parent(parent) + canary._set_parent(parent) - eq_(canary.mock_calls, [mock.call.go(evt_target, c)]) + def _set_parent_w_dispatch(parent): + orig_set_parent_w_dispatch(parent) + canary._set_parent_with_dispatch(parent) + + with mock.patch.object(evt_target, '_set_parent', _set_parent): + with mock.patch.object( + evt_target, '_set_parent_with_dispatch', + _set_parent_w_dispatch): + event.listen(evt_target, "before_parent_attach", canary.go) + + c = Column('q', typ) + + if double: + # no clean way yet to fix this, inner schema type is called + # twice, but this is a very unusual use case. + eq_( + canary.mock_calls, + [ + mock.call._set_parent(c), + mock.call.go(evt_target, c), + mock.call._set_parent(c), + mock.call._set_parent_with_dispatch(c) + ] + ) + else: + eq_( + canary.mock_calls, + [ + mock.call.go(evt_target, c), + mock.call._set_parent(c), + mock.call._set_parent_with_dispatch(c) + ] + ) def test_independent_schema(self): m = MetaData() diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 8fbc65ef27..b417e69640 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1315,11 +1315,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): non_native_enum_table.insert(), {"id": 1, "someenum": None}) eq_(conn.scalar(select([non_native_enum_table.c.someenum])), None) - @testing.fails_on( - 'mysql', - "The CHECK clause is parsed but ignored by all storage engines.") - @testing.fails_on( - 'mssql', "FIXME: MS-SQL 2005 doesn't honor CHECK ?!?") + @testing.requires.enforces_check_constraints def test_check_constraint(self): assert_raises( (exc.IntegrityError, exc.ProgrammingError), @@ -1327,6 +1323,62 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): "insert into non_native_enum_table " "(id, someenum) values(1, 'four')") + @testing.requires.enforces_check_constraints + @testing.provide_metadata + def test_variant_we_are_default(self): + # test that the "variant" does not create a constraint + t = Table( + 'my_table', self.metadata, + Column( + 'data', Enum("one", "two", "three", name="e1").with_variant( + Enum("four", "five", "six", name="e2"), "some_other_db" + ) + ) + ) + + eq_( + len([c for c in t.constraints if isinstance(c, CheckConstraint)]), + 2 + ) + + with testing.db.connect() as conn: + self.metadata.create_all(conn) + assert_raises( + (exc.IntegrityError, exc.ProgrammingError, exc.DataError), + conn.execute, + "insert into my_table " + "(data) values('four')") + conn.execute("insert into my_table (data) values ('two')") + + @testing.requires.enforces_check_constraints + @testing.provide_metadata + def test_variant_we_are_not_default(self): + # test that the "variant" does not create a constraint + t = Table( + 'my_table', self.metadata, + Column( + 'data', Enum("one", "two", "three", name="e1").with_variant( + Enum("four", "five", "six", name="e2"), + testing.db.dialect.name + ) + ) + ) + + # ensure Variant isn't exploding the constraints + eq_( + len([c for c in t.constraints if isinstance(c, CheckConstraint)]), + 2 + ) + + with testing.db.connect() as conn: + self.metadata.create_all(conn) + assert_raises( + (exc.IntegrityError, exc.ProgrammingError, exc.DataError), + conn.execute, + "insert into my_table " + "(data) values('two')") + conn.execute("insert into my_table (data) values ('four')") + def test_skip_check_constraint(self): with testing.db.connect() as conn: conn.execute(