From: Mike Bayer Date: Sat, 23 Oct 2021 15:26:45 +0000 (-0400) Subject: replace Variant with direct feature inside of TypeEngine X-Git-Tag: rel_2_0_0b1~571 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9e3c8d0d71ae0aabe9f5abfae2db838cb80fe320;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git replace Variant with direct feature inside of TypeEngine The :meth:`_sqltypes.TypeEngine.with_variant` method now returns a copy of the original :class:`_sqltypes.TypeEngine` object, rather than wrapping it inside the ``Variant`` class, which is effectively removed (the import symbol remains for backwards compatibility with code that may be testing for this symbol). While the previous approach maintained in-Python behaviors, maintaining the original type allows for clearer type checking and debugging. Fixes: #6980 Change-Id: I158c7e56306b886b5b82b040205c428a5c4a242c --- diff --git a/doc/build/changelog/migration_20.rst b/doc/build/changelog/migration_20.rst index 72530142e8..4c2ef22bdd 100644 --- a/doc/build/changelog/migration_20.rst +++ b/doc/build/changelog/migration_20.rst @@ -159,6 +159,42 @@ of the Cython requirement. .. _Cython: https://cython.org/ +.. _change_6980: + +"with_variant()" clones the original TypeEngine rather than changing the type +----------------------------------------------------------------------------- + +The :meth:`_sqltypes.TypeEngine.with_variant` method, which is used to apply +alternate per-database behaviors to a particular type, now returns a copy of +the original :class:`_sqltypes.TypeEngine` object with the variant information +stored internally, rather than wrapping it inside the ``Variant`` class. + +While the previous ``Variant`` approach was able to maintain all the in-Python +behaviors of the original type using dynamic attribute getters, the improvement +here is that when calling upon a variant, the returned type remains an instance +of the original type, which works more smoothly with type checkers such as mypy +and pylance. Given a program as below:: + + import typing + + + from sqlalchemy import String + from sqlalchemy.dialects.mysql import VARCHAR + + + type_ = String(255).with_variant(VARCHAR(255, charset='utf8mb4'), "mysql") + + if typing.TYPE_CHECKING: + reveal_type(type_) + +A type checker like pyright will now report the type as:: + + info: Type of "type_" is "String" + + +:ticket:`6980` + + .. _change_4926: Python division operator performs true division for all backends; added floor division diff --git a/doc/build/changelog/unreleased_20/6980.rst b/doc/build/changelog/unreleased_20/6980.rst new file mode 100644 index 0000000000..d83599c48c --- /dev/null +++ b/doc/build/changelog/unreleased_20/6980.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: improvement, typing + :tickets: 6980 + + The :meth:`_sqltypes.TypeEngine.with_variant` method now returns a copy of + the original :class:`_sqltypes.TypeEngine` object, rather than wrapping it + inside the ``Variant`` class, which is effectively removed (the import + symbol remains for backwards compatibility with code that may be testing + for this symbol). While the previous approach maintained in-Python + behaviors, maintaining the original type allows for clearer type checking + and debugging. + + .. seealso:: + + :ref:`change_6980` + + + diff --git a/doc/build/conf.py b/doc/build/conf.py index 9e2d8e9de7..cba0fee6ab 100644 --- a/doc/build/conf.py +++ b/doc/build/conf.py @@ -63,6 +63,7 @@ changelog_sections = [ "sql", "schema", "extensions", + "typing", "mypy", "asyncio", "postgresql", @@ -76,6 +77,7 @@ changelog_sections = [ # tags to sort on inside of sections changelog_inner_tag_sort = [ "feature", + "improvement", "usecase", "change", "changed", diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a9dd6a23ae..90f32a4f7c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -512,6 +512,11 @@ class TypeCompiler(metaclass=util.EnsureKWArgType): self.dialect = dialect def process(self, type_, **kw): + if ( + type_._variant_mapping + and self.dialect.name in type_._variant_mapping + ): + type_ = type_._variant_mapping[self.dialect.name] return type_._compiler_dispatch(self, **kw) def visit_unsupported_compilation(self, element, err, **kw): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 9b5005b5d8..dd8238450c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1586,9 +1586,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): # check if this Column is proxying another column if "_proxies" in kwargs: self._proxies = kwargs.pop("_proxies") - # otherwise, add DDL-related events - elif isinstance(self.type, SchemaEventTarget): - self.type._set_parent_with_dispatch(self) + else: + # otherwise, add DDL-related events + if isinstance(self.type, SchemaEventTarget): + self.type._set_parent_with_dispatch(self) + for impl in self.type._variant_mapping.values(): + if isinstance(impl, SchemaEventTarget): + impl._set_parent_with_dispatch(self) if self.default is not None: if isinstance(self.default, (ColumnDefault, Sequence)): diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 574692fed7..cda7b35cda 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -33,7 +33,7 @@ from .type_api import NativeForEmulated # noqa from .type_api import to_instance from .type_api import TypeDecorator from .type_api import TypeEngine -from .type_api import Variant +from .type_api import Variant # noqa from .. import event from .. import exc from .. import inspection @@ -844,12 +844,25 @@ class SchemaType(SchemaEventTarget): ) def _set_parent(self, column, **kw): + # set parent hook is when this type is associated with a column. + # Column calls it for all SchemaEventTarget instances, either the + # base type and/or variants in _variant_mapping. + + # we want to register a second hook to trigger when that column is + # associated with a table. in that event, we and all of our variants + # may want to set up some state on the table such as a CheckConstraint + # that will conditionally render at DDL render time. + + # the base SchemaType also sets up events for + # on_table/metadata_create/drop in this method, which is used by + # "native" types with a separate CREATE/DROP e.g. Postgresql.ENUM + column._on_table_attach(util.portable_instancemethod(self._set_table)) def _variant_mapping_for_set_table(self, column): - if isinstance(column.type, Variant): - variant_mapping = column.type.mapping.copy() - variant_mapping["_default"] = column.type.impl + if column.type._variant_mapping: + variant_mapping = dict(column.type._variant_mapping) + variant_mapping["_default"] = column.type else: variant_mapping = None return variant_mapping @@ -880,8 +893,9 @@ class SchemaType(SchemaEventTarget): ), ) if self.metadata is None: - # TODO: what's the difference between self.metadata - # and table.metadata here ? + # if SchemaType were created w/ a metadata argument, these + # events would already have been associated with that metadata + # and would preclude an association with table.metadata event.listen( table.metadata, "before_create", @@ -963,9 +977,19 @@ class SchemaType(SchemaEventTarget): def _is_impl_for_variant(self, dialect, kw): variant_mapping = kw.pop("variant_mapping", None) - if variant_mapping is None: + + if not variant_mapping: return True + # for types that have _variant_mapping, all the impls in the map + # that are SchemaEventTarget subclasses get set up as event holders. + # this is so that constructs that need + # to be associated with the Table at dialect-agnostic time etc. like + # CheckConstraints can be set up with that table. they then add + # to these constraints a DDL check_rule that among other things + # will check this _is_impl_for_variant() method to determine when + # the dialect is known that we are part of the table's DDL sequence. + # since PostgreSQL is the only DB that has ARRAY this can only # be integration tested by PG-specific tests def _we_are_the_impl(typ): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index cc226d7e37..07cd4d95fb 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -9,6 +9,7 @@ """ +import typing from . import operators from .base import SchemaEventTarget @@ -29,6 +30,10 @@ TABLEVALUE = None _resolve_value_to_type = None +# replace with pep-673 when applicable +SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine") + + class TypeEngine(Traversible): """The ultimate base class for all SQL datatypes. @@ -192,6 +197,8 @@ class TypeEngine(Traversible): """ + _variant_mapping = util.EMPTY_DICT + def evaluates_none(self): """Return a copy of this type which has the :attr:`.should_evaluate_none` flag set to True. @@ -532,8 +539,10 @@ class TypeEngine(Traversible): """ raise NotImplementedError() - def with_variant(self, type_, dialect_name): - r"""Produce a new type object that will utilize the given + def with_variant( + self: SelfTypeEngine, type_: "TypeEngine", dialect_name: str + ) -> SelfTypeEngine: + r"""Produce a copy of this type object that will utilize the given type when applied to the dialect of the given name. e.g.:: @@ -541,15 +550,21 @@ class TypeEngine(Traversible): from sqlalchemy.types import String from sqlalchemy.dialects import mysql - s = String() + string_type = String() + + string_type = string_type.with_variant( + mysql.VARCHAR(collation='foo'), 'mysql' + ) - s = s.with_variant(mysql.VARCHAR(collation='foo'), 'mysql') + The variant mapping indicates that when this type is + interpreted by a specific dialect, it will instead be + transmuted into the given type, rather than using the + primary type. - The construction of :meth:`.TypeEngine.with_variant` is always - from the "fallback" type to that which is dialect specific. - The returned type is an instance of :class:`.Variant`, which - itself provides a :meth:`.Variant.with_variant` - that can be called repeatedly. + .. versionchanged:: 2.0 the :meth:`_types.TypeEngine.with_variant` + method now works with a :class:`_types.TypeEngine` object "in + place", returning a copy of the original type rather than returning + a wrapping object; the ``Variant`` class is no longer used. :param type\_: a :class:`.TypeEngine` that will be selected as a variant from the originating type, when a dialect @@ -558,7 +573,24 @@ class TypeEngine(Traversible): this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) """ - return Variant(self, {dialect_name: to_instance(type_)}) + + if dialect_name in self._variant_mapping: + raise exc.ArgumentError( + "Dialect '%s' is already present in " + "the mapping for this %r" % (dialect_name, self) + ) + new_type = self.copy() + if isinstance(type_, type): + type_ = type_() + elif type_._variant_mapping: + raise exc.ArgumentError( + "can't pass a type that already has variants as a " + "dialect-level type to with_variant()" + ) + new_type._variant_mapping = self._variant_mapping.union( + {dialect_name: type_} + ) + return new_type @util.memoized_property def _type_affinity(self): @@ -735,7 +767,12 @@ class TypeEngine(Traversible): return d def _gen_dialect_impl(self, dialect): - return dialect.type_descriptor(self) + if dialect.name in self._variant_mapping: + return self._variant_mapping[dialect.name]._gen_dialect_impl( + dialect + ) + else: + return dialect.type_descriptor(self) @util.memoized_property def _static_cache_key(self): @@ -1361,7 +1398,12 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine): """ #todo """ - adapted = dialect.type_descriptor(self) + if dialect.name in self._variant_mapping: + adapted = dialect.type_descriptor( + self._variant_mapping[dialect.name] + ) + else: + adapted = dialect.type_descriptor(self) if adapted is not self: return adapted @@ -1818,98 +1860,17 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine): class Variant(TypeDecorator): - """A wrapping type that selects among a variety of - implementations based on dialect in use. - - The :class:`.Variant` type is typically constructed - using the :meth:`.TypeEngine.with_variant` method. - - .. seealso:: :meth:`.TypeEngine.with_variant` for an example of use. + """deprecated. symbol is present for backwards-compatibility with + workaround recipes, however this actual type should not be used. """ - cache_ok = True - - def __init__(self, base, mapping): - """Construct a new :class:`.Variant`. - - :param base: the base 'fallback' type - :param mapping: dictionary of string dialect names to - :class:`.TypeEngine` instances. - - """ - self.impl = base - self.mapping = mapping - - @util.memoized_property - def _static_cache_key(self): - # TODO: needs tests in test/sql/test_compare.py - return (self.__class__,) + ( - self.impl._static_cache_key, - tuple( - (key, self.mapping[key]._static_cache_key) - for key in sorted(self.mapping) - ), + def __init__(self, *arg, **kw): + raise NotImplementedError( + "Variant is no longer used in SQLAlchemy; this is a " + "placeholder symbol for backwards compatibility." ) - def coerce_compared_value(self, operator, value): - result = self.impl.coerce_compared_value(operator, value) - if result is self.impl: - return self - else: - return result - - def load_dialect_impl(self, dialect): - if dialect.name in self.mapping: - return self.mapping[dialect.name] - else: - return self.impl - - def _set_parent(self, column, outer=False, **kw): - """Support SchemaEventTarget""" - - if isinstance(self.impl, SchemaEventTarget): - self.impl._set_parent(column, **kw) - for impl in self.mapping.values(): - if isinstance(impl, SchemaEventTarget): - impl._set_parent(column, **kw) - - def _set_parent_with_dispatch(self, parent): - """Support SchemaEventTarget""" - - if isinstance(self.impl, SchemaEventTarget): - self.impl._set_parent_with_dispatch(parent) - for impl in self.mapping.values(): - if isinstance(impl, SchemaEventTarget): - impl._set_parent_with_dispatch(parent) - - def with_variant(self, type_, dialect_name): - r"""Return a new :class:`.Variant` which adds the given - type + dialect name to the mapping, in addition to the - mapping present in this :class:`.Variant`. - - :param type\_: a :class:`.TypeEngine` that will be selected - as a variant from the originating type, when a dialect - of the given name is in use. - :param dialect_name: base name of the dialect which uses - this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) - - """ - - if dialect_name in self.mapping: - raise exc.ArgumentError( - "Dialect '%s' is already present in " - "the mapping for this Variant" % dialect_name - ) - mapping = self.mapping.copy() - mapping[dialect_name] = type_ - return Variant(self.impl, mapping) - - @property - def comparator_factory(self): - """express comparison behavior in terms of the base type""" - return self.impl.comparator_factory - def _reconstitute_comparator(expression): return expression.comparator diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 2acf151958..795b538047 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -515,6 +515,10 @@ class AssertsCompiledSQL: if hasattr(test_statement, "_return_defaults"): self._return_defaults = test_statement._return_defaults + @property + def _variant_mapping(self): + return self.test_statement._variant_mapping + def _default_dialect(self): return self.test_statement._default_dialect() diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 5f8a41d1f8..a5797dc2f0 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -50,7 +50,6 @@ from sqlalchemy.orm import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes -from sqlalchemy.sql.type_api import Variant from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message @@ -779,24 +778,33 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): connection.execute(t1.insert(), {"data": "two"}) eq_(connection.scalar(select(t1.c.data)), "twoHITHERE") - def test_generic_w_pg_variant(self, metadata, connection): + @testing.combinations( + ( + Enum( + "one", + "two", + "three", + native_enum=True # make sure this is True because + # it should *not* take effect due to + # the variant + ).with_variant( + postgresql.ENUM("four", "five", "six", name="my_enum"), + "postgresql", + ) + ), + ( + String(50).with_variant( + postgresql.ENUM("four", "five", "six", name="my_enum"), + "postgresql", + ) + ), + argnames="datatype", + ) + def test_generic_w_pg_variant(self, metadata, connection, datatype): some_table = Table( "some_table", self.metadata, - Column( - "data", - Enum( - "one", - "two", - "three", - native_enum=True # make sure this is True because - # it should *not* take effect due to - # the variant - ).with_variant( - postgresql.ENUM("four", "five", "six", name="my_enum"), - "postgresql", - ), - ), + Column("data", datatype), ) assert "my_enum" not in [ @@ -2134,7 +2142,7 @@ class ArrayRoundTripTest: new_gen = gen(3) - if isinstance(table.c.bar.type, Variant): + if not table.c.bar.type._variant_mapping: # this is not likely to occur to users but we need to just # exercise this as far as we can expr = type_coerce(table.c.bar, ARRAY(type_))[1:3] diff --git a/test/sql/test_types.py b/test/sql/test_types.py index dc47cca46d..d930464a61 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -86,6 +86,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not from sqlalchemy.testing import mock from sqlalchemy.testing import pickleable +from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.schema import Table @@ -1515,15 +1516,26 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): assert_raises_message( exc.ArgumentError, "Dialect 'postgresql' is already present " - "in the mapping for this Variant", + "in the mapping for this UTypeOne()", lambda: v.with_variant(self.UTypeThree(), "postgresql"), ) + def test_no_variants_of_variants(self): + t = Integer().with_variant(Float(), "postgresql") + + with expect_raises_message( + exc.ArgumentError, + r"can't pass a type that already has variants as a " + r"dialect-level type to with_variant\(\)", + ): + String().with_variant(t, "mysql") + def test_compile(self): self.assert_compile(self.variant, "UTYPEONE", use_default_dialect=True) self.assert_compile( self.variant, "UTYPEONE", dialect=dialects.mysql.dialect() ) + self.assert_compile( self.variant, "UTYPETWO", dialect=dialects.postgresql.dialect() ) @@ -1535,6 +1547,27 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialects.postgresql.dialect(), ) + def test_typedec_gen_dialect_impl(self): + """test that gen_dialect_impl passes onto a TypeDecorator, as + TypeDecorator._gen_dialect_impl() itself has special behaviors. + + """ + + class MyDialectString(String): + pass + + class MyString(TypeDecorator): + impl = String + cache_ok = True + + def load_dialect_impl(self, dialect): + return MyDialectString() + + variant = String().with_variant(MyString(), "mysql") + + dialect_impl = variant._gen_dialect_impl(mysql.dialect()) + is_(dialect_impl.impl.__class__, MyDialectString) + def test_compile_composite(self): self.assert_compile( self.composite, "UTYPEONE", use_default_dialect=True @@ -1984,12 +2017,60 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): ) @testing.requires.enforces_check_constraints - @testing.provide_metadata - def test_variant_we_are_default(self): + def test_variant_default_is_not_schematype(self, metadata): + t = Table( + "my_table", + metadata, + Column( + "data", + String(50).with_variant( + Enum( + "four", + "five", + "six", + native_enum=False, + name="e2", + create_constraint=True, + ), + testing.db.dialect.name, + ), + ), + ) + + # the base String() didnt create a constraint or even do any + # events. But Column looked for SchemaType in _variant_mapping + # and found our type anyway. + eq_( + len([c for c in t.constraints if isinstance(c, CheckConstraint)]), + 1, + ) + + metadata.create_all(testing.db) + + # not using the connection fixture because we need to rollback and + # start again in the middle + with testing.db.connect() as connection: + # postgresql needs this in order to continue after the exception + trans = connection.begin() + assert_raises( + (exc.DBAPIError,), + connection.exec_driver_sql, + "insert into my_table (data) values('two')", + ) + trans.rollback() + + with connection.begin(): + connection.exec_driver_sql( + "insert into my_table (data) values ('four')" + ) + eq_(connection.execute(select(t.c.data)).scalar(), "four") + + @testing.requires.enforces_check_constraints + def test_variant_we_are_default(self, metadata): # test that the "variant" does not create a constraint t = Table( "my_table", - self.metadata, + metadata, Column( "data", Enum( @@ -2019,7 +2100,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): 2, ) - self.metadata.create_all(testing.db) + metadata.create_all(testing.db) # not using the connection fixture because we need to rollback and # start again in the middle @@ -2040,12 +2121,11 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): eq_(connection.execute(select(t.c.data)).scalar(), "two") @testing.requires.enforces_check_constraints - @testing.provide_metadata - def test_variant_we_are_not_default(self): + def test_variant_we_are_not_default(self, metadata): # test that the "variant" does not create a constraint t = Table( "my_table", - self.metadata, + metadata, Column( "data", Enum( @@ -2075,7 +2155,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): 2, ) - self.metadata.create_all(testing.db) + metadata.create_all(testing.db) # not using the connection fixture because we need to rollback and # start again in the middle