return impl_rt
else:
return "%s.%r" % (dname, type_)
+ elif impl_rt:
+ return impl_rt
elif mod.startswith("sqlalchemy."):
if '_render_%s_type' % type_.__visit_name__ in globals():
fn = globals()['_render_%s_type' % type_.__visit_name__]
)
-def _render_type_w_subtype(type_, autogen_context, attrname, regexp):
+def _render_type_w_subtype(
+ type_, autogen_context, attrname, regexp, prefix=None):
outer_repr = repr(type_)
inner_type = getattr(type_, attrname, None)
if inner_type is None:
regexp + inner_repr,
r"\1%s" % sub_type, outer_repr)
+ if prefix:
+ return "%s%s" % (prefix, outer_type)
+
mod = type(type_).__module__
if mod.startswith("sqlalchemy.dialects"):
dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
--- /dev/null
+from alembic.testing.fixtures import TestBase
+from alembic.ddl import impl
+from alembic.testing import eq_ignore_whitespace
+from alembic import autogenerate
+from alembic.autogenerate import render
+from alembic.migration import MigrationContext
+from alembic.autogenerate import api
+from alembic.testing import eq_
+
+from sqlalchemy import types as sqla_types
+from sqlalchemy import MetaData
+from sqlalchemy.engine import default
+
+
+class CustomDialect(default.DefaultDialect):
+ name = "custom_dialect"
+
+try:
+ from sqlalchemy.dialects import registry
+except ImportError:
+ pass
+else:
+ registry.register("custom_dialect", __name__, "CustomDialect")
+
+
+class CustomDialectImpl(impl.DefaultImpl):
+ __dialect__ = 'custom_dialect'
+ transactional_ddl = False
+
+ def render_type(self, type_, autogen_context):
+ if type_.__module__ == __name__:
+ autogen_context.imports.add(
+ "from %s import custom_dialect_types" % (__name__, ))
+ is_external = True
+ else:
+ is_external = False
+
+ if is_external and \
+ hasattr(self, '_render_%s_type' % type_.__visit_name__):
+ meth = getattr(self, '_render_%s_type' % type_.__visit_name__)
+ return meth(type_, autogen_context)
+
+ if is_external:
+ return "%s.%r" % ("custom_dialect_types", type_)
+ else:
+ return None
+
+ def _render_EXT_ARRAY_type(self, type_, autogen_context):
+ return render._render_type_w_subtype(
+ type_, autogen_context, 'item_type', r'(.+?\()',
+ prefix="custom_dialect_types."
+ )
+
+
+class EXT_ARRAY(sqla_types.TypeEngine):
+ __visit_name__ = 'EXT_ARRAY'
+
+ def __init__(self, item_type):
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+ super(EXT_ARRAY, self).__init__()
+
+
+class FOOBARTYPE(sqla_types.TypeEngine):
+ __visit_name__ = 'FOOBARTYPE'
+
+
+class ExternalDialectRenderTest(TestBase):
+ __requires__ = ('sqlalchemy_09', )
+
+ def setUp(self):
+ ctx_opts = {
+ 'sqlalchemy_module_prefix': 'sa.',
+ 'alembic_module_prefix': 'op.',
+ 'target_metadata': MetaData(),
+ 'user_module_prefix': None
+ }
+ context = MigrationContext.configure(
+ dialect_name="custom_dialect",
+ opts=ctx_opts
+ )
+
+ self.autogen_context = api.AutogenContext(context)
+
+ def test_render_type(self):
+ eq_ignore_whitespace(
+ autogenerate.render._repr_type(
+ FOOBARTYPE(), self.autogen_context),
+ "custom_dialect_types.FOOBARTYPE()"
+ )
+
+ eq_(
+ self.autogen_context.imports,
+ set([
+ 'from tests.test_external_dialect import custom_dialect_types'
+ ])
+ )
+
+ def test_external_nested_render_sqla_type(self):
+
+ eq_ignore_whitespace(
+ autogenerate.render._repr_type(
+ EXT_ARRAY(sqla_types.Integer), self.autogen_context),
+ "custom_dialect_types.EXT_ARRAY(sa.Integer())"
+ )
+
+ eq_ignore_whitespace(
+ autogenerate.render._repr_type(
+ EXT_ARRAY(
+ sqla_types.DateTime(timezone=True)
+ ),
+ self.autogen_context),
+ "custom_dialect_types.EXT_ARRAY(sa.DateTime(timezone=True))"
+ )
+
+ eq_(
+ self.autogen_context.imports,
+ set([
+ 'from tests.test_external_dialect import custom_dialect_types'
+ ])
+ )
+
+ def test_external_nested_render_external_type(self):
+
+ eq_ignore_whitespace(
+ autogenerate.render._repr_type(
+ EXT_ARRAY(FOOBARTYPE),
+ self.autogen_context),
+ "custom_dialect_types.EXT_ARRAY(custom_dialect_types.FOOBARTYPE())"
+ )
+
+ eq_(
+ self.autogen_context.imports,
+ set([
+ 'from tests.test_external_dialect import custom_dialect_types'
+ ])
+ )