from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.sqltypes import ARRAY
from sqlalchemy.sql.type_api import TypeEngine
- from sqlalchemy.sql.type_api import Variant
from alembic.autogenerate.api import AutogenContext
from alembic.config import Config
return kwargs
-def _repr_type(type_: "TypeEngine", autogen_context: "AutogenContext") -> str:
+def _repr_type(
+ type_: "TypeEngine",
+ autogen_context: "AutogenContext",
+ _skip_variants: bool = False,
+) -> str:
rendered = _user_defined_render("type", type_, autogen_context)
if rendered is not False:
return rendered
elif impl_rt:
return impl_rt
elif mod.startswith("sqlalchemy."):
- if type(type_) is sqltypes.Variant:
+ if not _skip_variants and sqla_compat._type_has_variants(type_):
return _render_Variant_type(type_, autogen_context)
if "_render_%s_type" % type_.__visit_name__ in globals():
fn = globals()["_render_%s_type" % type_.__visit_name__]
def _render_Variant_type(
- type_: "Variant", autogen_context: "AutogenContext"
+ type_: "TypeEngine", autogen_context: "AutogenContext"
) -> str:
- base = _repr_type(type_.impl, autogen_context)
+ base_type, variant_mapping = sqla_compat._get_variant_mapping(type_)
+ base = _repr_type(base_type, autogen_context, _skip_variants=True)
assert base is not None and base is not False
- for dialect in sorted(type_.mapping):
- typ = type_.mapping[dialect]
+ for dialect in sorted(variant_mapping):
+ typ = variant_mapping[dialect]
base += ".with_variant(%s, %r)" % (
- _repr_type(typ, autogen_context),
+ _repr_type(typ, autogen_context, _skip_variants=True),
dialect,
)
return base
return inspector.reflecttable(table, None)
+if hasattr(sqltypes.TypeEngine, "_variant_mapping"):
+
+ def _type_has_variants(type_):
+ return bool(type_._variant_mapping)
+
+ def _get_variant_mapping(type_):
+ return type_, type_._variant_mapping
+
+
+else:
+
+ def _type_has_variants(type_):
+ return type(type_) is sqltypes.Variant
+
+ def _get_variant_mapping(type_):
+ return type_.impl, type_.mapping
+
+
def _fk_spec(constraint):
source_columns = [
constraint.columns[key].name for key in constraint.column_keys