]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
replace Variant with direct feature inside of TypeEngine
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Oct 2021 15:26:45 +0000 (11:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 29 Dec 2021 16:43:53 +0000 (11:43 -0500)
The :meth:`_sqltypes.TypeEngine.with_variant` method now returns a copy of
the original :class:`_sqltypes.TypeEngine` object, rather than wrapping it
inside the ``Variant`` class, which is effectively removed (the import
symbol remains for backwards compatibility with code that may be testing
for this symbol). While the previous approach maintained in-Python
behaviors, maintaining the original type allows for clearer type checking
and debugging.

Fixes: #6980
Change-Id: I158c7e56306b886b5b82b040205c428a5c4a242c

doc/build/changelog/migration_20.rst
doc/build/changelog/unreleased_20/6980.rst [new file with mode: 0644]
doc/build/conf.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/assertions.py
test/dialect/postgresql/test_types.py
test/sql/test_types.py

index 72530142e8901b82881e7a85c0e7a6b63e021ba1..4c2ef22bdd9f5362dc7caca863aafb2e466278b5 100644 (file)
@@ -159,6 +159,42 @@ of the Cython requirement.
 
 .. _Cython: https://cython.org/
 
+.. _change_6980:
+
+"with_variant()" clones the original TypeEngine rather than changing the type
+-----------------------------------------------------------------------------
+
+The :meth:`_sqltypes.TypeEngine.with_variant` method, which is used to apply
+alternate per-database behaviors to a particular type, now returns a copy of
+the original :class:`_sqltypes.TypeEngine` object with the variant information
+stored internally, rather than wrapping it inside the ``Variant`` class.
+
+While the previous ``Variant`` approach was able to maintain all the in-Python
+behaviors of the original type using dynamic attribute getters, the improvement
+here is that when calling upon a variant, the returned type remains an instance
+of the original type, which works more smoothly with type checkers such as mypy
+and pylance.  Given a program as below::
+
+    import typing
+
+
+    from sqlalchemy import String
+    from sqlalchemy.dialects.mysql import VARCHAR
+
+
+    type_ = String(255).with_variant(VARCHAR(255, charset='utf8mb4'), "mysql")
+
+    if typing.TYPE_CHECKING:
+        reveal_type(type_)
+
+A type checker like pyright will now report the type as::
+
+    info: Type of "type_" is "String"
+
+
+:ticket:`6980`
+
+
 .. _change_4926:
 
 Python division operator performs true division for all backends; added floor division
diff --git a/doc/build/changelog/unreleased_20/6980.rst b/doc/build/changelog/unreleased_20/6980.rst
new file mode 100644 (file)
index 0000000..d83599c
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: improvement, typing
+    :tickets: 6980
+
+    The :meth:`_sqltypes.TypeEngine.with_variant` method now returns a copy of
+    the original :class:`_sqltypes.TypeEngine` object, rather than wrapping it
+    inside the ``Variant`` class, which is effectively removed (the import
+    symbol remains for backwards compatibility with code that may be testing
+    for this symbol). While the previous approach maintained in-Python
+    behaviors, maintaining the original type allows for clearer type checking
+    and debugging.
+
+    .. seealso::
+
+        :ref:`change_6980`
+
+
+
index 9e2d8e9de7c1fe26936b596c6a2c92a68367d1c8..cba0fee6ab9f9e9a63661e51902b4d9ad4a1b663 100644 (file)
@@ -63,6 +63,7 @@ changelog_sections = [
     "sql",
     "schema",
     "extensions",
+    "typing",
     "mypy",
     "asyncio",
     "postgresql",
@@ -76,6 +77,7 @@ changelog_sections = [
 # tags to sort on inside of sections
 changelog_inner_tag_sort = [
     "feature",
+    "improvement",
     "usecase",
     "change",
     "changed",
index a9dd6a23ae1d865682190db8ee0e31b37a9dbb70..90f32a4f7cecfbc3a70072faf9fdfa6677d205d4 100644 (file)
@@ -512,6 +512,11 @@ class TypeCompiler(metaclass=util.EnsureKWArgType):
         self.dialect = dialect
 
     def process(self, type_, **kw):
+        if (
+            type_._variant_mapping
+            and self.dialect.name in type_._variant_mapping
+        ):
+            type_ = type_._variant_mapping[self.dialect.name]
         return type_._compiler_dispatch(self, **kw)
 
     def visit_unsupported_compilation(self, element, err, **kw):
index 9b5005b5d839fedea93f904133028ff90c198160..dd8238450ce81201f2260ffe1e00224e4fdd9fb0 100644 (file)
@@ -1586,9 +1586,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
         # check if this Column is proxying another column
         if "_proxies" in kwargs:
             self._proxies = kwargs.pop("_proxies")
-        # otherwise, add DDL-related events
-        elif isinstance(self.type, SchemaEventTarget):
-            self.type._set_parent_with_dispatch(self)
+        else:
+            # otherwise, add DDL-related events
+            if isinstance(self.type, SchemaEventTarget):
+                self.type._set_parent_with_dispatch(self)
+            for impl in self.type._variant_mapping.values():
+                if isinstance(impl, SchemaEventTarget):
+                    impl._set_parent_with_dispatch(self)
 
         if self.default is not None:
             if isinstance(self.default, (ColumnDefault, Sequence)):
index 574692fed70c2068f3436152a42901abf62580ff..cda7b35cda052e8ded24298f99f4c87ab74f51b1 100644 (file)
@@ -33,7 +33,7 @@ from .type_api import NativeForEmulated  # noqa
 from .type_api import to_instance
 from .type_api import TypeDecorator
 from .type_api import TypeEngine
-from .type_api import Variant
+from .type_api import Variant  # noqa
 from .. import event
 from .. import exc
 from .. import inspection
@@ -844,12 +844,25 @@ class SchemaType(SchemaEventTarget):
             )
 
     def _set_parent(self, column, **kw):
+        # set parent hook is when this type is associated with a column.
+        # Column calls it for all SchemaEventTarget instances, either the
+        # base type and/or variants in _variant_mapping.
+
+        # we want to register a second hook to trigger when that column is
+        # associated with a table.  in that event, we and all of our variants
+        # may want to set up some state on the table such as a CheckConstraint
+        # that will conditionally render at DDL render time.
+
+        # the base SchemaType also sets up events for
+        # on_table/metadata_create/drop in this method, which is used by
+        # "native" types with a separate CREATE/DROP e.g. Postgresql.ENUM
+
         column._on_table_attach(util.portable_instancemethod(self._set_table))
 
     def _variant_mapping_for_set_table(self, column):
-        if isinstance(column.type, Variant):
-            variant_mapping = column.type.mapping.copy()
-            variant_mapping["_default"] = column.type.impl
+        if column.type._variant_mapping:
+            variant_mapping = dict(column.type._variant_mapping)
+            variant_mapping["_default"] = column.type
         else:
             variant_mapping = None
         return variant_mapping
@@ -880,8 +893,9 @@ class SchemaType(SchemaEventTarget):
             ),
         )
         if self.metadata is None:
-            # TODO: what's the difference between self.metadata
-            # and table.metadata here ?
+            # if SchemaType were created w/ a metadata argument, these
+            # events would already have been associated with that metadata
+            # and would preclude an association with table.metadata
             event.listen(
                 table.metadata,
                 "before_create",
@@ -963,9 +977,19 @@ class SchemaType(SchemaEventTarget):
 
     def _is_impl_for_variant(self, dialect, kw):
         variant_mapping = kw.pop("variant_mapping", None)
-        if variant_mapping is None:
+
+        if not variant_mapping:
             return True
 
+        # for types that have _variant_mapping, all the impls in the map
+        # that are SchemaEventTarget subclasses get set up as event holders.
+        # this is so that constructs that need
+        # to be associated with the Table at dialect-agnostic time etc. like
+        # CheckConstraints can be set up with that table.  they then add
+        # to these constraints a DDL check_rule that among other things
+        # will check this _is_impl_for_variant() method to determine when
+        # the dialect is known that we are part of the table's DDL sequence.
+
         # since PostgreSQL is the only DB that has ARRAY this can only
         # be integration tested by PG-specific tests
         def _we_are_the_impl(typ):
index cc226d7e371a9627940c7465d2e2609107aae1d0..07cd4d95fb3687570b62cccfcd99ed493894c5fe 100644 (file)
@@ -9,6 +9,7 @@
 
 """
 
+import typing
 
 from . import operators
 from .base import SchemaEventTarget
@@ -29,6 +30,10 @@ TABLEVALUE = None
 _resolve_value_to_type = None
 
 
+# replace with pep-673 when applicable
+SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine")
+
+
 class TypeEngine(Traversible):
     """The ultimate base class for all SQL datatypes.
 
@@ -192,6 +197,8 @@ class TypeEngine(Traversible):
 
     """
 
+    _variant_mapping = util.EMPTY_DICT
+
     def evaluates_none(self):
         """Return a copy of this type which has the :attr:`.should_evaluate_none`
         flag set to True.
@@ -532,8 +539,10 @@ class TypeEngine(Traversible):
         """
         raise NotImplementedError()
 
-    def with_variant(self, type_, dialect_name):
-        r"""Produce a new type object that will utilize the given
+    def with_variant(
+        self: SelfTypeEngine, type_: "TypeEngine", dialect_name: str
+    ) -> SelfTypeEngine:
+        r"""Produce a copy of this type object that will utilize the given
         type when applied to the dialect of the given name.
 
         e.g.::
@@ -541,15 +550,21 @@ class TypeEngine(Traversible):
             from sqlalchemy.types import String
             from sqlalchemy.dialects import mysql
 
-            s = String()
+            string_type = String()
+
+            string_type = string_type.with_variant(
+                mysql.VARCHAR(collation='foo'), 'mysql'
+            )
 
-            s = s.with_variant(mysql.VARCHAR(collation='foo'), 'mysql')
+        The variant mapping indicates that when this type is
+        interpreted by a specific dialect, it will instead be
+        transmuted into the given type, rather than using the
+        primary type.
 
-        The construction of :meth:`.TypeEngine.with_variant` is always
-        from the "fallback" type to that which is dialect specific.
-        The returned type is an instance of :class:`.Variant`, which
-        itself provides a :meth:`.Variant.with_variant`
-        that can be called repeatedly.
+        .. versionchanged:: 2.0 the :meth:`_types.TypeEngine.with_variant`
+           method now works with a :class:`_types.TypeEngine` object "in
+           place", returning a copy of the original type rather than returning
+           a wrapping object; the ``Variant`` class is no longer used.
 
         :param type\_: a :class:`.TypeEngine` that will be selected
          as a variant from the originating type, when a dialect
@@ -558,7 +573,24 @@ class TypeEngine(Traversible):
          this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.)
 
         """
-        return Variant(self, {dialect_name: to_instance(type_)})
+
+        if dialect_name in self._variant_mapping:
+            raise exc.ArgumentError(
+                "Dialect '%s' is already present in "
+                "the mapping for this %r" % (dialect_name, self)
+            )
+        new_type = self.copy()
+        if isinstance(type_, type):
+            type_ = type_()
+        elif type_._variant_mapping:
+            raise exc.ArgumentError(
+                "can't pass a type that already has variants as a "
+                "dialect-level type to with_variant()"
+            )
+        new_type._variant_mapping = self._variant_mapping.union(
+            {dialect_name: type_}
+        )
+        return new_type
 
     @util.memoized_property
     def _type_affinity(self):
@@ -735,7 +767,12 @@ class TypeEngine(Traversible):
             return d
 
     def _gen_dialect_impl(self, dialect):
-        return dialect.type_descriptor(self)
+        if dialect.name in self._variant_mapping:
+            return self._variant_mapping[dialect.name]._gen_dialect_impl(
+                dialect
+            )
+        else:
+            return dialect.type_descriptor(self)
 
     @util.memoized_property
     def _static_cache_key(self):
@@ -1361,7 +1398,12 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine):
         """
         #todo
         """
-        adapted = dialect.type_descriptor(self)
+        if dialect.name in self._variant_mapping:
+            adapted = dialect.type_descriptor(
+                self._variant_mapping[dialect.name]
+            )
+        else:
+            adapted = dialect.type_descriptor(self)
         if adapted is not self:
             return adapted
 
@@ -1818,98 +1860,17 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine):
 
 
 class Variant(TypeDecorator):
-    """A wrapping type that selects among a variety of
-    implementations based on dialect in use.
-
-    The :class:`.Variant` type is typically constructed
-    using the :meth:`.TypeEngine.with_variant` method.
-
-    .. seealso:: :meth:`.TypeEngine.with_variant` for an example of use.
+    """deprecated.  symbol is present for backwards-compatibility with
+    workaround recipes, however this actual type should not be used.
 
     """
 
-    cache_ok = True
-
-    def __init__(self, base, mapping):
-        """Construct a new :class:`.Variant`.
-
-        :param base: the base 'fallback' type
-        :param mapping: dictionary of string dialect names to
-          :class:`.TypeEngine` instances.
-
-        """
-        self.impl = base
-        self.mapping = mapping
-
-    @util.memoized_property
-    def _static_cache_key(self):
-        # TODO: needs tests in test/sql/test_compare.py
-        return (self.__class__,) + (
-            self.impl._static_cache_key,
-            tuple(
-                (key, self.mapping[key]._static_cache_key)
-                for key in sorted(self.mapping)
-            ),
+    def __init__(self, *arg, **kw):
+        raise NotImplementedError(
+            "Variant is no longer used in SQLAlchemy; this is a "
+            "placeholder symbol for backwards compatibility."
         )
 
-    def coerce_compared_value(self, operator, value):
-        result = self.impl.coerce_compared_value(operator, value)
-        if result is self.impl:
-            return self
-        else:
-            return result
-
-    def load_dialect_impl(self, dialect):
-        if dialect.name in self.mapping:
-            return self.mapping[dialect.name]
-        else:
-            return self.impl
-
-    def _set_parent(self, column, outer=False, **kw):
-        """Support SchemaEventTarget"""
-
-        if isinstance(self.impl, SchemaEventTarget):
-            self.impl._set_parent(column, **kw)
-        for impl in self.mapping.values():
-            if isinstance(impl, SchemaEventTarget):
-                impl._set_parent(column, **kw)
-
-    def _set_parent_with_dispatch(self, parent):
-        """Support SchemaEventTarget"""
-
-        if isinstance(self.impl, SchemaEventTarget):
-            self.impl._set_parent_with_dispatch(parent)
-        for impl in self.mapping.values():
-            if isinstance(impl, SchemaEventTarget):
-                impl._set_parent_with_dispatch(parent)
-
-    def with_variant(self, type_, dialect_name):
-        r"""Return a new :class:`.Variant` which adds the given
-        type + dialect name to the mapping, in addition to the
-        mapping present in this :class:`.Variant`.
-
-        :param type\_: a :class:`.TypeEngine` that will be selected
-         as a variant from the originating type, when a dialect
-         of the given name is in use.
-        :param dialect_name: base name of the dialect which uses
-         this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.)
-
-        """
-
-        if dialect_name in self.mapping:
-            raise exc.ArgumentError(
-                "Dialect '%s' is already present in "
-                "the mapping for this Variant" % dialect_name
-            )
-        mapping = self.mapping.copy()
-        mapping[dialect_name] = type_
-        return Variant(self.impl, mapping)
-
-    @property
-    def comparator_factory(self):
-        """express comparison behavior in terms of the base type"""
-        return self.impl.comparator_factory
-
 
 def _reconstitute_comparator(expression):
     return expression.comparator
index 2acf151958e6b857e808ef9725ce7e3c67b0ae13..795b5380472dd5d180ef97d5cb9cfc6952e5fec9 100644 (file)
@@ -515,6 +515,10 @@ class AssertsCompiledSQL:
                     if hasattr(test_statement, "_return_defaults"):
                         self._return_defaults = test_statement._return_defaults
 
+            @property
+            def _variant_mapping(self):
+                return self.test_statement._variant_mapping
+
             def _default_dialect(self):
                 return self.test_statement._default_dialect()
 
index 5f8a41d1f831ff8178dabd3de966f102f647803e..a5797dc2f03e4e160808fd5eb2df56a90af03de4 100644 (file)
@@ -50,7 +50,6 @@ from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
-from sqlalchemy.sql.type_api import Variant
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
@@ -779,24 +778,33 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         connection.execute(t1.insert(), {"data": "two"})
         eq_(connection.scalar(select(t1.c.data)), "twoHITHERE")
 
-    def test_generic_w_pg_variant(self, metadata, connection):
+    @testing.combinations(
+        (
+            Enum(
+                "one",
+                "two",
+                "three",
+                native_enum=True  # make sure this is True because
+                # it should *not* take effect due to
+                # the variant
+            ).with_variant(
+                postgresql.ENUM("four", "five", "six", name="my_enum"),
+                "postgresql",
+            )
+        ),
+        (
+            String(50).with_variant(
+                postgresql.ENUM("four", "five", "six", name="my_enum"),
+                "postgresql",
+            )
+        ),
+        argnames="datatype",
+    )
+    def test_generic_w_pg_variant(self, metadata, connection, datatype):
         some_table = Table(
             "some_table",
             self.metadata,
-            Column(
-                "data",
-                Enum(
-                    "one",
-                    "two",
-                    "three",
-                    native_enum=True  # make sure this is True because
-                    # it should *not* take effect due to
-                    # the variant
-                ).with_variant(
-                    postgresql.ENUM("four", "five", "six", name="my_enum"),
-                    "postgresql",
-                ),
-            ),
+            Column("data", datatype),
         )
 
         assert "my_enum" not in [
@@ -2134,7 +2142,7 @@ class ArrayRoundTripTest:
 
         new_gen = gen(3)
 
-        if isinstance(table.c.bar.type, Variant):
+        if not table.c.bar.type._variant_mapping:
             # this is not likely to occur to users but we need to just
             # exercise this as far as we can
             expr = type_coerce(table.c.bar, ARRAY(type_))[1:3]
index dc47cca46d2d52885c33112d906f6659dbe5ca67..d930464a61bbdf1528ee87d6400c193e93c906ec 100644 (file)
@@ -86,6 +86,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import pickleable
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.testing.schema import Table
@@ -1515,15 +1516,26 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
         assert_raises_message(
             exc.ArgumentError,
             "Dialect 'postgresql' is already present "
-            "in the mapping for this Variant",
+            "in the mapping for this UTypeOne()",
             lambda: v.with_variant(self.UTypeThree(), "postgresql"),
         )
 
+    def test_no_variants_of_variants(self):
+        t = Integer().with_variant(Float(), "postgresql")
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"can't pass a type that already has variants as a "
+            r"dialect-level type to with_variant\(\)",
+        ):
+            String().with_variant(t, "mysql")
+
     def test_compile(self):
         self.assert_compile(self.variant, "UTYPEONE", use_default_dialect=True)
         self.assert_compile(
             self.variant, "UTYPEONE", dialect=dialects.mysql.dialect()
         )
+
         self.assert_compile(
             self.variant, "UTYPETWO", dialect=dialects.postgresql.dialect()
         )
@@ -1535,6 +1547,27 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=dialects.postgresql.dialect(),
         )
 
+    def test_typedec_gen_dialect_impl(self):
+        """test that gen_dialect_impl passes onto a TypeDecorator, as
+        TypeDecorator._gen_dialect_impl() itself has special behaviors.
+
+        """
+
+        class MyDialectString(String):
+            pass
+
+        class MyString(TypeDecorator):
+            impl = String
+            cache_ok = True
+
+            def load_dialect_impl(self, dialect):
+                return MyDialectString()
+
+        variant = String().with_variant(MyString(), "mysql")
+
+        dialect_impl = variant._gen_dialect_impl(mysql.dialect())
+        is_(dialect_impl.impl.__class__, MyDialectString)
+
     def test_compile_composite(self):
         self.assert_compile(
             self.composite, "UTYPEONE", use_default_dialect=True
@@ -1984,12 +2017,60 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         )
 
     @testing.requires.enforces_check_constraints
-    @testing.provide_metadata
-    def test_variant_we_are_default(self):
+    def test_variant_default_is_not_schematype(self, metadata):
+        t = Table(
+            "my_table",
+            metadata,
+            Column(
+                "data",
+                String(50).with_variant(
+                    Enum(
+                        "four",
+                        "five",
+                        "six",
+                        native_enum=False,
+                        name="e2",
+                        create_constraint=True,
+                    ),
+                    testing.db.dialect.name,
+                ),
+            ),
+        )
+
+        # the base String() didnt create a constraint or even do any
+        # events.  But Column looked for SchemaType in _variant_mapping
+        # and found our type anyway.
+        eq_(
+            len([c for c in t.constraints if isinstance(c, CheckConstraint)]),
+            1,
+        )
+
+        metadata.create_all(testing.db)
+
+        # not using the connection fixture because we need to rollback and
+        # start again in the middle
+        with testing.db.connect() as connection:
+            # postgresql needs this in order to continue after the exception
+            trans = connection.begin()
+            assert_raises(
+                (exc.DBAPIError,),
+                connection.exec_driver_sql,
+                "insert into my_table (data) values('two')",
+            )
+            trans.rollback()
+
+            with connection.begin():
+                connection.exec_driver_sql(
+                    "insert into my_table (data) values ('four')"
+                )
+                eq_(connection.execute(select(t.c.data)).scalar(), "four")
+
+    @testing.requires.enforces_check_constraints
+    def test_variant_we_are_default(self, metadata):
         # test that the "variant" does not create a constraint
         t = Table(
             "my_table",
-            self.metadata,
+            metadata,
             Column(
                 "data",
                 Enum(
@@ -2019,7 +2100,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             2,
         )
 
-        self.metadata.create_all(testing.db)
+        metadata.create_all(testing.db)
 
         # not using the connection fixture because we need to rollback and
         # start again in the middle
@@ -2040,12 +2121,11 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
                 eq_(connection.execute(select(t.c.data)).scalar(), "two")
 
     @testing.requires.enforces_check_constraints
-    @testing.provide_metadata
-    def test_variant_we_are_not_default(self):
+    def test_variant_we_are_not_default(self, metadata):
         # test that the "variant" does not create a constraint
         t = Table(
             "my_table",
-            self.metadata,
+            metadata,
             Column(
                 "data",
                 Enum(
@@ -2075,7 +2155,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             2,
         )
 
-        self.metadata.create_all(testing.db)
+        metadata.create_all(testing.db)
 
         # not using the connection fixture because we need to rollback and
         # start again in the middle