.. _Cython: https://cython.org/
+.. _change_6980:
+
+"with_variant()" clones the original TypeEngine rather than changing the type
+-----------------------------------------------------------------------------
+
+The :meth:`_sqltypes.TypeEngine.with_variant` method, which is used to apply
+alternate per-database behaviors to a particular type, now returns a copy of
+the original :class:`_sqltypes.TypeEngine` object with the variant information
+stored internally, rather than wrapping it inside the ``Variant`` class.
+
+While the previous ``Variant`` approach was able to maintain all the in-Python
+behaviors of the original type using dynamic attribute getters, the improvement
+here is that when calling upon a variant, the returned type remains an instance
+of the original type, which works more smoothly with type checkers such as mypy
+and pylance. Given a program as below::
+
+ import typing
+
+
+ from sqlalchemy import String
+ from sqlalchemy.dialects.mysql import VARCHAR
+
+
+ type_ = String(255).with_variant(VARCHAR(255, charset='utf8mb4'), "mysql")
+
+ if typing.TYPE_CHECKING:
+ reveal_type(type_)
+
+A type checker like pyright will now report the type as::
+
+ info: Type of "type_" is "String"
+
+
+:ticket:`6980`
+
+
.. _change_4926:
Python division operator performs true division for all backends; added floor division
--- /dev/null
+.. change::
+ :tags: improvement, typing
+ :tickets: 6980
+
+ The :meth:`_sqltypes.TypeEngine.with_variant` method now returns a copy of
+ the original :class:`_sqltypes.TypeEngine` object, rather than wrapping it
+ inside the ``Variant`` class, which is effectively removed (the import
+ symbol remains for backwards compatibility with code that may be testing
+ for this symbol). While the previous approach maintained in-Python
+ behaviors, maintaining the original type allows for clearer type checking
+ and debugging.
+
+ .. seealso::
+
+ :ref:`change_6980`
+
+
+
"sql",
"schema",
"extensions",
+ "typing",
"mypy",
"asyncio",
"postgresql",
# tags to sort on inside of sections
changelog_inner_tag_sort = [
"feature",
+ "improvement",
"usecase",
"change",
"changed",
self.dialect = dialect
def process(self, type_, **kw):
+ if (
+ type_._variant_mapping
+ and self.dialect.name in type_._variant_mapping
+ ):
+ type_ = type_._variant_mapping[self.dialect.name]
return type_._compiler_dispatch(self, **kw)
def visit_unsupported_compilation(self, element, err, **kw):
# check if this Column is proxying another column
if "_proxies" in kwargs:
self._proxies = kwargs.pop("_proxies")
- # otherwise, add DDL-related events
- elif isinstance(self.type, SchemaEventTarget):
- self.type._set_parent_with_dispatch(self)
+ else:
+ # otherwise, add DDL-related events
+ if isinstance(self.type, SchemaEventTarget):
+ self.type._set_parent_with_dispatch(self)
+ for impl in self.type._variant_mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent_with_dispatch(self)
if self.default is not None:
if isinstance(self.default, (ColumnDefault, Sequence)):
from .type_api import to_instance
from .type_api import TypeDecorator
from .type_api import TypeEngine
-from .type_api import Variant
+from .type_api import Variant # noqa
from .. import event
from .. import exc
from .. import inspection
)
def _set_parent(self, column, **kw):
+ # set parent hook is when this type is associated with a column.
+ # Column calls it for all SchemaEventTarget instances, either the
+ # base type and/or variants in _variant_mapping.
+
+ # we want to register a second hook to trigger when that column is
+ # associated with a table. in that event, we and all of our variants
+ # may want to set up some state on the table such as a CheckConstraint
+ # that will conditionally render at DDL render time.
+
+ # the base SchemaType also sets up events for
+ # on_table/metadata_create/drop in this method, which is used by
+ # "native" types with a separate CREATE/DROP e.g. Postgresql.ENUM
+
column._on_table_attach(util.portable_instancemethod(self._set_table))
def _variant_mapping_for_set_table(self, column):
- if isinstance(column.type, Variant):
- variant_mapping = column.type.mapping.copy()
- variant_mapping["_default"] = column.type.impl
+ if column.type._variant_mapping:
+ variant_mapping = dict(column.type._variant_mapping)
+ variant_mapping["_default"] = column.type
else:
variant_mapping = None
return variant_mapping
),
)
if self.metadata is None:
- # TODO: what's the difference between self.metadata
- # and table.metadata here ?
+ # if SchemaType were created w/ a metadata argument, these
+ # events would already have been associated with that metadata
+ # and would preclude an association with table.metadata
event.listen(
table.metadata,
"before_create",
def _is_impl_for_variant(self, dialect, kw):
variant_mapping = kw.pop("variant_mapping", None)
- if variant_mapping is None:
+
+ if not variant_mapping:
return True
+ # for types that have _variant_mapping, all the impls in the map
+ # that are SchemaEventTarget subclasses get set up as event holders.
+ # this is so that constructs that need
+ # to be associated with the Table at dialect-agnostic time etc. like
+ # CheckConstraints can be set up with that table. they then add
+ # to these constraints a DDL check_rule that among other things
+ # will check this _is_impl_for_variant() method to determine when
+ # the dialect is known that we are part of the table's DDL sequence.
+
# since PostgreSQL is the only DB that has ARRAY this can only
# be integration tested by PG-specific tests
def _we_are_the_impl(typ):
"""
+import typing
from . import operators
from .base import SchemaEventTarget
_resolve_value_to_type = None
+# replace with pep-673 when applicable
+SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine")
+
+
class TypeEngine(Traversible):
"""The ultimate base class for all SQL datatypes.
"""
+ _variant_mapping = util.EMPTY_DICT
+
def evaluates_none(self):
"""Return a copy of this type which has the :attr:`.should_evaluate_none`
flag set to True.
"""
raise NotImplementedError()
- def with_variant(self, type_, dialect_name):
- r"""Produce a new type object that will utilize the given
+ def with_variant(
+ self: SelfTypeEngine, type_: "TypeEngine", dialect_name: str
+ ) -> SelfTypeEngine:
+ r"""Produce a copy of this type object that will utilize the given
type when applied to the dialect of the given name.
e.g.::
from sqlalchemy.types import String
from sqlalchemy.dialects import mysql
- s = String()
+ string_type = String()
+
+ string_type = string_type.with_variant(
+ mysql.VARCHAR(collation='foo'), 'mysql'
+ )
- s = s.with_variant(mysql.VARCHAR(collation='foo'), 'mysql')
+ The variant mapping indicates that when this type is
+ interpreted by a specific dialect, it will instead be
+ transmuted into the given type, rather than using the
+ primary type.
- The construction of :meth:`.TypeEngine.with_variant` is always
- from the "fallback" type to that which is dialect specific.
- The returned type is an instance of :class:`.Variant`, which
- itself provides a :meth:`.Variant.with_variant`
- that can be called repeatedly.
+ .. versionchanged:: 2.0 the :meth:`_types.TypeEngine.with_variant`
+ method now works with a :class:`_types.TypeEngine` object "in
+ place", returning a copy of the original type rather than returning
+ a wrapping object; the ``Variant`` class is no longer used.
:param type\_: a :class:`.TypeEngine` that will be selected
as a variant from the originating type, when a dialect
this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.)
"""
- return Variant(self, {dialect_name: to_instance(type_)})
+
+ if dialect_name in self._variant_mapping:
+ raise exc.ArgumentError(
+ "Dialect '%s' is already present in "
+ "the mapping for this %r" % (dialect_name, self)
+ )
+ new_type = self.copy()
+ if isinstance(type_, type):
+ type_ = type_()
+ elif type_._variant_mapping:
+ raise exc.ArgumentError(
+ "can't pass a type that already has variants as a "
+ "dialect-level type to with_variant()"
+ )
+ new_type._variant_mapping = self._variant_mapping.union(
+ {dialect_name: type_}
+ )
+ return new_type
@util.memoized_property
def _type_affinity(self):
return d
def _gen_dialect_impl(self, dialect):
- return dialect.type_descriptor(self)
+ if dialect.name in self._variant_mapping:
+ return self._variant_mapping[dialect.name]._gen_dialect_impl(
+ dialect
+ )
+ else:
+ return dialect.type_descriptor(self)
@util.memoized_property
def _static_cache_key(self):
"""
#todo
"""
- adapted = dialect.type_descriptor(self)
+ if dialect.name in self._variant_mapping:
+ adapted = dialect.type_descriptor(
+ self._variant_mapping[dialect.name]
+ )
+ else:
+ adapted = dialect.type_descriptor(self)
if adapted is not self:
return adapted
class Variant(TypeDecorator):
- """A wrapping type that selects among a variety of
- implementations based on dialect in use.
-
- The :class:`.Variant` type is typically constructed
- using the :meth:`.TypeEngine.with_variant` method.
-
- .. seealso:: :meth:`.TypeEngine.with_variant` for an example of use.
+ """deprecated. symbol is present for backwards-compatibility with
+ workaround recipes, however this actual type should not be used.
"""
- cache_ok = True
-
- def __init__(self, base, mapping):
- """Construct a new :class:`.Variant`.
-
- :param base: the base 'fallback' type
- :param mapping: dictionary of string dialect names to
- :class:`.TypeEngine` instances.
-
- """
- self.impl = base
- self.mapping = mapping
-
- @util.memoized_property
- def _static_cache_key(self):
- # TODO: needs tests in test/sql/test_compare.py
- return (self.__class__,) + (
- self.impl._static_cache_key,
- tuple(
- (key, self.mapping[key]._static_cache_key)
- for key in sorted(self.mapping)
- ),
+ def __init__(self, *arg, **kw):
+ raise NotImplementedError(
+ "Variant is no longer used in SQLAlchemy; this is a "
+ "placeholder symbol for backwards compatibility."
)
- def coerce_compared_value(self, operator, value):
- result = self.impl.coerce_compared_value(operator, value)
- if result is self.impl:
- return self
- else:
- return result
-
- def load_dialect_impl(self, dialect):
- if dialect.name in self.mapping:
- return self.mapping[dialect.name]
- else:
- return self.impl
-
- def _set_parent(self, column, outer=False, **kw):
- """Support SchemaEventTarget"""
-
- if isinstance(self.impl, SchemaEventTarget):
- self.impl._set_parent(column, **kw)
- for impl in self.mapping.values():
- if isinstance(impl, SchemaEventTarget):
- impl._set_parent(column, **kw)
-
- def _set_parent_with_dispatch(self, parent):
- """Support SchemaEventTarget"""
-
- if isinstance(self.impl, SchemaEventTarget):
- self.impl._set_parent_with_dispatch(parent)
- for impl in self.mapping.values():
- if isinstance(impl, SchemaEventTarget):
- impl._set_parent_with_dispatch(parent)
-
- def with_variant(self, type_, dialect_name):
- r"""Return a new :class:`.Variant` which adds the given
- type + dialect name to the mapping, in addition to the
- mapping present in this :class:`.Variant`.
-
- :param type\_: a :class:`.TypeEngine` that will be selected
- as a variant from the originating type, when a dialect
- of the given name is in use.
- :param dialect_name: base name of the dialect which uses
- this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.)
-
- """
-
- if dialect_name in self.mapping:
- raise exc.ArgumentError(
- "Dialect '%s' is already present in "
- "the mapping for this Variant" % dialect_name
- )
- mapping = self.mapping.copy()
- mapping[dialect_name] = type_
- return Variant(self.impl, mapping)
-
- @property
- def comparator_factory(self):
- """express comparison behavior in terms of the base type"""
- return self.impl.comparator_factory
-
def _reconstitute_comparator(expression):
return expression.comparator
if hasattr(test_statement, "_return_defaults"):
self._return_defaults = test_statement._return_defaults
+ @property
+ def _variant_mapping(self):
+ return self.test_statement._variant_mapping
+
def _default_dialect(self):
return self.test_statement._default_dialect()
from sqlalchemy.orm import Session
from sqlalchemy.sql import operators
from sqlalchemy.sql import sqltypes
-from sqlalchemy.sql.type_api import Variant
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
connection.execute(t1.insert(), {"data": "two"})
eq_(connection.scalar(select(t1.c.data)), "twoHITHERE")
- def test_generic_w_pg_variant(self, metadata, connection):
+ @testing.combinations(
+ (
+ Enum(
+ "one",
+ "two",
+ "three",
+ native_enum=True # make sure this is True because
+ # it should *not* take effect due to
+ # the variant
+ ).with_variant(
+ postgresql.ENUM("four", "five", "six", name="my_enum"),
+ "postgresql",
+ )
+ ),
+ (
+ String(50).with_variant(
+ postgresql.ENUM("four", "five", "six", name="my_enum"),
+ "postgresql",
+ )
+ ),
+ argnames="datatype",
+ )
+ def test_generic_w_pg_variant(self, metadata, connection, datatype):
some_table = Table(
"some_table",
self.metadata,
- Column(
- "data",
- Enum(
- "one",
- "two",
- "three",
- native_enum=True # make sure this is True because
- # it should *not* take effect due to
- # the variant
- ).with_variant(
- postgresql.ENUM("four", "five", "six", name="my_enum"),
- "postgresql",
- ),
- ),
+ Column("data", datatype),
)
assert "my_enum" not in [
new_gen = gen(3)
- if isinstance(table.c.bar.type, Variant):
+ if not table.c.bar.type._variant_mapping:
# this is not likely to occur to users but we need to just
# exercise this as far as we can
expr = type_coerce(table.c.bar, ARRAY(type_))[1:3]
from sqlalchemy.testing import is_not
from sqlalchemy.testing import mock
from sqlalchemy.testing import pickleable
+from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import pep435_enum
from sqlalchemy.testing.schema import Table
assert_raises_message(
exc.ArgumentError,
"Dialect 'postgresql' is already present "
- "in the mapping for this Variant",
+ "in the mapping for this UTypeOne()",
lambda: v.with_variant(self.UTypeThree(), "postgresql"),
)
+ def test_no_variants_of_variants(self):
+ t = Integer().with_variant(Float(), "postgresql")
+
+ with expect_raises_message(
+ exc.ArgumentError,
+ r"can't pass a type that already has variants as a "
+ r"dialect-level type to with_variant\(\)",
+ ):
+ String().with_variant(t, "mysql")
+
def test_compile(self):
self.assert_compile(self.variant, "UTYPEONE", use_default_dialect=True)
self.assert_compile(
self.variant, "UTYPEONE", dialect=dialects.mysql.dialect()
)
+
self.assert_compile(
self.variant, "UTYPETWO", dialect=dialects.postgresql.dialect()
)
dialect=dialects.postgresql.dialect(),
)
+ def test_typedec_gen_dialect_impl(self):
+ """test that gen_dialect_impl passes onto a TypeDecorator, as
+ TypeDecorator._gen_dialect_impl() itself has special behaviors.
+
+ """
+
+ class MyDialectString(String):
+ pass
+
+ class MyString(TypeDecorator):
+ impl = String
+ cache_ok = True
+
+ def load_dialect_impl(self, dialect):
+ return MyDialectString()
+
+ variant = String().with_variant(MyString(), "mysql")
+
+ dialect_impl = variant._gen_dialect_impl(mysql.dialect())
+ is_(dialect_impl.impl.__class__, MyDialectString)
+
def test_compile_composite(self):
self.assert_compile(
self.composite, "UTYPEONE", use_default_dialect=True
)
@testing.requires.enforces_check_constraints
- @testing.provide_metadata
- def test_variant_we_are_default(self):
+ def test_variant_default_is_not_schematype(self, metadata):
+ t = Table(
+ "my_table",
+ metadata,
+ Column(
+ "data",
+ String(50).with_variant(
+ Enum(
+ "four",
+ "five",
+ "six",
+ native_enum=False,
+ name="e2",
+ create_constraint=True,
+ ),
+ testing.db.dialect.name,
+ ),
+ ),
+ )
+
+ # the base String() didnt create a constraint or even do any
+ # events. But Column looked for SchemaType in _variant_mapping
+ # and found our type anyway.
+ eq_(
+ len([c for c in t.constraints if isinstance(c, CheckConstraint)]),
+ 1,
+ )
+
+ metadata.create_all(testing.db)
+
+ # not using the connection fixture because we need to rollback and
+ # start again in the middle
+ with testing.db.connect() as connection:
+ # postgresql needs this in order to continue after the exception
+ trans = connection.begin()
+ assert_raises(
+ (exc.DBAPIError,),
+ connection.exec_driver_sql,
+ "insert into my_table (data) values('two')",
+ )
+ trans.rollback()
+
+ with connection.begin():
+ connection.exec_driver_sql(
+ "insert into my_table (data) values ('four')"
+ )
+ eq_(connection.execute(select(t.c.data)).scalar(), "four")
+
+ @testing.requires.enforces_check_constraints
+ def test_variant_we_are_default(self, metadata):
# test that the "variant" does not create a constraint
t = Table(
"my_table",
- self.metadata,
+ metadata,
Column(
"data",
Enum(
2,
)
- self.metadata.create_all(testing.db)
+ metadata.create_all(testing.db)
# not using the connection fixture because we need to rollback and
# start again in the middle
eq_(connection.execute(select(t.c.data)).scalar(), "two")
@testing.requires.enforces_check_constraints
- @testing.provide_metadata
- def test_variant_we_are_not_default(self):
+ def test_variant_we_are_not_default(self, metadata):
# test that the "variant" does not create a constraint
t = Table(
"my_table",
- self.metadata,
+ metadata,
Column(
"data",
Enum(
2,
)
- self.metadata.create_all(testing.db)
+ metadata.create_all(testing.db)
# not using the connection fixture because we need to rollback and
# start again in the middle