]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- rework all of autogenerate to build directly on alembic.operations.ops
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Jul 2015 23:00:55 +0000 (19:00 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Jul 2015 23:00:55 +0000 (19:00 -0400)
objects; the "diffs" is now a legacy system that is exported from
the ops.  A new model of comparison/rendering/ upgrade/downgrade
composition that is cleaner and much more extensible is introduced.
- autogenerate is now extensible as far as database objects compared
and rendered into scripts; any new operation directive can also be
registered into a series of hooks that allow custom database/model
comparison functions to run as well as to render new operation
directives into autogenerate scripts.
- write all new docs for the new system
fixes #306

19 files changed:
alembic/autogenerate/__init__.py
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/compose.py [deleted file]
alembic/autogenerate/generate.py [deleted file]
alembic/autogenerate/render.py
alembic/operations/ops.py
alembic/runtime/migration.py
alembic/util/langhelpers.py
docs/build/api/autogenerate.rst
docs/build/api/operations.rst
docs/build/changelog.rst
tests/_autogen_fixtures.py
tests/test_autogen_composition.py
tests/test_autogen_diffs.py
tests/test_autogen_fks.py
tests/test_autogen_indexes.py
tests/test_autogen_render.py
tests/test_postgresql.py

index 4272a7ed15f7845584a9f088645eed39ca4884aa..78520a8567c7a0f5444d549df6a14b7a45271830 100644 (file)
@@ -1,7 +1,7 @@
 from .api import ( # noqa
     compare_metadata, _render_migration_diffs,
-    produce_migrations, render_python_code
+    produce_migrations, render_python_code,
+    RevisionContext
     )
-from .compare import _produce_net_changes  # noqa
-from .generate import RevisionContext  # noqa
+from .compare import _produce_net_changes, comparators  # noqa
 from .render import render_op_text, renderers  # noqa
\ No newline at end of file
index cff977bbcf9530df4ff8c935239c7c169e9c6ccf..e9af4cfdae67bf0551f55f42152fee3fcffb4a35 100644 (file)
@@ -4,8 +4,9 @@ automatically."""
 from ..operations import ops
 from . import render
 from . import compare
-from . import compose
 from .. import util
+from sqlalchemy.engine.reflection import Inspector
+import contextlib
 
 
 def compare_metadata(context, metadata):
@@ -98,20 +99,8 @@ def compare_metadata(context, metadata):
 
     """
 
-    autogen_context = _autogen_context(context, metadata=metadata)
-
-    # as_sql=True is nonsensical here. autogenerate requires a connection
-    # it can use to run queries against to get the database schema.
-    if context.as_sql:
-        raise util.CommandError(
-            "autogenerate can't use as_sql=True as it prevents querying "
-            "the database for schema information")
-
-    diffs = []
-
-    compare._produce_net_changes(autogen_context, diffs)
-
-    return diffs
+    migration_script = produce_migrations(context, metadata)
+    return migration_script.upgrade_ops.as_diffs()
 
 
 def produce_migrations(context, metadata):
@@ -132,10 +121,7 @@ def produce_migrations(context, metadata):
 
     """
 
-    autogen_context = _autogen_context(context, metadata=metadata)
-    diffs = []
-
-    compare._produce_net_changes(autogen_context, diffs)
+    autogen_context = AutogenContext(context, metadata=metadata)
 
     migration_script = ops.MigrationScript(
         rev_id=None,
@@ -143,7 +129,7 @@ def produce_migrations(context, metadata):
         downgrade_ops=ops.DowngradeOps([]),
     )
 
-    compose._to_migration_script(autogen_context, migration_script, diffs)
+    compare._populate_migration_script(autogen_context, migration_script)
 
     return migration_script
 
@@ -152,6 +138,7 @@ def render_python_code(
     up_or_down_op,
     sqlalchemy_module_prefix='sa.',
     alembic_module_prefix='op.',
+    render_as_batch=False,
     imports=(),
     render_item=None,
 ):
@@ -162,84 +149,239 @@ def render_python_code(
     autogenerate output of a user-defined :class:`.MigrationScript` structure.
 
     """
-    autogen_context = {
-        'opts': {
-            'sqlalchemy_module_prefix': sqlalchemy_module_prefix,
-            'alembic_module_prefix': alembic_module_prefix,
-            'render_item': render_item,
-        },
-        'imports': set(imports)
+    opts = {
+        'sqlalchemy_module_prefix': sqlalchemy_module_prefix,
+        'alembic_module_prefix': alembic_module_prefix,
+        'render_item': render_item,
+        'render_as_batch': render_as_batch,
     }
+
+    autogen_context = AutogenContext(None, opts=opts)
+    autogen_context.imports = set(imports)
     return render._indent(render._render_cmd_body(
         up_or_down_op, autogen_context))
 
 
-
-
-def _render_migration_diffs(context, template_args, imports):
+def _render_migration_diffs(context, template_args):
     """legacy, used by test_autogen_composition at the moment"""
 
-    migration_script = produce_migrations(context, None)
-
-    autogen_context = _autogen_context(context, imports=imports)
-    diffs = []
+    autogen_context = AutogenContext(context)
 
-    compare._produce_net_changes(autogen_context, diffs)
+    upgrade_ops = ops.UpgradeOps([])
+    compare._produce_net_changes(autogen_context, upgrade_ops)
 
     migration_script = ops.MigrationScript(
         rev_id=None,
-        imports=imports,
-        upgrade_ops=ops.UpgradeOps([]),
-        downgrade_ops=ops.DowngradeOps([]),
+        upgrade_ops=upgrade_ops,
+        downgrade_ops=upgrade_ops.reverse(),
     )
 
-    compose._to_migration_script(autogen_context, migration_script, diffs)
-
     render._render_migration_script(
         autogen_context, migration_script, template_args
     )
 
 
-def _autogen_context(
-    context, imports=None, metadata=None, include_symbol=None,
-        include_object=None, include_schemas=False):
-
-    opts = context.opts
-    metadata = opts['target_metadata'] if metadata is None else metadata
-    include_schemas = opts.get('include_schemas', include_schemas)
-
-    include_symbol = opts.get('include_symbol', include_symbol)
-    include_object = opts.get('include_object', include_object)
-
-    object_filters = []
-    if include_symbol:
-        def include_symbol_filter(object, name, type_, reflected, compare_to):
-            if type_ == "table":
-                return include_symbol(name, object.schema)
-            else:
-                return True
-        object_filters.append(include_symbol_filter)
-    if include_object:
-        object_filters.append(include_object)
-
-    if metadata is None:
-        raise util.CommandError(
-            "Can't proceed with --autogenerate option; environment "
-            "script %s does not provide "
-            "a MetaData object to the context." % (
-                context.script.env_py_location
-            ))
-
-    opts = context.opts
-    connection = context.bind
-    return {
-        'imports': imports if imports is not None else set(),
-        'connection': connection,
-        'dialect': connection.dialect,
-        'context': context,
-        'opts': opts,
-        'metadata': metadata,
-        'object_filters': object_filters,
-        'include_schemas': include_schemas
-    }
+class AutogenContext(object):
+    """Maintains configuration and state that's specific to an
+    autogenerate operation."""
+
+    metadata = None
+    """The :class:`~sqlalchemy.schema.MetaData` object
+    representing the destination.
+
+    This object is the one that is passed within ``env.py``
+    to the :paramref:`.EnvironmentContext.configure.target_metadata`
+    parameter.  It represents the structure of :class:`.Table` and other
+    objects as stated in the current database model, and represents the
+    destination structure for the database being examined.
+
+    While the :class:`~sqlalchemy.schema.MetaData` object is primarily
+    known as a collection of :class:`~sqlalchemy.schema.Table` objects,
+    it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary
+    that may be used by end-user schemes to store additional schema-level
+    objects that are to be compared in custom autogeneration schemes.
+
+    """
+
+    connection = None
+    """The :class:`~sqlalchemy.engine.base.Connection` object currently
+    connected to the database backend being compared.
+
+    This is obtained from the :attr:`.MigrationContext.bind` and is
+    utimately set up in the ``env.py`` script.
+
+    """
+
+    dialect = None
+    """The :class:`~sqlalchemy.engine.Dialect` object currently in use.
+
+    This is normally obtained from the
+    :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute.
+
+    """
+
+    migration_context = None
+    """The :class:`.MigrationContext` established by the ``env.py`` script."""
+
+    def __init__(self, migration_context, metadata=None, opts=None):
+
+        if migration_context is not None and migration_context.as_sql:
+            raise util.CommandError(
+                "autogenerate can't use as_sql=True as it prevents querying "
+                "the database for schema information")
+
+        if opts is None:
+            opts = migration_context.opts
+        self.metadata = metadata = opts.get('target_metadata', None) \
+            if metadata is None else metadata
+
+        if metadata is None and \
+                migration_context is not None and \
+                migration_context.script is not None:
+            raise util.CommandError(
+                "Can't proceed with --autogenerate option; environment "
+                "script %s does not provide "
+                "a MetaData object to the context." % (
+                    migration_context.script.env_py_location
+                ))
+
+        include_symbol = opts.get('include_symbol', None)
+        include_object = opts.get('include_object', None)
+
+        object_filters = []
+        if include_symbol:
+            def include_symbol_filter(
+                    object, name, type_, reflected, compare_to):
+                if type_ == "table":
+                    return include_symbol(name, object.schema)
+                else:
+                    return True
+            object_filters.append(include_symbol_filter)
+        if include_object:
+            object_filters.append(include_object)
+
+        self._object_filters = object_filters
+
+        self.migration_context = migration_context
+        if self.migration_context is not None:
+            self.connection = self.migration_context.bind
+            self.dialect = self.migration_context.dialect
+
+        self._imports = set()
+        self.opts = opts
+        self._has_batch = False
+
+    @util.memoized_property
+    def inspector(self):
+        return Inspector.from_engine(self.connection)
+
+    @contextlib.contextmanager
+    def _within_batch(self):
+        self._has_batch = True
+        yield
+        self._has_batch = False
+
+    def run_filters(self, object_, name, type_, reflected, compare_to):
+        """Run the context's object filters and return True if the targets
+        should be part of the autogenerate operation.
+
+        This method should be run for every kind of object encountered within
+        an autogenerate operation, giving the environment the chance
+        to filter what objects should be included in the comparison.
+        The filters here are produced directly via the
+        :paramref:`.EnvironmentContext.configure.include_object`
+        and :paramref:`.EnvironmentContext.configure.include_symbol`
+        functions, if present.
+
+        """
+        for fn in self._object_filters:
+            if not fn(object_, name, type_, reflected, compare_to):
+                return False
+        else:
+            return True
+
+
+class RevisionContext(object):
+    """Maintains configuration and state that's specific to a revision
+    file generation operation."""
+
+    def __init__(self, config, script_directory, command_args):
+        self.config = config
+        self.script_directory = script_directory
+        self.command_args = command_args
+        self.template_args = {
+            'config': config  # Let templates use config for
+                              # e.g. multiple databases
+        }
+        self.generated_revisions = [
+            self._default_revision()
+        ]
+
+    def _to_script(self, migration_script):
+        template_args = {}
+        for k, v in self.template_args.items():
+            template_args.setdefault(k, v)
+
+        if migration_script._autogen_context is not None:
+            render._render_migration_script(
+                migration_script._autogen_context, migration_script,
+                template_args
+            )
+
+        return self.script_directory.generate_revision(
+            migration_script.rev_id,
+            migration_script.message,
+            refresh=True,
+            head=migration_script.head,
+            splice=migration_script.splice,
+            branch_labels=migration_script.branch_label,
+            version_path=migration_script.version_path,
+            **template_args)
+
+    def run_autogenerate(self, rev, context):
+        if self.command_args['sql']:
+            raise util.CommandError(
+                "Using --sql with --autogenerate does not make any sense")
+        if set(self.script_directory.get_revisions(rev)) != \
+                set(self.script_directory.get_revisions("heads")):
+            raise util.CommandError("Target database is not up to date.")
+
+        autogen_context = AutogenContext(context)
+
+        migration_script = self.generated_revisions[0]
+
+        compare._populate_migration_script(autogen_context, migration_script)
+
+        hook = context.opts.get('process_revision_directives', None)
+        if hook:
+            hook(context, rev, self.generated_revisions)
+
+        for migration_script in self.generated_revisions:
+            migration_script._autogen_context = autogen_context
+
+    def run_no_autogenerate(self, rev, context):
+        hook = context.opts.get('process_revision_directives', None)
+        if hook:
+            hook(context, rev, self.generated_revisions)
+
+        for migration_script in self.generated_revisions:
+            migration_script._autogen_context = None
+
+    def _default_revision(self):
+        op = ops.MigrationScript(
+            rev_id=self.command_args['rev_id'] or util.rev_id(),
+            message=self.command_args['message'],
+            imports=set(),
+            upgrade_ops=ops.UpgradeOps([]),
+            downgrade_ops=ops.DowngradeOps([]),
+            head=self.command_args['head'],
+            splice=self.command_args['splice'],
+            branch_label=self.command_args['branch_label'],
+            version_path=self.command_args['version_path']
+        )
+        op._autogen_context = None
+        return op
 
+    def generate_scripts(self):
+        for generated_revision in self.generated_revisions:
+            yield self._to_script(generated_revision)
index cd6b6965a11ace7de95255a7933f789eda440c52..fdc3cae32d078b8b785b48be8c8bff6e34036ac7 100644 (file)
@@ -1,7 +1,9 @@
 from sqlalchemy import schema as sa_schema, types as sqltypes
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy import event
+from ..operations import ops
 import logging
+from .. import util
 from ..util import compat
 from ..util import sqla_compat
 from sqlalchemy.util import OrderedSet
@@ -13,15 +15,20 @@ from alembic.ddl.base import _fk_spec
 log = logging.getLogger(__name__)
 
 
-def _produce_net_changes(autogen_context, diffs):
+def _populate_migration_script(autogen_context, migration_script):
+    _produce_net_changes(autogen_context, migration_script.upgrade_ops)
+    migration_script.upgrade_ops.reverse_into(migration_script.downgrade_ops)
 
-    metadata = autogen_context['metadata']
-    connection = autogen_context['connection']
-    object_filters = autogen_context.get('object_filters', ())
-    include_schemas = autogen_context.get('include_schemas', False)
+
+comparators = util.Dispatcher(uselist=True)
+
+
+def _produce_net_changes(autogen_context, upgrade_ops):
+
+    connection = autogen_context.connection
+    include_schemas = autogen_context.opts.get('include_schemas', False)
 
     inspector = Inspector.from_engine(connection)
-    conn_table_names = set()
 
     default_schema = connection.dialect.default_schema_name
     if include_schemas:
@@ -34,14 +41,28 @@ def _produce_net_changes(autogen_context, diffs):
     else:
         schemas = [None]
 
-    version_table_schema = autogen_context['context'].version_table_schema
-    version_table = autogen_context['context'].version_table
+    comparators.dispatch("schema", autogen_context.dialect.name)(
+        autogen_context, upgrade_ops, schemas
+    )
+
+
+@comparators.dispatch_for("schema")
+def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
+    inspector = autogen_context.inspector
+
+    metadata = autogen_context.metadata
+
+    conn_table_names = set()
+
+    version_table_schema = \
+        autogen_context.migration_context.version_table_schema
+    version_table = autogen_context.migration_context.version_table
 
     for s in schemas:
         tables = set(inspector.get_table_names(schema=s))
         if s == version_table_schema:
             tables = tables.difference(
-                [autogen_context['context'].version_table]
+                [autogen_context.migration_context.version_table]
             )
         conn_table_names.update(zip([s] * len(tables), tables))
 
@@ -50,21 +71,11 @@ def _produce_net_changes(autogen_context, diffs):
     ).difference([(version_table_schema, version_table)])
 
     _compare_tables(conn_table_names, metadata_table_names,
-                    object_filters,
-                    inspector, metadata, diffs, autogen_context)
-
-
-def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
-    for fn in object_filters:
-        if not fn(object_, name, type_, reflected, compare_to):
-            return False
-    else:
-        return True
+                    inspector, metadata, upgrade_ops, autogen_context)
 
 
 def _compare_tables(conn_table_names, metadata_table_names,
-                    object_filters,
-                    inspector, metadata, diffs, autogen_context):
+                    inspector, metadata, upgrade_ops, autogen_context):
 
     default_schema = inspector.bind.dialect.default_schema_name
 
@@ -95,14 +106,19 @@ def _compare_tables(conn_table_names, metadata_table_names,
     for s, tname in metadata_table_names.difference(conn_table_names):
         name = '%s.%s' % (s, tname) if s else tname
         metadata_table = tname_to_table[(s, tname)]
-        if _run_filters(
-                metadata_table, tname, "table", False, None, object_filters):
-            diffs.append(("add_table", metadata_table))
+        if autogen_context.run_filters(
+                metadata_table, tname, "table", False, None):
+            upgrade_ops.ops.append(
+                ops.CreateTableOp.from_table(metadata_table))
             log.info("Detected added table %r", name)
-            _compare_indexes_and_uniques(s, tname, object_filters,
-                                         None,
-                                         metadata_table,
-                                         diffs, autogen_context, inspector)
+            modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
+
+            comparators.dispatch("table")(
+                autogen_context, modify_table_ops,
+                s, tname, None, metadata_table
+            )
+            if not modify_table_ops.is_empty():
+                upgrade_ops.ops.append(modify_table_ops)
 
     removal_metadata = sa_schema.MetaData()
     for s, tname in conn_table_names.difference(metadata_table_names):
@@ -114,11 +130,13 @@ def _compare_tables(conn_table_names, metadata_table_names,
             event.listen(
                 t,
                 "column_reflect",
-                autogen_context['context'].impl.
+                autogen_context.migration_context.impl.
                 _compat_autogen_column_reflect(inspector))
             inspector.reflecttable(t, None)
-        if _run_filters(t, tname, "table", True, None, object_filters):
-            diffs.append(("remove_table", t))
+        if autogen_context.run_filters(t, tname, "table", True, None):
+            upgrade_ops.ops.append(
+                ops.DropTableOp.from_table(t)
+            )
             log.info("Detected removed table %r", name)
 
     existing_tables = conn_table_names.intersection(metadata_table_names)
@@ -133,7 +151,7 @@ def _compare_tables(conn_table_names, metadata_table_names,
             event.listen(
                 t,
                 "column_reflect",
-                autogen_context['context'].impl.
+                autogen_context.migration_context.impl.
                 _compat_autogen_column_reflect(inspector))
             inspector.reflecttable(t, None)
         conn_column_info[(s, tname)] = t
@@ -144,25 +162,24 @@ def _compare_tables(conn_table_names, metadata_table_names,
         metadata_table = tname_to_table[(s, tname)]
         conn_table = existing_metadata.tables[name]
 
-        if _run_filters(
+        if autogen_context.run_filters(
                 metadata_table, tname, "table", False,
-                conn_table, object_filters):
+                conn_table):
+
+            modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
             with _compare_columns(
-                s, tname, object_filters,
+                s, tname,
                 conn_table,
                 metadata_table,
-                    diffs, autogen_context, inspector):
-                _compare_indexes_and_uniques(s, tname, object_filters,
-                                             conn_table,
-                                             metadata_table,
-                                             diffs, autogen_context, inspector)
-                _compare_foreign_keys(s, tname, object_filters, conn_table,
-                                      metadata_table, diffs, autogen_context,
-                                      inspector)
+                    modify_table_ops, autogen_context, inspector):
 
-    # TODO:
-    # table constraints
-    # sequences
+                comparators.dispatch("table")(
+                    autogen_context, modify_table_ops,
+                    s, tname, conn_table, metadata_table
+                )
+
+            if not modify_table_ops.is_empty():
+                upgrade_ops.ops.append(modify_table_ops)
 
 
 def _make_index(params, conn_table):
@@ -202,56 +219,51 @@ def _make_foreign_key(params, conn_table):
 
 
 @contextlib.contextmanager
-def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
-                     diffs, autogen_context, inspector):
+def _compare_columns(schema, tname, conn_table, metadata_table,
+                     modify_table_ops, autogen_context, inspector):
     name = '%s.%s' % (schema, tname) if schema else tname
     metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c)
     conn_col_names = dict((c.name, c) for c in conn_table.c)
     metadata_col_names = OrderedSet(sorted(metadata_cols_by_name))
 
     for cname in metadata_col_names.difference(conn_col_names):
-        if _run_filters(metadata_cols_by_name[cname], cname,
-                        "column", False, None, object_filters):
-            diffs.append(
-                ("add_column", schema, tname, metadata_cols_by_name[cname])
+        if autogen_context.run_filters(
+                metadata_cols_by_name[cname], cname,
+                "column", False, None):
+            modify_table_ops.ops.append(
+                ops.AddColumnOp.from_column_and_tablename(
+                    schema, tname, metadata_cols_by_name[cname])
             )
             log.info("Detected added column '%s.%s'", name, cname)
 
     for colname in metadata_col_names.intersection(conn_col_names):
         metadata_col = metadata_cols_by_name[colname]
         conn_col = conn_table.c[colname]
-        if not _run_filters(
+        if not autogen_context.run_filters(
                 metadata_col, colname, "column", False,
-                conn_col, object_filters):
+                conn_col):
             continue
-        col_diff = []
-        _compare_type(schema, tname, colname,
-                      conn_col,
-                      metadata_col,
-                      col_diff, autogen_context
-                      )
-        # work around SQLAlchemy issue #3023
-        if not metadata_col.primary_key:
-            _compare_nullable(schema, tname, colname,
-                              conn_col,
-                              metadata_col.nullable,
-                              col_diff, autogen_context
-                              )
-        _compare_server_default(schema, tname, colname,
-                                conn_col,
-                                metadata_col,
-                                col_diff, autogen_context
-                                )
-        if col_diff:
-            diffs.append(col_diff)
+        alter_column_op = ops.AlterColumnOp(
+            tname, colname, schema=schema)
+
+        comparators.dispatch("column")(
+            autogen_context, alter_column_op,
+            schema, tname, colname, conn_col, metadata_col
+        )
+
+        if alter_column_op.has_changes():
+            modify_table_ops.ops.append(alter_column_op)
 
     yield
 
     for cname in set(conn_col_names).difference(metadata_col_names):
-        if _run_filters(conn_table.c[cname], cname,
-                        "column", True, None, object_filters):
-            diffs.append(
-                ("remove_column", schema, tname, conn_table.c[cname])
+        if autogen_context.run_filters(
+                conn_table.c[cname], cname,
+                "column", True, None):
+            modify_table_ops.ops.append(
+                ops.DropColumnOp.from_column_and_tablename(
+                    schema, tname, conn_table.c[cname]
+                )
             )
             log.info("Detected removed column '%s.%s'", name, cname)
 
@@ -310,10 +322,12 @@ class _fk_constraint_sig(_constraint_sig):
         )
 
 
-def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
-                                 metadata_table, diffs,
-                                 autogen_context, inspector):
+@comparators.dispatch_for("table")
+def _compare_indexes_and_uniques(
+        autogen_context, modify_ops, schema, tname, conn_table,
+        metadata_table):
 
+    inspector = autogen_context.inspector
     is_create_table = conn_table is None
 
     # 1a. get raw indexes and unique constraints from metadata ...
@@ -350,7 +364,7 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
     # 3. give the dialect a chance to omit indexes and constraints that
     # we know are either added implicitly by the DB or that the DB
     # can't accurately report on
-    autogen_context['context'].impl.\
+    autogen_context.migration_context.impl.\
         correct_for_autogen_constraints(
             conn_uniques, conn_indexes,
             metadata_unique_constraints,
@@ -411,9 +425,11 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
 
     def obj_added(obj):
         if obj.is_index:
-            if _run_filters(
-                    obj.const, obj.name, "index", False, None, object_filters):
-                diffs.append(("add_index", obj.const))
+            if autogen_context.run_filters(
+                    obj.const, obj.name, "index", False, None):
+                modify_ops.ops.append(
+                    ops.CreateIndexOp.from_index(obj.const)
+                )
                 log.info("Detected added index '%s' on %s",
                          obj.name, ', '.join([
                              "'%s'" % obj.column_names
@@ -426,10 +442,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
             if is_create_table:
                 # unique constraints are created inline with table defs
                 return
-            if _run_filters(
+            if autogen_context.run_filters(
                     obj.const, obj.name,
-                    "unique_constraint", False, None, object_filters):
-                diffs.append(("add_constraint", obj.const))
+                    "unique_constraint", False, None):
+                modify_ops.ops.append(
+                    ops.AddConstraintOp.from_constraint(obj.const)
+                )
                 log.info("Detected added unique constraint '%s' on %s",
                          obj.name, ', '.join([
                              "'%s'" % obj.column_names
@@ -443,39 +461,51 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
                 # be sure what we're doing here
                 return
 
-            if _run_filters(
-                    obj.const, obj.name, "index", True, None, object_filters):
-                diffs.append(("remove_index", obj.const))
+            if autogen_context.run_filters(
+                    obj.const, obj.name, "index", True, None):
+                modify_ops.ops.append(
+                    ops.DropIndexOp.from_index(obj.const)
+                )
                 log.info(
                     "Detected removed index '%s' on '%s'", obj.name, tname)
         else:
-            if _run_filters(
+            if autogen_context.run_filters(
                     obj.const, obj.name,
-                    "unique_constraint", True, None, object_filters):
-                diffs.append(("remove_constraint", obj.const))
+                    "unique_constraint", True, None):
+                modify_ops.ops.append(
+                    ops.DropConstraintOp.from_constraint(obj.const)
+                )
                 log.info("Detected removed unique constraint '%s' on '%s'",
                          obj.name, tname
                          )
 
     def obj_changed(old, new, msg):
         if old.is_index:
-            if _run_filters(
+            if autogen_context.run_filters(
                     new.const, new.name, "index",
-                    False, old.const, object_filters):
+                    False, old.const):
                 log.info("Detected changed index '%s' on '%s':%s",
                          old.name, tname, ', '.join(msg)
                          )
-                diffs.append(("remove_index", old.const))
-                diffs.append(("add_index", new.const))
+                modify_ops.ops.append(
+                    ops.DropIndexOp.from_index(old.const)
+                )
+                modify_ops.ops.append(
+                    ops.CreateIndexOp.from_index(new.const)
+                )
         else:
-            if _run_filters(
+            if autogen_context.run_filters(
                     new.const, new.name,
-                    "unique_constraint", False, old.const, object_filters):
+                    "unique_constraint", False, old.const):
                 log.info("Detected changed unique constraint '%s' on '%s':%s",
                          old.name, tname, ', '.join(msg)
                          )
-                diffs.append(("remove_constraint", old.const))
-                diffs.append(("add_constraint", new.const))
+                modify_ops.ops.append(
+                    ops.DropConstraintOp.from_constraint(old.const)
+                )
+                modify_ops.ops.append(
+                    ops.AddConstraintOp.from_constraint(new.const)
+                )
 
     for added_name in sorted(set(metadata_names).difference(conn_names)):
         obj = metadata_names[added_name]
@@ -528,20 +558,21 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
             obj_added(unnamed_metadata_uniques[uq_sig])
 
 
-def _compare_nullable(schema, tname, cname, conn_col,
-                      metadata_col_nullable, diffs,
-                      autogen_context):
+@comparators.dispatch_for("column")
+def _compare_nullable(
+    autogen_context, alter_column_op, schema, tname, cname, conn_col,
+        metadata_col):
+
+    # work around SQLAlchemy issue #3023
+    if metadata_col.primary_key:
+        return
+
+    metadata_col_nullable = metadata_col.nullable
     conn_col_nullable = conn_col.nullable
+    alter_column_op.existing_nullable = conn_col_nullable
+
     if conn_col_nullable is not metadata_col_nullable:
-        diffs.append(
-            ("modify_nullable", schema, tname, cname,
-                {
-                    "existing_type": conn_col.type,
-                    "existing_server_default": conn_col.server_default,
-                },
-                conn_col_nullable,
-                metadata_col_nullable),
-        )
+        alter_column_op.modify_nullable = metadata_col_nullable
         log.info("Detected %s on column '%s.%s'",
                  "NULL" if metadata_col_nullable else "NOT NULL",
                  tname,
@@ -549,11 +580,13 @@ def _compare_nullable(schema, tname, cname, conn_col,
                  )
 
 
-def _compare_type(schema, tname, cname, conn_col,
-                  metadata_col, diffs,
-                  autogen_context):
+@comparators.dispatch_for("column")
+def _compare_type(
+    autogen_context, alter_column_op, schema, tname, cname, conn_col,
+        metadata_col):
 
     conn_type = conn_col.type
+    alter_column_op.existing_type = conn_type
     metadata_type = metadata_col.type
     if conn_type._type_affinity is sqltypes.NullType:
         log.info("Couldn't determine database type "
@@ -564,19 +597,11 @@ def _compare_type(schema, tname, cname, conn_col,
                  "the model; can't compare", tname, cname)
         return
 
-    isdiff = autogen_context['context']._compare_type(conn_col, metadata_col)
+    isdiff = autogen_context.migration_context._compare_type(
+        conn_col, metadata_col)
 
     if isdiff:
-
-        diffs.append(
-            ("modify_type", schema, tname, cname,
-             {
-                 "existing_nullable": conn_col.nullable,
-                 "existing_server_default": conn_col.server_default,
-             },
-             conn_type,
-             metadata_type),
-        )
+        alter_column_op.modify_type = metadata_type
         log.info("Detected type change from %r to %r on '%s.%s'",
                  conn_type, metadata_type, tname, cname
                  )
@@ -594,7 +619,7 @@ def _render_server_default_for_compare(metadata_default,
             metadata_default = metadata_default.arg
         else:
             metadata_default = str(metadata_default.arg.compile(
-                dialect=autogen_context['dialect']))
+                dialect=autogen_context.dialect))
     if isinstance(metadata_default, compat.string_types):
         if metadata_col.type._type_affinity is sqltypes.String:
             metadata_default = re.sub(r"^'|'$", "", metadata_default)
@@ -605,8 +630,10 @@ def _render_server_default_for_compare(metadata_default,
         return None
 
 
-def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
-                            diffs, autogen_context):
+@comparators.dispatch_for("column")
+def _compare_server_default(
+    autogen_context, alter_column_op, schema, tname, cname,
+        conn_col, metadata_col):
 
     metadata_default = metadata_col.server_default
     conn_col_default = conn_col.server_default
@@ -618,36 +645,31 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
     rendered_conn_default = conn_col.server_default.arg.text \
         if conn_col.server_default else None
 
-    isdiff = autogen_context['context']._compare_server_default(
+    alter_column_op.existing_server_default = conn_col_default
+
+    isdiff = autogen_context.migration_context._compare_server_default(
         conn_col, metadata_col,
         rendered_metadata_default,
         rendered_conn_default
     )
     if isdiff:
-        conn_col_default = rendered_conn_default
-        diffs.append(
-            ("modify_default", schema, tname, cname,
-                {
-                    "existing_nullable": conn_col.nullable,
-                    "existing_type": conn_col.type,
-                },
-                conn_col_default,
-                metadata_default),
-        )
-        log.info("Detected server default on column '%s.%s'",
-                 tname,
-                 cname
-                 )
+        alter_column_op.modify_server_default = metadata_default
+        log.info(
+            "Detected server default on column '%s.%s'",
+            tname, cname)
 
 
-def _compare_foreign_keys(schema, tname, object_filters, conn_table,
-                          metadata_table, diffs, autogen_context, inspector):
+@comparators.dispatch_for("table")
+def _compare_foreign_keys(
+    autogen_context, modify_table_ops, schema, tname, conn_table,
+        metadata_table):
 
     # if we're doing CREATE TABLE, all FKs are created
     # inline within the table def
     if conn_table is None:
         return
 
+    inspector = autogen_context.inspector
     metadata_fks = set(
         fk for fk in metadata_table.constraints
         if isinstance(fk, sa_schema.ForeignKeyConstraint)
@@ -673,10 +695,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
     )
 
     def _add_fk(obj, compare_to):
-        if _run_filters(
+        if autogen_context.run_filters(
                 obj.const, obj.name, "foreign_key_constraint", False,
-                compare_to, object_filters):
-            diffs.append(('add_fk', const.const))
+                compare_to):
+            modify_table_ops.ops.append(
+                ops.CreateForeignKeyOp.from_constraint(const.const)
+            )
 
             log.info(
                 "Detected added foreign key (%s)(%s) on table %s%s",
@@ -686,10 +710,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
                 obj.source_table)
 
     def _remove_fk(obj, compare_to):
-        if _run_filters(
+        if autogen_context.run_filters(
                 obj.const, obj.name, "foreign_key_constraint", True,
-                compare_to, object_filters):
-            diffs.append(('remove_fk', obj.const))
+                compare_to):
+            modify_table_ops.ops.append(
+                ops.DropConstraintOp.from_constraint(obj.const)
+            )
             log.info(
                 "Detected removed foreign key (%s)(%s) on table %s%s",
                 ", ".join(obj.source_columns),
@@ -713,5 +739,3 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
             compare_to = conn_fks_by_name[const.name].const \
                 if const.name in conn_fks_by_name else None
             _add_fk(const, compare_to)
-
-    return diffs
diff --git a/alembic/autogenerate/compose.py b/alembic/autogenerate/compose.py
deleted file mode 100644 (file)
index b42b505..0000000
+++ /dev/null
@@ -1,144 +0,0 @@
-import itertools
-from ..operations import ops
-
-
-def _to_migration_script(autogen_context, migration_script, diffs):
-    _to_upgrade_op(
-        autogen_context,
-        diffs,
-        migration_script.upgrade_ops,
-    )
-
-    _to_downgrade_op(
-        autogen_context,
-        diffs,
-        migration_script.downgrade_ops,
-    )
-
-
-def _to_upgrade_op(autogen_context, diffs, upgrade_ops):
-    return _to_updown_op(autogen_context, diffs, upgrade_ops, "upgrade")
-
-
-def _to_downgrade_op(autogen_context, diffs, downgrade_ops):
-    return _to_updown_op(autogen_context, diffs, downgrade_ops, "downgrade")
-
-
-def _to_updown_op(autogen_context, diffs, op_container, type_):
-    if not diffs:
-        return
-
-    if type_ == 'downgrade':
-        diffs = reversed(diffs)
-
-    dest = [op_container.ops]
-
-    for (schema, tablename), subdiffs in _group_diffs_by_table(diffs):
-        subdiffs = list(subdiffs)
-        if tablename is not None:
-            table_ops = []
-            op = ops.ModifyTableOps(tablename, table_ops, schema=schema)
-            dest[-1].append(op)
-            dest.append(table_ops)
-        for diff in subdiffs:
-            _produce_command(autogen_context, diff, dest[-1], type_)
-        if tablename is not None:
-            dest.pop(-1)
-
-
-def _produce_command(autogen_context, diff, op_list, updown):
-    if isinstance(diff, tuple):
-        _produce_adddrop_command(updown, diff, op_list, autogen_context)
-    else:
-        _produce_modify_command(updown, diff, op_list, autogen_context)
-
-
-def _produce_adddrop_command(updown, diff, op_list, autogen_context):
-    cmd_type = diff[0]
-    adddrop, cmd_type = cmd_type.split("_")
-
-    cmd_args = diff[1:]
-
-    _commands = {
-        "table": (ops.DropTableOp.from_table, ops.CreateTableOp.from_table),
-        "column": (
-            ops.DropColumnOp.from_column_and_tablename,
-            ops.AddColumnOp.from_column_and_tablename),
-        "index": (ops.DropIndexOp.from_index, ops.CreateIndexOp.from_index),
-        "constraint": (
-            ops.DropConstraintOp.from_constraint,
-            ops.AddConstraintOp.from_constraint),
-        "fk": (
-            ops.DropConstraintOp.from_constraint,
-            ops.CreateForeignKeyOp.from_constraint)
-    }
-
-    cmd_callables = _commands[cmd_type]
-
-    if (
-        updown == "upgrade" and adddrop == "add"
-    ) or (
-        updown == "downgrade" and adddrop == "remove"
-    ):
-        op_list.append(cmd_callables[1](*cmd_args))
-    else:
-        op_list.append(cmd_callables[0](*cmd_args))
-
-
-def _produce_modify_command(updown, diffs, op_list, autogen_context):
-    sname, tname, cname = diffs[0][1:4]
-    kw = {}
-
-    _arg_struct = {
-        "modify_type": ("existing_type", "modify_type"),
-        "modify_nullable": ("existing_nullable", "modify_nullable"),
-        "modify_default": ("existing_server_default", "modify_server_default"),
-    }
-    for diff in diffs:
-        diff_kw = diff[4]
-        for arg in ("existing_type",
-                    "existing_nullable",
-                    "existing_server_default"):
-            if arg in diff_kw:
-                kw.setdefault(arg, diff_kw[arg])
-        old_kw, new_kw = _arg_struct[diff[0]]
-        if updown == "upgrade":
-            kw[new_kw] = diff[-1]
-            kw[old_kw] = diff[-2]
-        else:
-            kw[new_kw] = diff[-2]
-            kw[old_kw] = diff[-1]
-
-    if "modify_nullable" in kw:
-        kw.pop("existing_nullable", None)
-    if "modify_server_default" in kw:
-        kw.pop("existing_server_default", None)
-
-    op_list.append(
-        ops.AlterColumnOp(
-            tname, cname, schema=sname,
-            **kw
-        )
-    )
-
-
-def _group_diffs_by_table(diffs):
-    _adddrop = {
-        "table": lambda diff: (None, None),
-        "column": lambda diff: (diff[0], diff[1]),
-        "index": lambda diff: (diff[0].table.schema, diff[0].table.name),
-        "constraint": lambda diff: (diff[0].table.schema, diff[0].table.name),
-        "fk": lambda diff: (diff[0].parent.schema, diff[0].parent.name)
-    }
-
-    def _derive_table(diff):
-        if isinstance(diff, tuple):
-            cmd_type = diff[0]
-            adddrop, cmd_type = cmd_type.split("_")
-            return _adddrop[cmd_type](diff[1:])
-        else:
-            sname, tname = diff[0][1:3]
-            return sname, tname
-
-    return itertools.groupby(diffs, _derive_table)
-
diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py
deleted file mode 100644 (file)
index c686156..0000000
+++ /dev/null
@@ -1,92 +0,0 @@
-from .. import util
-from . import api
-from . import compose
-from . import compare
-from . import render
-from ..operations import ops
-
-
-class RevisionContext(object):
-    def __init__(self, config, script_directory, command_args):
-        self.config = config
-        self.script_directory = script_directory
-        self.command_args = command_args
-        self.template_args = {
-            'config': config  # Let templates use config for
-                              # e.g. multiple databases
-        }
-        self.generated_revisions = [
-            self._default_revision()
-        ]
-
-    def _to_script(self, migration_script):
-        template_args = {}
-        for k, v in self.template_args.items():
-            template_args.setdefault(k, v)
-
-        if migration_script._autogen_context is not None:
-            render._render_migration_script(
-                migration_script._autogen_context, migration_script,
-                template_args
-            )
-
-        return self.script_directory.generate_revision(
-            migration_script.rev_id,
-            migration_script.message,
-            refresh=True,
-            head=migration_script.head,
-            splice=migration_script.splice,
-            branch_labels=migration_script.branch_label,
-            version_path=migration_script.version_path,
-            **template_args)
-
-    def run_autogenerate(self, rev, context):
-        if self.command_args['sql']:
-            raise util.CommandError(
-                "Using --sql with --autogenerate does not make any sense")
-        if set(self.script_directory.get_revisions(rev)) != \
-                set(self.script_directory.get_revisions("heads")):
-            raise util.CommandError("Target database is not up to date.")
-
-        autogen_context = api._autogen_context(context)
-
-        diffs = []
-        compare._produce_net_changes(autogen_context, diffs)
-
-        migration_script = self.generated_revisions[0]
-
-        compose._to_migration_script(autogen_context, migration_script, diffs)
-
-        hook = context.opts.get('process_revision_directives', None)
-        if hook:
-            hook(context, rev, self.generated_revisions)
-
-        for migration_script in self.generated_revisions:
-            migration_script._autogen_context = autogen_context
-
-    def run_no_autogenerate(self, rev, context):
-        hook = context.opts.get('process_revision_directives', None)
-        if hook:
-            hook(context, rev, self.generated_revisions)
-
-        for migration_script in self.generated_revisions:
-            migration_script._autogen_context = None
-
-    def _default_revision(self):
-        op = ops.MigrationScript(
-            rev_id=self.command_args['rev_id'] or util.rev_id(),
-            message=self.command_args['message'],
-            imports=set(),
-            upgrade_ops=ops.UpgradeOps([]),
-            downgrade_ops=ops.DowngradeOps([]),
-            head=self.command_args['head'],
-            splice=self.command_args['splice'],
-            branch_label=self.command_args['branch_label'],
-            version_path=self.command_args['version_path']
-        )
-        op._autogen_context = None
-        return op
-
-    def generate_scripts(self):
-        for generated_revision in self.generated_revisions:
-            yield self._to_script(generated_revision)
index c3f3df1e841c47b354405f68f8b9d3d738e08efd..6f5f96c1c591a4baa7f8aa0db92f3ee0ede36c89 100644 (file)
@@ -30,8 +30,8 @@ def _indent(text):
 
 
 def _render_migration_script(autogen_context, migration_script, template_args):
-    opts = autogen_context['opts']
-    imports = autogen_context['imports']
+    opts = autogen_context.opts
+    imports = autogen_context._imports
     template_args[opts['upgrade_token']] = _indent(_render_cmd_body(
         migration_script.upgrade_ops, autogen_context))
     template_args[opts['downgrade_token']] = _indent(_render_cmd_body(
@@ -78,23 +78,26 @@ def render_op_text(autogen_context, op):
 
 @renderers.dispatch_for(ops.ModifyTableOps)
 def _render_modify_table(autogen_context, op):
-    opts = autogen_context['opts']
+    opts = autogen_context.opts
     render_as_batch = opts.get('render_as_batch', False)
 
     if op.ops:
         lines = []
         if render_as_batch:
-            lines.append(
-                "with op.batch_alter_table(%r, schema=%r) as batch_op:"
-                % (op.table_name, op.schema)
-            )
-            autogen_context['batch_prefix'] = 'batch_op.'
-        for t_op in op.ops:
-            t_lines = render_op(autogen_context, t_op)
-            lines.extend(t_lines)
-        if render_as_batch:
-            del autogen_context['batch_prefix']
-            lines.append("")
+            with autogen_context._within_batch():
+                lines.append(
+                    "with op.batch_alter_table(%r, schema=%r) as batch_op:"
+                    % (op.table_name, op.schema)
+                )
+                for t_op in op.ops:
+                    t_lines = render_op(autogen_context, t_op)
+                    lines.extend(t_lines)
+                lines.append("")
+        else:
+            for t_op in op.ops:
+                t_lines = render_op(autogen_context, t_op)
+                lines.extend(t_lines)
+
         return lines
     else:
         return [
@@ -149,7 +152,7 @@ def _drop_table(autogen_context, op):
 def _add_index(autogen_context, op):
     index = op.to_index()
 
-    has_batch = 'batch_prefix' in autogen_context
+    has_batch = autogen_context._has_batch
 
     if has_batch:
         tmpl = "%(prefix)screate_index(%(name)r, [%(columns)s], "\
@@ -180,7 +183,7 @@ def _add_index(autogen_context, op):
 
 @renderers.dispatch_for(ops.DropIndexOp)
 def _drop_index(autogen_context, op):
-    has_batch = 'batch_prefix' in autogen_context
+    has_batch = autogen_context._has_batch
 
     if has_batch:
         tmpl = "%(prefix)sdrop_index(%(name)r)"
@@ -243,7 +246,7 @@ def _add_check_constraint(constraint, autogen_context):
 @renderers.dispatch_for(ops.DropConstraintOp)
 def _drop_constraint(autogen_context, op):
 
-    if 'batch_prefix' in autogen_context:
+    if autogen_context._has_batch:
         template = "%(prefix)sdrop_constraint"\
             "(%(name)r, type_=%(type)r)"
     else:
@@ -266,7 +269,7 @@ def _drop_constraint(autogen_context, op):
 def _add_column(autogen_context, op):
 
     schema, tname, column = op.schema, op.table_name, op.column
-    if 'batch_prefix' in autogen_context:
+    if autogen_context._has_batch:
         template = "%(prefix)sadd_column(%(column)s)"
     else:
         template = "%(prefix)sadd_column(%(tname)r, %(column)s"
@@ -287,7 +290,7 @@ def _drop_column(autogen_context, op):
 
     schema, tname, column_name = op.schema, op.table_name, op.column_name
 
-    if 'batch_prefix' in autogen_context:
+    if autogen_context._has_batch:
         template = "%(prefix)sdrop_column(%(cname)r)"
     else:
         template = "%(prefix)sdrop_column(%(tname)r, %(cname)r"
@@ -319,7 +322,7 @@ def _alter_column(autogen_context, op):
 
     indent = " " * 11
 
-    if 'batch_prefix' in autogen_context:
+    if autogen_context._has_batch:
         template = "%(prefix)salter_column(%(cname)r"
     else:
         template = "%(prefix)salter_column(%(tname)r, %(cname)r"
@@ -343,16 +346,16 @@ def _alter_column(autogen_context, op):
     if nullable is not None:
         text += ",\n%snullable=%r" % (
             indent, nullable,)
-    if existing_nullable is not None:
+    if nullable is None and existing_nullable is not None:
         text += ",\n%sexisting_nullable=%r" % (
             indent, existing_nullable)
-    if existing_server_default:
+    if server_default is False and existing_server_default:
         rendered = _render_server_default(
             existing_server_default,
             autogen_context)
         text += ",\n%sexisting_server_default=%s" % (
             indent, rendered)
-    if schema and "batch_prefix" not in autogen_context:
+    if schema and not autogen_context._has_batch:
         text += ",\n%sschema=%r" % (indent, schema)
     text += ")"
     return text
@@ -409,7 +412,7 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True):
         return template % {
             "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
             "sql": compat.text_type(
-                value.compile(dialect=autogen_context['dialect'],
+                value.compile(dialect=autogen_context.dialect,
                               **compile_kw)
             )
         }
@@ -432,7 +435,7 @@ def _get_index_rendered_expressions(idx, autogen_context):
 def _uq_constraint(constraint, autogen_context, alter):
     opts = []
 
-    has_batch = 'batch_prefix' in autogen_context
+    has_batch = autogen_context._has_batch
 
     if constraint.deferrable:
         opts.append(("deferrable", str(constraint.deferrable)))
@@ -467,7 +470,7 @@ def _uq_constraint(constraint, autogen_context, alter):
 
 
 def _user_autogenerate_prefix(autogen_context, target):
-    prefix = autogen_context['opts']['user_module_prefix']
+    prefix = autogen_context.opts['user_module_prefix']
     if prefix is None:
         return "%s." % target.__module__
     else:
@@ -475,20 +478,19 @@ def _user_autogenerate_prefix(autogen_context, target):
 
 
 def _sqlalchemy_autogenerate_prefix(autogen_context):
-    return autogen_context['opts']['sqlalchemy_module_prefix'] or ''
+    return autogen_context.opts['sqlalchemy_module_prefix'] or ''
 
 
 def _alembic_autogenerate_prefix(autogen_context):
-    if 'batch_prefix' in autogen_context:
-        return autogen_context['batch_prefix']
+    if autogen_context._has_batch:
+        return 'batch_op.'
     else:
-        return autogen_context['opts']['alembic_module_prefix'] or ''
+        return autogen_context.opts['alembic_module_prefix'] or ''
 
 
 def _user_defined_render(type_, object_, autogen_context):
-    if 'opts' in autogen_context and \
-            'render_item' in autogen_context['opts']:
-        render = autogen_context['opts']['render_item']
+    if 'render_item' in autogen_context.opts:
+        render = autogen_context.opts['render_item']
         if render:
             rendered = render(type_, object_, autogen_context)
             if rendered is not False:
@@ -547,7 +549,7 @@ def _repr_type(type_, autogen_context):
         return rendered
 
     mod = type(type_).__module__
-    imports = autogen_context.get('imports', None)
+    imports = autogen_context._imports
     if mod.startswith("sqlalchemy.dialects"):
         dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
         if imports is not None:
index 08a05513e68f79474f2d45ee26af2916e75e48ab..71e85159c9f53bfa21fed7cba85454f76929cc7e 100644 (file)
@@ -3,6 +3,7 @@ from ..util import sqla_compat
 from . import schemaobj
 from sqlalchemy.types import NULLTYPE
 from .base import Operations, BatchOperations
+import re
 
 
 class MigrateOperation(object):
@@ -34,6 +35,10 @@ class MigrateOperation(object):
 class AddConstraintOp(MigrateOperation):
     """Represent an add constraint operation."""
 
+    @property
+    def constraint_type(self):
+        raise NotImplementedError()
+
     @classmethod
     def from_constraint(cls, constraint):
         funcs = {
@@ -45,17 +50,40 @@ class AddConstraintOp(MigrateOperation):
         }
         return funcs[constraint.__visit_name__](constraint)
 
+    def reverse(self):
+        return DropConstraintOp.from_constraint(self.to_constraint())
+
+    def to_diff_tuple(self):
+        return ("add_constraint", self.to_constraint())
+
 
 @Operations.register_operation("drop_constraint")
 @BatchOperations.register_operation("drop_constraint", "batch_drop_constraint")
 class DropConstraintOp(MigrateOperation):
     """Represent a drop constraint operation."""
 
-    def __init__(self, constraint_name, table_name, type_=None, schema=None):
+    def __init__(
+            self,
+            constraint_name, table_name, type_=None, schema=None,
+            _orig_constraint=None):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.constraint_type = type_
         self.schema = schema
+        self._orig_constraint = _orig_constraint
+
+    def reverse(self):
+        if self._orig_constraint is None:
+            raise ValueError(
+                "operation is not reversible; "
+                "original constraint is not present")
+        return AddConstraintOp.from_constraint(self._orig_constraint)
+
+    def to_diff_tuple(self):
+        if self.constraint_type == "foreignkey":
+            return ("remove_fk", self.to_constraint())
+        else:
+            return ("remove_constraint", self.to_constraint())
 
     @classmethod
     def from_constraint(cls, constraint):
@@ -72,9 +100,18 @@ class DropConstraintOp(MigrateOperation):
             constraint.name,
             constraint_table.name,
             schema=constraint_table.schema,
-            type_=types[constraint.__visit_name__]
+            type_=types[constraint.__visit_name__],
+            _orig_constraint=constraint
         )
 
+    def to_constraint(self):
+        if self._orig_constraint is not None:
+            return self._orig_constraint
+        else:
+            raise ValueError(
+                "constraint cannot be produced; "
+                "original constraint is not present")
+
     @classmethod
     @util._with_legacy_names([("type", "type_")])
     def drop_constraint(
@@ -124,8 +161,11 @@ class DropConstraintOp(MigrateOperation):
 class CreatePrimaryKeyOp(AddConstraintOp):
     """Represent a create primary key operation."""
 
+    constraint_type = "primarykey"
+
     def __init__(
-            self, constraint_name, table_name, columns, schema=None, **kw):
+            self, constraint_name, table_name, columns,
+            schema=None, **kw):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.columns = columns
@@ -225,8 +265,11 @@ class CreatePrimaryKeyOp(AddConstraintOp):
 class CreateUniqueConstraintOp(AddConstraintOp):
     """Represent a create unique constraint operation."""
 
+    constraint_type = "unique"
+
     def __init__(
-            self, constraint_name, table_name, columns, schema=None, **kw):
+            self, constraint_name, table_name,
+            columns, schema=None, **kw):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.columns = columns
@@ -342,6 +385,8 @@ class CreateUniqueConstraintOp(AddConstraintOp):
 class CreateForeignKeyOp(AddConstraintOp):
     """Represent a create foreign key constraint operation."""
 
+    constraint_type = "foreignkey"
+
     def __init__(
             self, constraint_name, source_table, referent_table, local_cols,
             remote_cols, **kw):
@@ -352,6 +397,9 @@ class CreateForeignKeyOp(AddConstraintOp):
         self.remote_cols = remote_cols
         self.kw = kw
 
+    def to_diff_tuple(self):
+        return ("add_fk", self.to_constraint())
+
     @classmethod
     def from_constraint(cls, constraint):
         kw = {}
@@ -507,6 +555,8 @@ class CreateForeignKeyOp(AddConstraintOp):
 class CreateCheckConstraintOp(AddConstraintOp):
     """Represent a create check constraint operation."""
 
+    constraint_type = "check"
+
     def __init__(
             self, constraint_name, table_name, condition, schema=None, **kw):
         self.constraint_name = constraint_name
@@ -523,7 +573,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
             constraint.name,
             constraint_table.name,
             constraint.condition,
-            schema=constraint_table.schema
+            schema=constraint_table.schema,
         )
 
     def to_constraint(self, migration_context=None):
@@ -624,6 +674,12 @@ class CreateIndexOp(MigrateOperation):
         self.kw = kw
         self._orig_index = _orig_index
 
+    def reverse(self):
+        return DropIndexOp.from_index(self.to_index())
+
+    def to_diff_tuple(self):
+        return ("add_index", self.to_index())
+
     @classmethod
     def from_index(cls, index):
         return cls(
@@ -729,10 +785,22 @@ class CreateIndexOp(MigrateOperation):
 class DropIndexOp(MigrateOperation):
     """Represent a drop index operation."""
 
-    def __init__(self, index_name, table_name=None, schema=None):
+    def __init__(
+            self, index_name, table_name=None, schema=None, _orig_index=None):
         self.index_name = index_name
         self.table_name = table_name
         self.schema = schema
+        self._orig_index = _orig_index
+
+    def to_diff_tuple(self):
+        return ("remove_index", self.to_index())
+
+    def reverse(self):
+        if self._orig_index is None:
+            raise ValueError(
+                "operation is not reversible; "
+                "original index is not present")
+        return CreateIndexOp.from_index(self._orig_index)
 
     @classmethod
     def from_index(cls, index):
@@ -740,6 +808,7 @@ class DropIndexOp(MigrateOperation):
             index.name,
             index.table.name,
             schema=index.table.schema,
+            _orig_index=index
         )
 
     def to_index(self, migration_context=None):
@@ -807,6 +876,12 @@ class CreateTableOp(MigrateOperation):
         self.kw = kw
         self._orig_table = _orig_table
 
+    def reverse(self):
+        return DropTableOp.from_table(self.to_table())
+
+    def to_diff_tuple(self):
+        return ("add_table", self.to_table())
+
     @classmethod
     def from_table(cls, table):
         return cls(
@@ -921,16 +996,30 @@ class CreateTableOp(MigrateOperation):
 class DropTableOp(MigrateOperation):
     """Represent a drop table operation."""
 
-    def __init__(self, table_name, schema=None, table_kw=None):
+    def __init__(
+            self, table_name, schema=None, table_kw=None, _orig_table=None):
         self.table_name = table_name
         self.schema = schema
         self.table_kw = table_kw or {}
+        self._orig_table = _orig_table
+
+    def to_diff_tuple(self):
+        return ("remove_table", self.to_table())
+
+    def reverse(self):
+        if self._orig_table is None:
+            raise ValueError(
+                "operation is not reversible; "
+                "original table is not present")
+        return CreateTableOp.from_table(self._orig_table)
 
     @classmethod
     def from_table(cls, table):
-        return cls(table.name, schema=table.schema)
+        return cls(table.name, schema=table.schema, _orig_table=table)
 
-    def to_table(self, migration_context):
+    def to_table(self, migration_context=None):
+        if self._orig_table is not None:
+            return self._orig_table
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.table(
             self.table_name,
@@ -1029,6 +1118,87 @@ class AlterColumnOp(AlterTableOp):
         self.modify_type = modify_type
         self.kw = kw
 
+    def to_diff_tuple(self):
+        col_diff = []
+        schema, tname, cname = self.schema, self.table_name, self.column_name
+
+        if self.modify_type is not None:
+            col_diff.append(
+                ("modify_type", schema, tname, cname,
+                 {
+                     "existing_nullable": self.existing_nullable,
+                     "existing_server_default": self.existing_server_default,
+                 },
+                 self.existing_type,
+                 self.modify_type)
+            )
+
+        if self.modify_nullable is not None:
+            col_diff.append(
+                ("modify_nullable", schema, tname, cname,
+                    {
+                        "existing_type": self.existing_type,
+                        "existing_server_default": self.existing_server_default
+                    },
+                    self.existing_nullable,
+                    self.modify_nullable)
+            )
+
+        if self.modify_server_default is not False:
+            col_diff.append(
+                ("modify_default", schema, tname, cname,
+                 {
+                     "existing_nullable": self.existing_nullable,
+                     "existing_type": self.existing_type
+                 },
+                 self.existing_server_default,
+                 self.modify_server_default)
+            )
+
+        return col_diff
+
+    def has_changes(self):
+        hc1 = self.modify_nullable is not None or \
+            self.modify_server_default is not False or \
+            self.modify_type is not None
+        if hc1:
+            return True
+        for kw in self.kw:
+            if kw.startswith('modify_'):
+                return True
+        else:
+            return False
+
+    def reverse(self):
+
+        kw = self.kw.copy()
+        kw['existing_type'] = self.existing_type
+        kw['existing_nullable'] = self.existing_nullable
+        kw['existing_server_default'] = self.existing_server_default
+        if self.modify_type is not None:
+            kw['modify_type'] = self.modify_type
+        if self.modify_nullable is not None:
+            kw['modify_nullable'] = self.modify_nullable
+        if self.modify_server_default is not False:
+            kw['modify_server_default'] = self.modify_server_default
+
+        # TODO: make this a little simpler
+        all_keys = set(m.group(1) for m in [
+            re.match(r'^(?:existing_|modify_)(.+)$', k)
+            for k in kw
+        ] if m)
+
+        for k in all_keys:
+            if 'modify_%s' % k in kw:
+                swap = kw['existing_%s' % k]
+                kw['existing_%s' % k] = kw['modify_%s' % k]
+                kw['modify_%s' % k] = swap
+
+        return self.__class__(
+            self.table_name, self.column_name, schema=self.schema,
+            **kw
+        )
+
     @classmethod
     @util._with_legacy_names([('name', 'new_column_name')])
     def alter_column(
@@ -1177,6 +1347,13 @@ class AddColumnOp(AlterTableOp):
         super(AddColumnOp, self).__init__(table_name, schema=schema)
         self.column = column
 
+    def reverse(self):
+        return DropColumnOp.from_column_and_tablename(
+            self.schema, self.table_name, self.column)
+
+    def to_diff_tuple(self):
+        return ("add_column", self.schema, self.table_name, self.column)
+
     @classmethod
     def from_column(cls, col):
         return cls(col.table.name, col, schema=col.table.schema)
@@ -1265,14 +1442,30 @@ class AddColumnOp(AlterTableOp):
 class DropColumnOp(AlterTableOp):
     """Represent a drop column operation."""
 
-    def __init__(self, table_name, column_name, schema=None, **kw):
+    def __init__(
+            self, table_name, column_name, schema=None,
+            _orig_column=None, **kw):
         super(DropColumnOp, self).__init__(table_name, schema=schema)
         self.column_name = column_name
         self.kw = kw
+        self._orig_column = _orig_column
+
+    def to_diff_tuple(self):
+        return (
+            "remove_column", self.schema, self.table_name, self.to_column())
+
+    def reverse(self):
+        if self._orig_column is None:
+            raise ValueError(
+                "operation is not reversible; "
+                "original column is not present")
+
+        return AddColumnOp.from_column_and_tablename(
+            self.schema, self.table_name, self._orig_column)
 
     @classmethod
     def from_column_and_tablename(cls, schema, tname, col):
-        return cls(tname, col.name, schema=schema)
+        return cls(tname, col.name, schema=schema, _orig_column=col)
 
     def to_column(self, migration_context=None):
         schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -1522,6 +1715,21 @@ class OpContainer(MigrateOperation):
     def __init__(self, ops=()):
         self.ops = ops
 
+    def is_empty(self):
+        return not self.ops
+
+    def as_diffs(self):
+        return list(OpContainer._ops_as_diffs(self))
+
+    @classmethod
+    def _ops_as_diffs(cls, migrations):
+        for op in migrations.ops:
+            if hasattr(op, 'ops'):
+                for sub_op in cls._ops_as_diffs(op):
+                    yield sub_op
+            else:
+                yield op.to_diff_tuple()
+
 
 class ModifyTableOps(OpContainer):
     """Contains a sequence of operations that all apply to a single Table."""
@@ -1531,6 +1739,15 @@ class ModifyTableOps(OpContainer):
         self.table_name = table_name
         self.schema = schema
 
+    def reverse(self):
+        return ModifyTableOps(
+            self.table_name,
+            ops=list(reversed(
+                [op.reverse() for op in self.ops]
+            )),
+            schema=self.schema
+        )
+
 
 class UpgradeOps(OpContainer):
     """contains a sequence of operations that would apply to the
@@ -1542,6 +1759,15 @@ class UpgradeOps(OpContainer):
 
     """
 
+    def reverse_into(self, downgrade_ops):
+        downgrade_ops.ops[:] = list(reversed(
+            [op.reverse() for op in self.ops]
+        ))
+        return downgrade_ops
+
+    def reverse(self):
+        return self.reverse_into(DowngradeOps(ops=[]))
+
 
 class DowngradeOps(OpContainer):
     """contains a sequence of operations that would apply to the
@@ -1553,6 +1779,13 @@ class DowngradeOps(OpContainer):
 
     """
 
+    def reverse(self):
+        return UpgradeOps(
+            ops=list(reversed(
+                [op.reverse() for op in self.ops]
+            ))
+        )
+
 
 class MigrationScript(MigrateOperation):
     """represents a migration script.
@@ -1583,4 +1816,3 @@ class MigrationScript(MigrateOperation):
         self.version_path = version_path
         self.upgrade_ops = upgrade_ops
         self.downgrade_ops = downgrade_ops
-
index 84a3c7fd6380f4e8de7021e04e7bc35e34960d47..e811a36ca26f00ea021435cc829e54f983e1581d 100644 (file)
@@ -118,6 +118,7 @@ class MigrationContext(object):
                   connection=None,
                   url=None,
                   dialect_name=None,
+                  dialect=None,
                   environment_context=None,
                   opts=None,
                   ):
@@ -152,7 +153,7 @@ class MigrationContext(object):
         elif dialect_name:
             url = sqla_url.make_url("%s://" % dialect_name)
             dialect = url.get_dialect()()
-        else:
+        elif not dialect:
             raise Exception("Connection, url, or dialect_name is required.")
 
         return MigrationContext(dialect, connection, opts, environment_context)
index 1fb09428ccafff905cccc7fad1e333e8b6aefd2e..6c92e3c695ced79895727ce5186cff296a02c2ce 100644 (file)
@@ -257,30 +257,57 @@ class immutabledict(dict):
 
 
 class Dispatcher(object):
-    def __init__(self):
+    def __init__(self, uselist=False):
         self._registry = {}
+        self.uselist = uselist
 
     def dispatch_for(self, target, qualifier='default'):
         def decorate(fn):
-            assert isinstance(target, type)
-            assert target not in self._registry
-            self._registry[(target, qualifier)] = fn
+            if self.uselist:
+                assert target not in self._registry
+                self._registry.setdefault((target, qualifier), []).append(fn)
+            else:
+                assert target not in self._registry
+                self._registry[(target, qualifier)] = fn
             return fn
         return decorate
 
     def dispatch(self, obj, qualifier='default'):
-        for spcls in type(obj).__mro__:
+
+        if isinstance(obj, string_types):
+            targets = [obj]
+        elif isinstance(obj, type):
+            targets = obj.__mro__
+        else:
+            targets = type(obj).__mro__
+
+        for spcls in targets:
             if qualifier != 'default' and (spcls, qualifier) in self._registry:
-                return self._registry[(spcls, qualifier)]
+                return self._fn_or_list(self._registry[(spcls, qualifier)])
             elif (spcls, 'default') in self._registry:
-                return self._registry[(spcls, 'default')]
+                return self._fn_or_list(self._registry[(spcls, 'default')])
         else:
             raise ValueError("no dispatch function for object: %s" % obj)
 
+    def _fn_or_list(self, fn_or_list):
+        if self.uselist:
+            def go(*arg, **kw):
+                for fn in fn_or_list:
+                    fn(*arg, **kw)
+            return go
+        else:
+            return fn_or_list
+
     def branch(self):
         """Return a copy of this dispatcher that is independently
         writable."""
 
         d = Dispatcher()
-        d._registry.update(self._registry)
+        if self.uselist:
+            d._registry.update(
+                (k, [fn for fn in self._registry[k]])
+                for k in self._registry
+            )
+        else:
+            d._registry.update(self._registry)
         return d
index b024ab137cc9234127558523c9409e9dd88bc9c4..8b026e81cadc4b7e1dbd1591e5b4f9cec45a5fa9 100644 (file)
@@ -4,7 +4,8 @@
 Autogeneration
 ==============
 
-The autogenerate system has two areas of API that are public:
+The autogeneration system has a wide degree of public API, including
+the following areas:
 
 1. The ability to do a "diff" of a :class:`~sqlalchemy.schema.MetaData` object against
    a database, and receive a data structure back.  This structure
@@ -15,9 +16,22 @@ The autogenerate system has two areas of API that are public:
    revision scripts, including support for multiple revision scripts
    generated in one pass.
 
+3. The ability to add new operation directives to autogeneration, including
+   custom schema/model comparison functions and revision script rendering.
+
 Getting Diffs
 ==============
 
+The simplest API autogenerate provides is the "schema comparison" API;
+these are simple functions that will run all registered "comparison" functions
+between a :class:`~sqlalchemy.schema.MetaData` object and a database
+backend to produce a structure showing how they differ.   The two
+functions provided are :func:`.compare_metadata`, which is more of the
+"legacy" function that produces diff tuples, and :func:`.produce_migrations`,
+which produces a structure consisting of operation directives detailed in
+:ref:`alembic.operations.toplevel`.
+
+
 .. autofunction:: alembic.autogenerate.compare_metadata
 
 .. autofunction:: alembic.autogenerate.produce_migrations
@@ -184,6 +198,8 @@ to whatever is in this list.
 
 .. autofunction:: alembic.autogenerate.render_python_code
 
+.. _autogen_custom_ops:
+
 Autogenerating Custom Operation Directives
 ==========================================
 
@@ -192,16 +208,180 @@ subclasses of :class:`.MigrateOperation` in order to add new ``op.``
 directives.  In the preceding section :ref:`customizing_revision`, we
 also learned that these same :class:`.MigrateOperation` structures are at
 the base of how the autogenerate system knows what Python code to render.
-How to connect these two systems, so that our own custom operation
-directives can be used?  First off, we'd probably be implementing
-a :paramref:`.EnvironmentContext.configure.process_revision_directives`
-plugin as described previously, so that we can add our own directives
-to the autogenerate stream.  What if we wanted to add our ``CreateSequenceOp``
-to the autogenerate structure?  We basically need to define an autogenerate
-renderer for it, as follows::
+Using this knowledge, we can create additional functions that plug into
+the autogenerate system so that our new operations can be generated
+into migration scripts when ``alembic revision --autogenerate`` is run.
+
+The following sections will detail an example of this using the
+the ``CreateSequenceOp`` and ``DropSequenceOp`` directives
+we created in :ref:`operation_plugins`, which correspond to the
+SQLAlchemy :class:`~sqlalchemy.schema.Sequence` construct.
+
+.. versionadded:: 0.8.0 - custom operations can be added to the
+   autogenerate system to support new kinds of database objects.
+
+Tracking our Object with the Model
+----------------------------------
+
+The basic job of an autogenerate comparison function is to inspect
+a series of objects in the database and compare them against a series
+of objects defined in our model.  By "in our model", we mean anything
+defined in Python code that we want to track, however most commonly
+we're talking about a series of :class:`~sqlalchemy.schema.Table`
+objects present in a :class:`~sqlalchemy.schema.MetaData` collection.
+
+Let's propose a simple way of seeing what :class:`~sqlalchemy.schema.Sequence`
+objects we want to ensure exist in the database when autogenerate
+runs.  While these objects do have some integrations with
+:class:`~sqlalchemy.schema.Table` and :class:`~sqlalchemy.schema.MetaData`
+already, let's assume they don't, as the example here intends to illustrate
+how we would do this for most any kind of custom construct.   We
+associate the object with the :attr:`~sqlalchemy.schema.MetaData.info`
+collection of :class:`~sqlalchemy.schema.MetaData`, which is a dictionary
+we can use for anything, which we also know will be passed to the autogenerate
+process::
+
+    from sqlalchemy.schema import Sequence
+
+    def add_sequence_to_model(sequence, metadata):
+        metadata.info.setdefault("sequences", set()).add(
+            (sequence.schema, sequence.name)
+        )
+
+    my_seq = Sequence("my_sequence")
+    add_sequence_to_model(my_seq, model_metadata)
+
+The :attr:`~sqlalchemy.schema.MetaData.info`
+dictionary is a good place to put things that we want our autogeneration
+routines to be able to locate, which can include any object such as
+custom DDL objects representing views, triggers, special constraints,
+or anything else we want to support.
+
 
-    # note: this is a continuation of the example from the
-    # "Operation Plugins" section
+Registering a Comparison Function
+---------------------------------
+
+We now need to register a comparison hook, which will be used
+to compare the database to our model and produce ``CreateSequenceOp``
+and ``DropSequenceOp`` directives to be included in our migration
+script.  Note that we are assuming a
+Postgresql backend::
+
+    from alembic.autogenerate import comparators
+
+    @comparators.dispatch_for("schema")
+    def compare_sequences(autogen_context, upgrade_ops, schemas):
+        all_conn_sequences = set()
+
+        for sch in schemas:
+
+            all_conn_sequences.update([
+                (sch, row[0]) for row in
+                autogen_context.connection.execute(
+                    "SELECT relname FROM pg_class c join "
+                    "pg_namespace n on n.oid=c.relnamespace where "
+                    "relkind='S' and n.nspname=%(nspname)s",
+
+                    # note that we consider a schema of 'None' in our
+                    # model to be the "default" name in the PG database;
+                    # this usually is the name 'public'
+                    nspname=autogen_context.dialect.default_schema_name
+                    if sch is None else sch
+                )
+            ])
+
+        # get the collection of Sequence objects we're storing with
+        # our MetaData
+        metadata_sequences = autogen_context.metadata.info.setdefault(
+            "sequences", set())
+
+        # for new names, produce CreateSequenceOp directives
+        for sch, name in metadata_sequences.difference(all_conn_sequences):
+            upgrade_ops.ops.append(
+                CreateSequenceOp(name, schema=sch)
+            )
+
+        # for names that are going away, produce DropSequenceOp
+        # directives
+        for sch, name in all_conn_sequences.difference(metadata_sequences):
+            upgrade_ops.ops.append(
+                DropSequenceOp(name, schema=sch)
+            )
+
+Above, we've built a new function ``compare_sequences()`` and registered
+it as a "schema" level comparison function with autogenerate.   The
+job that it performs is that it compares the list of sequence names
+present in each database schema with that of a list of sequence names
+that we are maintaining in our :class:`~sqlalchemy.schema.MetaData` object.
+
+When autogenerate completes, it will have a series of
+``CreateSequenceOp`` and ``DropSequenceOp`` directives in the list of
+"upgrade" operations;  the list of "downgrade" operations is generated
+directly from these using the
+``CreateSequenceOp.reverse()`` and ``DropSequenceOp.reverse()`` methods
+that we've implemented on these objects.
+
+The registration of our function at the scope of "schema" means our
+autogenerate comparison function is called outside of the context
+of any specific table or column.  The three available scopes
+are "schema", "table", and "column", summarized as follows:
+
+* **Schema level** - these hooks are passed a :class:`.AutogenContext`,
+  an :class:`.UpgradeOps` collection, and a collection of string schema
+  names to be operated upon. If the
+  :class:`.UpgradeOps` collection contains changes after all
+  hooks are run, it is included in the migration script:
+
+  ::
+
+        @comparators.dispatch_for("schema")
+        def compare_schema_level(autogen_context, upgrade_ops, schemas):
+            pass
+
+* **Table level** - these hooks are passed a :class:`.AutogenContext`,
+  a :class:`.ModifyTableOps` collection, a schema name, table name,
+  a :class:`~sqlalchemy.schema.Table` reflected from the database if any
+  or ``None``, and a :class:`~sqlalchemy.schema.Table` present in the
+  local :class:`~sqlalchemy.schema.MetaData`.  If the
+  :class:`.ModifyTableOps` collection contains changes after all
+  hooks are run, it is included in the migration script:
+
+  ::
+
+        @comparators.dispatch_for("table")
+        def compare_table_level(autogen_context, modify_ops,
+            schemaname, tablename, conn_table, metadata_table):
+            pass
+
+* **Column level** - these hooks are passed a :class:`.AutogenContext`,
+  an :class:`.AlterColumnOp` object, a schema name, table name,
+  column name, a :class:`~sqlalchemy.schema.Column` reflected from the
+  database and a :class:`~sqlalchemy.schema.Column` present in the
+  local table.  If the :class:`.AlterColumnOp` contains changes after
+  all hooks are run, it is included in the migration script;
+  a "change" is considered to be present if any of the ``modify_`` attributes
+  are set to a non-default value, or there are any keys
+  in the ``.kw`` collection with the prefix ``"modify_"``:
+
+  ::
+
+        @comparators.dispatch_for("column")
+        def compare_column_level(autogen_context, alter_column_op,
+            schemaname, tname, cname, conn_col, metadata_col):
+            pass
+
+The :class:`.AutogenContext` passed to these hooks is documented below.
+
+.. autoclass:: alembic.autogenerate.api.AutogenContext
+    :members:
+
+Creating a Render Function
+--------------------------
+
+The second autogenerate integration hook is to provide a "render" function;
+since the autogenerate
+system renders Python code, we need to build a function that renders
+the correct "op" instructions for our directive::
 
     from alembic.autogenerate import renderers
 
@@ -209,27 +389,52 @@ renderer for it, as follows::
     def render_create_sequence(autogen_context, op):
         return "op.create_sequence(%r, **%r)" % (
             op.sequence_name,
-            op.kw
+            {"schema": op.schema}
         )
 
-With our render function established, we can our ``CreateSequenceOp``
-generated in an autogenerate context using the :func:`.render_python_code`
-debugging function in conjunction with an :class:`.UpgradeOps` structure::
 
-    from alembic.operations import ops
-    from alembic.autogenerate import render_python_code
+    @renderers.dispatch_for(DropSequenceOp)
+    def render_drop_sequence(autogen_context, op):
+        return "op.drop_sequence(%r, **%r)" % (
+            op.sequence_name,
+            {"schema": op.schema}
+        )
 
-    upgrade_ops = ops.UpgradeOps(
-        ops=[
-            CreateSequenceOp("my_seq")
-        ]
-    )
+The above functions will render Python code corresponding to the
+presence of ``CreateSequenceOp`` and ``DropSequenceOp`` instructions
+in the list that our comparison function generates.
 
-    print(render_python_code(upgrade_ops))
+Running It
+----------
 
-Which produces::
+All the above code can be organized however the developer sees fit;
+the only thing that needs to make it work is that when the
+Alembic environment ``env.py`` is invoked, it either imports modules
+which contain all the above routines, or they are locally present,
+or some combination thereof.
 
-    ### commands auto generated by Alembic - please adjust! ###
-        op.create_sequence('my_seq', **{})
+If we then have code in our model (which of course also needs to be invoked
+when ``env.py`` runs!) like this::
+
+    from sqlalchemy.schema import Sequence
+
+    my_seq_1 = Sequence("my_sequence_1")
+    add_sequence_to_model(my_seq_1, target_metadata)
+
+When we first run ``alembic revision --autogenerate``, we'll see this
+in our migration file::
+
+    def upgrade():
+        ### commands auto generated by Alembic - please adjust! ###
+        op.create_sequence('my_sequence_1', **{'schema': None})
         ### end Alembic commands ###
 
+
+    def downgrade():
+        ### commands auto generated by Alembic - please adjust! ###
+        op.drop_sequence('my_sequence_1', **{'schema': None})
+        ### end Alembic commands ###
+
+These are our custom directives that will invoke when ``alembic upgrade``
+or ``alembic downgrade`` is run.
+
index d9ff238f08758012fc06acc24cfe3a5a0753890b..2eb8358e734fcf593ec21161f7c0c74797c93903 100644 (file)
@@ -1,7 +1,7 @@
 .. _alembic.operations.toplevel:
 
 =====================
-The Operations Object
+Operation Directives
 =====================
 
 Within migration scripts, actual database migration operations are handled
@@ -48,9 +48,9 @@ migration scripts::
     class CreateSequenceOp(MigrateOperation):
         """Create a SEQUENCE."""
 
-        def __init__(self, sequence_name, **kw):
+        def __init__(self, sequence_name, schema=None):
             self.sequence_name = sequence_name
-            self.kw = kw
+            self.schema = schema
 
         @classmethod
         def create_sequence(cls, operations, sequence_name, **kw):
@@ -59,20 +59,58 @@ migration scripts::
             op = CreateSequenceOp(sequence_name, **kw)
             return operations.invoke(op)
 
-Above, the ``CreateSequenceOp`` class represents a new operation that will
-be available as ``op.create_sequence()``.   The reason the operation
-is represented as a stateful class is so that an operation and a specific
+        def reverse(self):
+            # only needed to support autogenerate
+            return DropSequenceOp(self.sequence_name, schema=self.schema)
+
+    @Operations.register_operation("drop_sequence")
+    class DropSequenceOp(MigrateOperation):
+        """Drop a SEQUENCE."""
+
+        def __init__(self, sequence_name, schema=None):
+            self.sequence_name = sequence_name
+            self.schema = schema
+
+        @classmethod
+        def drop_sequence(cls, operations, sequence_name, **kw):
+            """Issue a "DROP SEQUENCE" instruction."""
+
+            op = DropSequenceOp(sequence_name, **kw)
+            return operations.invoke(op)
+
+        def reverse(self):
+            # only needed to support autogenerate
+            return CreateSequenceOp(self.sequence_name, schema=self.schema)
+
+Above, the ``CreateSequenceOp`` and ``DropSequenceOp`` classes represent
+new operations that will
+be available as ``op.create_sequence()`` and ``op.drop_sequence()``.
+The reason the operations
+are represented as stateful classes is so that an operation and a specific
 set of arguments can be represented generically; the state can then correspond
 to different kinds of operations, such as invoking the instruction against
 a database, or autogenerating Python code for the operation into a
 script.
 
-In order to establish the migrate-script behavior of the new operation,
+In order to establish the migrate-script behavior of the new operations,
 we use the :meth:`.Operations.implementation_for` decorator::
 
     @Operations.implementation_for(CreateSequenceOp)
     def create_sequence(operations, operation):
-        operations.execute("CREATE SEQUENCE %s" % operation.sequence_name)
+        if operation.schema is not None:
+            name = "%s.%s" % (operation.schema, operation.sequence_name)
+        else:
+            name = operation.sequence_name
+        operations.execute("CREATE SEQUENCE %s" % name)
+
+
+    @Operations.implementation_for(DropSequenceOp)
+    def drop_sequence(operations, operation):
+        if operation.schema is not None:
+            name = "%s.%s" % (operation.schema, operation.sequence_name)
+        else:
+            name = operation.sequence_name
+        operations.execute("DROP SEQUENCE %s" % name)
 
 Above, we use the simplest possible technique of invoking our DDL, which
 is just to call :meth:`.Operations.execute` with literal SQL.  If this is
@@ -80,16 +118,24 @@ all a custom operation needs, then this is fine.  However, options for
 more comprehensive support include building out a custom SQL construct,
 as documented at :ref:`sqlalchemy.ext.compiler_toplevel`.
 
-With the above two steps, a migration script can now use a new method
-``op.create_sequence()`` that will proxy to our object as a classmethod::
+With the above two steps, a migration script can now use new methods
+``op.create_sequence()`` and ``op.drop_sequence()`` that will proxy to
+our object as a classmethod::
 
     def upgrade():
         op.create_sequence("my_sequence")
 
+    def downgrade():
+        op.drop_sequence("my_sequence")
+
 The registration of new operations only needs to occur in time for the
 ``env.py`` script to invoke :meth:`.MigrationContext.run_migrations`;
 within the module level of the ``env.py`` script is sufficient.
 
+.. seealso::
+
+    :ref:`autogen_custom_ops` - how to add autogenerate support to
+    custom operations.
 
 .. versionadded:: 0.8 - the migration operations available via the
    :class:`.Operations` class as well as the ``alembic.op`` namespace
index 691402d404c5c7c8ec5509e94eac870f34f682be..424bf8f1da7e2eb2e39d99adab4eb834e2196a79 100644 (file)
@@ -27,7 +27,7 @@ Changelog
 
     .. change::
       :tags: feature, autogenerate
-      :tickets: 301
+      :tickets: 301, 306
 
       The internal system for autogenerate been reworked to build upon
       the extensible system of operation objects present in
@@ -38,9 +38,12 @@ Changelog
       :paramref:`.EnvironmentContext.configure.process_revision_directives`
       allows end-user code to fully customize what autogenerate will do,
       including not just full manipulation of the Python steps to take
-      but also what file or files will be written and where.  It is also
-      possible to write a system that reads an autogenerate stream and
-      invokes it directly against a database without writing any files.
+      but also what file or files will be written and where.  Additionally,
+      autogenerate is now extensible as far as database objects compared
+      and rendered into scripts; any new operation directive can also be
+      registered into a series of hooks that allow custom database/model
+      comparison functions to run as well as to render new operation
+      directives into autogenerate scripts.
 
       .. seealso::
 
index 7ef6cbf706d68708804442415fb4e3d7a72c641b..e66888522c4a63eb3edc776ebac24e50495d6524 100644 (file)
@@ -2,12 +2,14 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
     Numeric, CHAR, ForeignKey, Index, UniqueConstraint, CheckConstraint, text
 from sqlalchemy.engine.reflection import Inspector
 
+from alembic.operations import ops
 from alembic import autogenerate
 from alembic.migration import MigrationContext
 from alembic.testing import config
 from alembic.testing.env import staging_env, clear_staging_env
 from alembic.testing import eq_
 from alembic.ddl.base import _fk_spec
+from alembic.autogenerate import api
 
 names_in_this_test = set()
 
@@ -25,9 +27,7 @@ def _default_include_object(obj, name, type_, reflected, compare_to):
     else:
         return True
 
-_default_object_filters = [
-    _default_include_object
-]
+_default_object_filters = _default_include_object
 
 
 class ModelOne(object):
@@ -177,6 +177,7 @@ class AutogenTest(_ComparesFKs):
             'downgrade_token': "downgrades",
             'alembic_module_prefix': 'op.',
             'sqlalchemy_module_prefix': 'sa.',
+            'include_object': _default_object_filters
         }
         if self.configure_opts:
             ctx_opts.update(self.configure_opts)
@@ -185,17 +186,18 @@ class AutogenTest(_ComparesFKs):
             opts=ctx_opts
         )
 
-        connection = context.bind
-        self.autogen_context = {
-            'imports': set(),
-            'connection': connection,
-            'dialect': connection.dialect,
-            'context': context
-        }
+        self.autogen_context = api.AutogenContext(context, self.m2)
 
     def tearDown(self):
         self.conn.close()
 
+    def _update_context(self, object_filters=None, include_schemas=None):
+        if include_schemas is not None:
+            self.autogen_context.opts['include_schemas'] = include_schemas
+        if object_filters is not None:
+            self.autogen_context._object_filters = [object_filters]
+        return self.autogen_context
+
 
 class AutogenFixtureTest(_ComparesFKs):
 
@@ -214,6 +216,8 @@ class AutogenFixtureTest(_ComparesFKs):
                 'downgrade_token': "downgrades",
                 'alembic_module_prefix': 'op.',
                 'sqlalchemy_module_prefix': 'sa.',
+                'include_object': object_filters,
+                'include_schemas': include_schemas
             }
             if opts:
                 ctx_opts.update(opts)
@@ -222,21 +226,12 @@ class AutogenFixtureTest(_ComparesFKs):
                 opts=ctx_opts
             )
 
-            connection = context.bind
-            autogen_context = {
-                'imports': set(),
-                'connection': connection,
-                'dialect': connection.dialect,
-                'context': context,
-                'metadata': model_metadata,
-                'object_filters': object_filters,
-                'include_schemas': include_schemas
-            }
-            diffs = []
+            autogen_context = api.AutogenContext(context, model_metadata)
+            uo = ops.UpgradeOps(ops=[])
             autogenerate._produce_net_changes(
-                autogen_context, diffs
+                autogen_context, uo
             )
-            return diffs
+            return uo.as_diffs()
 
     reports_unnamed_constraints = False
 
index b1717ab94c67b86205d2f31ef23d41748f62e313..6d1f55b057d12fe81d2fbb1865e22981a45576a8 100644 (file)
@@ -23,7 +23,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
             }
         )
         template_args = {}
-        autogenerate._render_migration_diffs(context, template_args, set())
+        autogenerate._render_migration_diffs(context, template_args)
 
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
             """### commands auto generated by Alembic - please adjust! ###
@@ -50,10 +50,8 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
             }
         )
         template_args = {}
-        autogenerate._render_migration_diffs(
-            context, template_args, set(),
+        autogenerate._render_migration_diffs(context, template_args)
 
-        )
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
             """### commands auto generated by Alembic - please adjust! ###
     pass
@@ -67,8 +65,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         """test a full render including indentation"""
 
         template_args = {}
-        autogenerate._render_migration_diffs(
-            self.context, template_args, set())
+        autogenerate._render_migration_diffs(self.context, template_args)
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
             """### commands auto generated by Alembic - please adjust! ###
     op.create_table('item',
@@ -135,8 +132,7 @@ nullable=True))
 
         template_args = {}
         self.context.opts['render_as_batch'] = True
-        autogenerate._render_migration_diffs(
-            self.context, template_args, set())
+        autogenerate._render_migration_diffs(self.context, template_args)
 
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
             """### commands auto generated by Alembic - please adjust! ###
@@ -229,10 +225,8 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
             }
         )
         template_args = {}
-        autogenerate._render_migration_diffs(
-            context, template_args, set(),
+        autogenerate._render_migration_diffs(context, template_args)
 
-        )
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
             """### commands auto generated by Alembic - please adjust! ###
     pass
@@ -250,9 +244,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
             'include_object': _default_include_object,
             'include_schemas': True
         })
-        autogenerate._render_migration_diffs(
-            self.context, template_args, set()
-        )
+        autogenerate._render_migration_diffs(self.context, template_args)
 
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
             """### commands auto generated by Alembic - please adjust! ###
@@ -326,3 +318,4 @@ name='extra_uid_fkey'),
     )
     op.drop_table('item', schema='%(schema)s')
     ### end Alembic commands ###""" % {"schema": self.schema})
+
index f32fd8492b853f63e0f8283f791671d33a4fd618..d176b913a2c1203a531b85e7cab4764e31696065 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
 from sqlalchemy.types import NULLTYPE
 from sqlalchemy.engine.reflection import Inspector
 
+from alembic.operations import ops
 from alembic import autogenerate
 from alembic.migration import MigrationContext
 from alembic.testing import TestBase
@@ -14,8 +15,7 @@ from alembic.testing import assert_raises_message
 from alembic.testing.mock import Mock
 from alembic.testing import eq_
 from alembic.util import CommandError
-from ._autogen_fixtures import \
-    AutogenTest, AutogenFixtureTest, _default_object_filters
+from ._autogen_fixtures import AutogenTest, AutogenFixtureTest
 
 py3k = sys.version_info >= (3, )
 
@@ -63,25 +63,24 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
         return m
 
     def test_default_schema_omitted_upgrade(self):
-        diffs = []
 
         def include_object(obj, name, type_, reflected, compare_to):
             if type_ == "table":
                 return name == "t3"
             else:
                 return True
-        self.autogen_context.update({
-            'object_filters': [include_object],
-            'include_schemas': True,
-            'metadata': self.m2
-        })
-        autogenerate._produce_net_changes(self.autogen_context, diffs)
+        self._update_context(
+            object_filters=include_object,
+            include_schemas=True,
+        )
+        uo = ops.UpgradeOps(ops=[])
+        autogenerate._produce_net_changes(self.autogen_context, uo)
 
+        diffs = uo.as_diffs()
         eq_(diffs[0][0], "add_table")
         eq_(diffs[0][1].schema, None)
 
     def test_alt_schema_included_upgrade(self):
-        diffs = []
 
         def include_object(obj, name, type_, reflected, compare_to):
             if type_ == "table":
@@ -89,48 +88,48 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
             else:
                 return True
 
-        self.autogen_context.update({
-            'object_filters': [include_object],
-            'include_schemas': True,
-            'metadata': self.m2
-        })
-        autogenerate._produce_net_changes(self.autogen_context, diffs)
+        self._update_context(
+            object_filters=include_object,
+            include_schemas=True,
+        )
+        uo = ops.UpgradeOps(ops=[])
+        autogenerate._produce_net_changes(self.autogen_context, uo)
 
+        diffs = uo.as_diffs()
         eq_(diffs[0][0], "add_table")
         eq_(diffs[0][1].schema, config.test_schema)
 
     def test_default_schema_omitted_downgrade(self):
-        diffs = []
-
         def include_object(obj, name, type_, reflected, compare_to):
             if type_ == "table":
                 return name == "t1"
             else:
                 return True
-        self.autogen_context.update({
-            'object_filters': [include_object],
-            'include_schemas': True,
-            'metadata': self.m2
-        })
-        autogenerate._produce_net_changes(self.autogen_context, diffs)
+        self._update_context(
+            object_filters=include_object,
+            include_schemas=True,
+        )
+        uo = ops.UpgradeOps(ops=[])
+        autogenerate._produce_net_changes(self.autogen_context, uo)
 
+        diffs = uo.as_diffs()
         eq_(diffs[0][0], "remove_table")
         eq_(diffs[0][1].schema, None)
 
     def test_alt_schema_included_downgrade(self):
-        diffs = []
 
         def include_object(obj, name, type_, reflected, compare_to):
             if type_ == "table":
                 return name == "t2"
             else:
                 return True
-        self.autogen_context.update({
-            'object_filters': [include_object],
-            'include_schemas': True,
-            'metadata': self.m2
-        })
-        autogenerate._produce_net_changes(self.autogen_context, diffs)
+        self._update_context(
+            object_filters=include_object,
+            include_schemas=True,
+        )
+        uo = ops.UpgradeOps(ops=[])
+        autogenerate._produce_net_changes(self.autogen_context, uo)
+        diffs = uo.as_diffs()
         eq_(diffs[0][0], "remove_table")
         eq_(diffs[0][1].schema, config.test_schema)
 
@@ -268,14 +267,14 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         """test generation of diff rules"""
 
         metadata = self.m2
-        diffs = []
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
-        ctx['object_filters'] = _default_object_filters
+        uo = ops.UpgradeOps(ops=[])
+        ctx = self.autogen_context
+
         autogenerate._produce_net_changes(
-            ctx, diffs
+            ctx, uo
         )
 
+        diffs = uo.as_diffs()
         eq_(
             diffs[0],
             ('add_table', metadata.tables['item'])
@@ -396,21 +395,25 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         eq_(alter_cols, set(['user_id', 'order', 'user']))
 
     def test_skip_null_type_comparison_reflected(self):
-        diff = []
-        autogenerate.compare._compare_type(None, "sometable", "somecol",
-                                           Column("somecol", NULLTYPE),
-                                           Column("somecol", Integer()),
-                                           diff, self.autogen_context
-                                           )
+        ac = ops.AlterColumnOp("sometable", "somecol")
+        autogenerate.compare._compare_type(
+            self.autogen_context, ac,
+            None, "sometable", "somecol",
+            Column("somecol", NULLTYPE),
+            Column("somecol", Integer()),
+        )
+        diff = ac.to_diff_tuple()
         assert not diff
 
     def test_skip_null_type_comparison_local(self):
-        diff = []
-        autogenerate.compare._compare_type(None, "sometable", "somecol",
-                                           Column("somecol", Integer()),
-                                           Column("somecol", NULLTYPE),
-                                           diff, self.autogen_context
-                                           )
+        ac = ops.AlterColumnOp("sometable", "somecol")
+        autogenerate.compare._compare_type(
+            self.autogen_context, ac,
+            None, "sometable", "somecol",
+            Column("somecol", Integer()),
+            Column("somecol", NULLTYPE),
+        )
+        diff = ac.to_diff_tuple()
         assert not diff
 
     def test_custom_type_compare(self):
@@ -420,20 +423,24 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
             def compare_against_backend(self, dialect, conn_type):
                 return isinstance(conn_type, Integer)
 
-        diff = []
-        autogenerate.compare._compare_type(None, "sometable", "somecol",
-                                           Column("somecol", INTEGER()),
-                                           Column("somecol", MyType()),
-                                           diff, self.autogen_context
-                                           )
-        assert not diff
+        ac = ops.AlterColumnOp("sometable", "somecol")
+        autogenerate.compare._compare_type(
+            self.autogen_context, ac,
+            None, "sometable", "somecol",
+            Column("somecol", INTEGER()),
+            Column("somecol", MyType()),
+        )
+
+        assert not ac.has_changes()
 
-        diff = []
-        autogenerate.compare._compare_type(None, "sometable", "somecol",
-                                           Column("somecol", String()),
-                                           Column("somecol", MyType()),
-                                           diff, self.autogen_context
-                                           )
+        ac = ops.AlterColumnOp("sometable", "somecol")
+        autogenerate.compare._compare_type(
+            self.autogen_context, ac,
+            None, "sometable", "somecol",
+            Column("somecol", String()),
+            Column("somecol", MyType()),
+        )
+        diff = ac.to_diff_tuple()
         eq_(
             diff[0][0:4],
             ('modify_type', None, 'sometable', 'somecol')
@@ -449,26 +456,26 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
                 else:
                     return dialect.type_descriptor(CHAR(32))
 
-        diff = []
+        uo = ops.AlterColumnOp('sometable', 'somecol')
         autogenerate.compare._compare_type(
+            self.autogen_context, uo,
             None, "sometable", "somecol",
             Column("somecol", Integer, nullable=True),
-            Column("somecol", MyType()),
-            diff, self.autogen_context
+            Column("somecol", MyType())
         )
-        assert not diff
+        assert not uo.has_changes()
 
     def test_dont_barf_on_already_reflected(self):
-        diffs = []
         from sqlalchemy.util import OrderedSet
         inspector = Inspector.from_engine(self.bind)
+        uo = ops.UpgradeOps(ops=[])
         autogenerate.compare._compare_tables(
             OrderedSet([(None, 'extra'), (None, 'user')]),
-            OrderedSet(), [], inspector,
-            MetaData(), diffs, self.autogen_context
+            OrderedSet(), inspector,
+            MetaData(), uo, self.autogen_context
         )
         eq_(
-            [(rec[0], rec[1].name) for rec in diffs],
+            [(rec[0], rec[1].name) for rec in uo.as_diffs()],
             [('remove_table', 'extra'), ('remove_table', 'user')]
         )
 
@@ -481,14 +488,14 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
         """test generation of diff rules"""
 
         metadata = self.m2
-        diffs = []
 
-        self.autogen_context.update({
-            'object_filters': _default_object_filters,
-            'include_schemas': True,
-            'metadata': self.m2
-        })
-        autogenerate._produce_net_changes(self.autogen_context, diffs)
+        self._update_context(
+            include_schemas=True,
+        )
+        uo = ops.UpgradeOps(ops=[])
+        autogenerate._produce_net_changes(self.autogen_context, uo)
+
+        diffs = uo.as_diffs()
 
         eq_(
             diffs[0],
@@ -567,10 +574,10 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
         my_compare_type = Mock()
         self.context._user_compare_type = my_compare_type
 
-        diffs = []
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
-        autogenerate._produce_net_changes(ctx, diffs)
+        uo = ops.UpgradeOps(ops=[])
+
+        ctx = self.autogen_context
+        autogenerate._produce_net_changes(ctx, uo)
 
         first_table = self.m2.tables['sometable']
         first_column = first_table.columns['id']
@@ -593,8 +600,7 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
         self.context._user_compare_type = my_compare_type
 
         diffs = []
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
+        ctx = self.autogen_context
         diffs = []
         autogenerate._produce_net_changes(ctx, diffs)
 
@@ -605,10 +611,10 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
         my_compare_type.return_value = True
         self.context._user_compare_type = my_compare_type
 
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
-        diffs = []
-        autogenerate._produce_net_changes(ctx, diffs)
+        ctx = self.autogen_context
+        uo = ops.UpgradeOps(ops=[])
+        autogenerate._produce_net_changes(ctx, uo)
+        diffs = uo.as_diffs()
 
         eq_(diffs[0][0][0], 'modify_type')
         eq_(diffs[1][0][0], 'modify_type')
@@ -636,8 +642,7 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase):
 
     def test_no_change(self):
         diffs = []
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
+        ctx = self.autogen_context
         autogenerate._produce_net_changes(ctx, diffs)
         eq_(diffs, [])
 
@@ -674,11 +679,11 @@ class AutogenKeyTest(AutogenTest, TestBase):
 
     def test_autogen(self):
 
-        diffs = []
+        uo = ops.UpgradeOps(ops=[])
 
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
-        autogenerate._produce_net_changes(ctx, diffs)
+        ctx = self.autogen_context
+        autogenerate._produce_net_changes(ctx, uo)
+        diffs = uo.as_diffs()
         eq_(diffs[0][0], "add_table")
         eq_(diffs[0][1].name, "sometable")
         eq_(diffs[1][0], "add_column")
@@ -705,8 +710,7 @@ class AutogenVersionTableTest(AutogenTest, TestBase):
 
     def test_no_version_table(self):
         diffs = []
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
+        ctx = self.autogen_context
 
         autogenerate._produce_net_changes(ctx, diffs)
         eq_(diffs, [])
@@ -717,8 +721,7 @@ class AutogenVersionTableTest(AutogenTest, TestBase):
             self.version_table_name,
             self.m2, Column('x', Integer), schema=self.version_table_schema)
 
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
+        ctx = self.autogen_context
         autogenerate._produce_net_changes(ctx, diffs)
         eq_(diffs, [])
 
@@ -769,10 +772,10 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase):
         before their parent tables
         """
 
-        ctx = self.autogen_context.copy()
-        ctx['metadata'] = self.m2
-        diffs = []
-        autogenerate._produce_net_changes(ctx, diffs)
+        ctx = self.autogen_context
+        uo = ops.UpgradeOps(ops=[])
+        autogenerate._produce_net_changes(ctx, uo)
+        diffs = uo.as_diffs()
 
         eq_(diffs[0][0], 'add_table')
         eq_(diffs[0][1].name, "parent")
index 525bed588552b58d86323f05e67916411d180f7c..174a53895fc537cecd03de261dd5b80dd89802f1 100644 (file)
@@ -351,7 +351,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 type_ == 'foreign_key_constraint'
                 and reflected and name == 'fk1')
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         self._assert_fk_diff(
             diffs[0], "remove_fk",
@@ -390,7 +390,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 type_ == 'foreign_key_constraint'
                 and not reflected and name == 'fk1')
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         self._assert_fk_diff(
             diffs[0], "add_fk",
@@ -456,7 +456,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 and name == 'fk1'
             )
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         self._assert_fk_diff(
             diffs[0], "remove_fk",
index 8ee33bccd2c87ca7e8067fbca0b8b8039e01c973..9b6cd444a1cc3b1ad411abbeba1184b340615458 100644 (file)
@@ -798,7 +798,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 isinstance(object_, Index) and
                 type_ == 'index' and reflected and name == 'ix1')
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         eq_(diffs[0][0], 'remove_index')
         eq_(diffs[0][1].name, 'ix2')
@@ -825,7 +825,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 isinstance(object_, UniqueConstraint) and
                 type_ == 'unique_constraint' and reflected and name == 'uq1')
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         eq_(diffs[0][0], 'remove_constraint')
         eq_(diffs[0][1].name, 'uq2')
@@ -846,7 +846,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 isinstance(object_, Index) and
                 type_ == 'index' and not reflected and name == 'ix1')
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         eq_(diffs[0][0], 'add_index')
         eq_(diffs[0][1].name, 'ix2')
@@ -871,7 +871,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 type_ == 'unique_constraint' and
                 not reflected and name == 'uq1')
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         eq_(diffs[0][0], 'add_constraint')
         eq_(diffs[0][1].name, 'uq2')
@@ -899,7 +899,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 type_ == 'index' and not reflected and name == 'ix1'
                 and isinstance(compare_to, Index))
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         eq_(diffs[0][0], 'remove_index')
         eq_(diffs[0][1].name, 'ix2')
@@ -935,7 +935,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
                 not reflected and name == 'uq1'
                 and isinstance(compare_to, UniqueConstraint))
 
-        diffs = self._fixture(m1, m2, object_filters=[include_object])
+        diffs = self._fixture(m1, m2, object_filters=include_object)
 
         eq_(diffs[0][0], 'remove_constraint')
         eq_(diffs[0][1].name, 'uq2')
index 4a49d5c5732dc14160b6864998a72679c6e38064..a73cff51f36c88cdea548a01d6f882400d4da542 100644 (file)
@@ -14,6 +14,8 @@ from sqlalchemy.types import UserDefinedType
 from sqlalchemy.dialects import mysql, postgresql
 from sqlalchemy.engine.default import DefaultDialect
 from sqlalchemy.sql import and_, column, literal_column, false
+from alembic.migration import MigrationContext
+from alembic.autogenerate import api
 
 from alembic.testing.mock import patch
 
@@ -32,22 +34,30 @@ class AutogenRenderTest(TestBase):
 
     """test individual directives"""
 
-    @classmethod
-    def setup_class(cls):
-        cls.autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-            },
-            'dialect': mysql.dialect()
-        }
-        cls.pg_autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-            },
-            'dialect': postgresql.dialect()
+    def setUp(self):
+        ctx_opts = {
+            'sqlalchemy_module_prefix': 'sa.',
+            'alembic_module_prefix': 'op.',
+            'target_metadata': MetaData()
         }
+        context = MigrationContext.configure(
+            dialect_name="mysql",
+            opts=ctx_opts
+        )
+
+        self.autogen_context = api.AutogenContext(context)
+
+        context = MigrationContext.configure(
+            dialect_name="postgresql",
+            opts=ctx_opts
+        )
+        self.pg_autogen_context = api.AutogenContext(context)
+
+        context = MigrationContext.configure(
+            dialect=DefaultDialect(),
+            opts=ctx_opts
+        )
+        self.default_autogen_context = api.AutogenContext(context)
 
     def test_render_add_index(self):
         """
@@ -812,10 +822,10 @@ unique=False, """
                     return "col(%s)" % obj.name
             return "render:%s" % type_
 
-        autogen_context = {"opts": {
-            'render_item': render,
-            'alembic_module_prefix': 'sa.'
-        }}
+        self.autogen_context.opts.update(
+            render_item=render,
+            alembic_module_prefix='sa.'
+        )
 
         t = Table('t', MetaData(),
                   Column('x', Integer),
@@ -824,7 +834,7 @@ unique=False, """
                   ForeignKeyConstraint(['x'], ['y'])
                   )
         op_obj = ops.CreateTableOp.from_table(t)
-        result = autogenerate.render_op_text(autogen_context, op_obj)
+        result = autogenerate.render_op_text(self.autogen_context, op_obj)
         eq_ignore_whitespace(
             result,
             "sa.create_table('t',"
@@ -1087,28 +1097,13 @@ unique=False, """
 
     def test_repr_plain_sqla_type(self):
         type_ = Integer()
-        autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-            },
-            'dialect': mysql.dialect()
-        }
-
         eq_ignore_whitespace(
-            autogenerate.render._repr_type(type_, autogen_context),
+            autogenerate.render._repr_type(type_, self.autogen_context),
             "sa.Integer()"
         )
 
     def test_repr_custom_type_w_sqla_prefix(self):
-        autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-                'user_module_prefix': None
-            },
-            'dialect': mysql.dialect()
-        }
+        self.autogen_context.opts['user_module_prefix'] = None
 
         class MyType(UserDefinedType):
             pass
@@ -1118,7 +1113,7 @@ unique=False, """
         type_ = MyType()
 
         eq_ignore_whitespace(
-            autogenerate.render._repr_type(type_, autogen_context),
+            autogenerate.render._repr_type(type_, self.autogen_context),
             "sqlalchemy_util.types.MyType()"
         )
 
@@ -1129,17 +1124,10 @@ unique=False, """
                 return "MYTYPE"
 
         type_ = MyType()
-        autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-                'user_module_prefix': None
-            },
-            'dialect': mysql.dialect()
-        }
+        self.autogen_context.opts['user_module_prefix'] = None
 
         eq_ignore_whitespace(
-            autogenerate.render._repr_type(type_, autogen_context),
+            autogenerate.render._repr_type(type_, self.autogen_context),
             "tests.test_autogen_render.MyType()"
         )
 
@@ -1152,17 +1140,11 @@ unique=False, """
                 return "MYTYPE"
 
         type_ = MyType()
-        autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-                'user_module_prefix': 'user.',
-            },
-            'dialect': mysql.dialect()
-        }
+
+        self.autogen_context.opts['user_module_prefix'] = 'user.'
 
         eq_ignore_whitespace(
-            autogenerate.render._repr_type(type_, autogen_context),
+            autogenerate.render._repr_type(type_, self.autogen_context),
             "user.MyType()"
         )
 
@@ -1171,20 +1153,14 @@ unique=False, """
         from sqlalchemy.dialects.mysql import VARCHAR
 
         type_ = VARCHAR(20, charset='utf8', national=True)
-        autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-                'user_module_prefix': None,
-            },
-            'imports': set(),
-            'dialect': mysql.dialect()
-        }
+
+        self.autogen_context.opts['user_module_prefix'] = None
+
         eq_ignore_whitespace(
-            autogenerate.render._repr_type(type_, autogen_context),
+            autogenerate.render._repr_type(type_, self.autogen_context),
             "mysql.VARCHAR(charset='utf8', national=True, length=20)"
         )
-        eq_(autogen_context['imports'],
+        eq_(self.autogen_context._imports,
             set(['from sqlalchemy.dialects import mysql'])
             )
 
@@ -1204,19 +1180,12 @@ unique=False, """
         )
 
     def test_render_server_default_native_boolean(self):
-        autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-            },
-            'dialect': postgresql.dialect()
-        }
         c = Column(
             'updated_at', Boolean(),
             server_default=false(),
             nullable=False)
         result = autogenerate.render._render_column(
-            c, autogen_context,
+            c, self.autogen_context,
         )
         eq_ignore_whitespace(
             result,
@@ -1231,17 +1200,10 @@ unique=False, """
             'updated_at', Boolean(),
             server_default=false(),
             nullable=False)
-        dialect = DefaultDialect()
-        autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-            },
-            'dialect': dialect
-        }
+# MARKMARK
 
         result = autogenerate.render._render_column(
-            c, autogen_context
+            c, self.default_autogen_context
         )
         eq_ignore_whitespace(
             result,
@@ -1296,16 +1258,6 @@ unique=False, """
 class RenderNamingConventionTest(TestBase):
     __requires__ = ('sqlalchemy_094',)
 
-    @classmethod
-    def setup_class(cls):
-        cls.autogen_context = {
-            'opts': {
-                'sqlalchemy_module_prefix': 'sa.',
-                'alembic_module_prefix': 'op.',
-            },
-            'dialect': postgresql.dialect()
-        }
-
     def setUp(self):
 
         convention = {
@@ -1322,6 +1274,17 @@ class RenderNamingConventionTest(TestBase):
             naming_convention=convention
         )
 
+        ctx_opts = {
+            'sqlalchemy_module_prefix': 'sa.',
+            'alembic_module_prefix': 'op.',
+            'target_metadata': MetaData()
+        }
+        context = MigrationContext.configure(
+            dialect_name="postgresql",
+            opts=ctx_opts
+        )
+        self.autogen_context = api.AutogenContext(context)
+
     def test_schema_type_boolean(self):
         t = Table('t', self.metadata, Column('c', Boolean(name='xyz')))
         op_obj = ops.AddColumnOp.from_column(t.c.c)
@@ -1457,3 +1420,29 @@ class RenderNamingConventionTest(TestBase):
             "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))"
         )
 
+    def test_create_table_plus_add_index_in_modify(self):
+        uo = ops.UpgradeOps(ops=[
+            ops.CreateTableOp(
+                "sometable",
+                [Column('x', Integer), Column('y', Integer)]
+            ),
+            ops.ModifyTableOps(
+                "sometable", ops=[
+                    ops.CreateIndexOp('ix1', 'sometable', ['x', 'y'])
+                ]
+            )
+        ])
+
+        eq_(
+            autogenerate.render_python_code(uo, render_as_batch=True),
+            "### commands auto generated by Alembic - please adjust! ###\n"
+            "    op.create_table('sometable',\n"
+            "    sa.Column('x', sa.Integer(), nullable=True),\n"
+            "    sa.Column('y', sa.Integer(), nullable=True)\n"
+            "    )\n"
+            "    with op.batch_alter_table('sometable', schema=None) "
+            "as batch_op:\n"
+            "        batch_op.create_index("
+            "'ix1', ['x', 'y'], unique=False)\n\n"
+            "    ### end Alembic commands ###"
+        )
index e70d05a3b0efc79f12416d90ffc7ac273d1b70ac..576d957aecdb6de05027f4fa1c3ff2a37ad68e41 100644 (file)
@@ -8,9 +8,11 @@ from sqlalchemy.sql import table, column
 from alembic.autogenerate.compare import \
     _compare_server_default, _compare_tables, _render_server_default_for_compare
 
+from alembic.operations import ops
 from alembic import command, util
 from alembic.migration import MigrationContext
 from alembic.script import ScriptDirectory
+from alembic.autogenerate import api
 
 from alembic.testing import eq_, provide_metadata
 from alembic.testing.env import staging_env, clear_staging_env, \
@@ -162,34 +164,22 @@ class PostgresqlDefaultCompareTest(TestBase):
     def setup_class(cls):
         cls.bind = config.db
         staging_env()
-        context = MigrationContext.configure(
+        cls.migration_context = MigrationContext.configure(
             connection=cls.bind.connect(),
             opts={
                 'compare_type': True,
                 'compare_server_default': True
             }
         )
-        connection = context.bind
-        cls.autogen_context = {
-            'imports': set(),
-            'connection': connection,
-            'dialect': connection.dialect,
-            'context': context,
-            'opts': {
-                'compare_type': True,
-                'compare_server_default': True,
-                'alembic_module_prefix': 'op.',
-                'sqlalchemy_module_prefix': 'sa.',
-            }
-        }
+
+    def setUp(self):
+        self.metadata = MetaData(self.bind)
+        self.autogen_context = api.AutogenContext(self.migration_context)
 
     @classmethod
     def teardown_class(cls):
         clear_staging_env()
 
-    def setUp(self):
-        self.metadata = MetaData(self.bind)
-
     def tearDown(self):
         self.metadata.drop_all()
 
@@ -212,9 +202,12 @@ class PostgresqlDefaultCompareTest(TestBase):
         cols = insp.get_columns(t1.name)
         insp_col = Column("somecol", cols[0]['type'],
                           server_default=text(cols[0]['default']))
-        diffs = []
-        _compare_server_default(None, "test", "somecol", insp_col,
-                                t2.c.somecol, diffs, self.autogen_context)
+        op = ops.AlterColumnOp("test", "somecol")
+        _compare_server_default(
+            self.autogen_context, op,
+            None, "test", "somecol", insp_col, t2.c.somecol)
+
+        diffs = op.to_diff_tuple()
         eq_(bool(diffs), diff_expected)
 
     def _compare_default(
@@ -225,7 +218,7 @@ class PostgresqlDefaultCompareTest(TestBase):
         t1.create(self.bind, checkfirst=True)
         insp = Inspector.from_engine(self.bind)
         cols = insp.get_columns(t1.name)
-        ctx = self.autogen_context['context']
+        ctx = self.autogen_context.migration_context
 
         return ctx.impl.compare_server_default(
             None,
@@ -385,26 +378,16 @@ class PostgresqlDetectSerialTest(TestBase):
         cls.bind = config.db
         cls.conn = cls.bind.connect()
         staging_env()
-        context = MigrationContext.configure(
+        cls.migration_context = MigrationContext.configure(
             connection=cls.conn,
             opts={
                 'compare_type': True,
                 'compare_server_default': True
             }
         )
-        connection = context.bind
-        cls.autogen_context = {
-            'imports': set(),
-            'connection': connection,
-            'dialect': connection.dialect,
-            'context': context,
-            'opts': {
-                'compare_type': True,
-                'compare_server_default': True,
-                'alembic_module_prefix': 'op.',
-                'sqlalchemy_module_prefix': 'sa.',
-            }
-        }
+
+    def setUp(self):
+        self.autogen_context = api.AutogenContext(self.migration_context)
 
     @classmethod
     def teardown_class(cls):
@@ -420,24 +403,26 @@ class PostgresqlDetectSerialTest(TestBase):
         self.metadata.create_all(config.db)
 
         insp = Inspector.from_engine(config.db)
-        diffs = []
+
+        uo = ops.UpgradeOps(ops=[])
         _compare_tables(
             set([(None, 't')]), set([]),
-            [],
-            insp, self.metadata, diffs, self.autogen_context)
+            insp, self.metadata, uo, self.autogen_context)
+        diffs = uo.as_diffs()
         tab = diffs[0][1]
+
         eq_(_render_server_default_for_compare(
             tab.c.x.server_default, tab.c.x, self.autogen_context),
             c_expected)
 
         insp = Inspector.from_engine(config.db)
-        diffs = []
+        uo = ops.UpgradeOps(ops=[])
         m2 = MetaData()
         Table('t', m2, Column('x', BigInteger()))
         _compare_tables(
             set([(None, 't')]), set([(None, 't')]),
-            [],
-            insp, m2, diffs, self.autogen_context)
+            insp, m2, uo, self.autogen_context)
+        diffs = uo.as_diffs()
         server_default = diffs[0][0][4]['existing_server_default']
         eq_(_render_server_default_for_compare(
             server_default, tab.c.x, self.autogen_context),