]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- add a helper object for autogen rewriting called Rewriter.
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 7 Aug 2015 21:58:12 +0000 (17:58 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 7 Aug 2015 21:58:12 +0000 (17:58 -0400)
this provides for operation-specific handler functions.
docs are based on the example requested in references #313.

alembic/autogenerate/__init__.py
alembic/autogenerate/rewriter.py [new file with mode: 0644]
alembic/runtime/environment.py
alembic/util/langhelpers.py
docs/build/api/autogenerate.rst
tests/test_script_production.py

index 78520a8567c7a0f5444d549df6a14b7a45271830..142f55d04fee32dc1af56f86b7e0dc13646f448b 100644 (file)
@@ -4,4 +4,5 @@ from .api import ( # noqa
     RevisionContext
     )
 from .compare import _produce_net_changes, comparators  # noqa
-from .render import render_op_text, renderers  # noqa
\ No newline at end of file
+from .render import render_op_text, renderers  # noqa
+from .rewriter import Rewriter  # noqa
\ No newline at end of file
diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py
new file mode 100644 (file)
index 0000000..c84712c
--- /dev/null
@@ -0,0 +1,142 @@
+from alembic import util
+from alembic.operations import ops
+
+
+class Rewriter(object):
+    """A helper object that allows easy 'rewriting' of ops streams.
+
+    The :class:`.Rewriter` object is intended to be passed along
+    to the
+    :paramref:`.EnvironmentContext.configure.process_revision_directives`
+    parameter in an ``env.py`` script.    Once constructed, any number
+    of "rewrites" functions can be associated with it, which will be given
+    the opportunity to modify the structure without having to have explicit
+    knowledge of the overall structure.
+
+    The function is passed the :class:`.MigrationContext` object and
+    ``revision`` tuple that are passed to the  :paramref:`.Environment
+    Context.configure.process_revision_directives` function normally,
+    and the third argument is an individual directive of the type
+    noted in the decorator.  The function has the choice of  returning
+    a single op directive, which normally can be the directive that
+    was actually passed, or a new directive to replace it, or a list
+    of zero or more directives to replace it.
+
+    .. seealso::
+
+        :ref:`autogen_rewriter` - usage example
+
+    .. versionadded:: 0.8
+
+    """
+
+    _traverse = util.Dispatcher()
+
+    _chained = None
+
+    def __init__(self):
+        self.dispatch = util.Dispatcher()
+
+    def chain(self, other):
+        """Produce a "chain" of this :class:`.Rewriter` to another.
+
+        This allows two rewriters to operate serially on a stream,
+        e.g.::
+
+            writer1 = autogenerate.Rewriter()
+            writer2 = autogenerate.Rewriter()
+
+            @writer1.rewrites(ops.AddColumnOp)
+            def add_column_nullable(context, revision, op):
+                op.column.nullable = True
+                return op
+
+            @writer2.rewrites(ops.AddColumnOp)
+            def add_column_idx(context, revision, op):
+                idx_op = ops.CreateIndexOp(
+                    'ixc', op.table_name, [op.column.name])
+                return [
+                    op,
+                    idx_op
+                ]
+
+            writer = writer1.chain(writer2)
+
+        :param other: a :class:`.Rewriter` instance
+        :return: a new :class:`.Rewriter` that will run the operations
+         of this writer, then the "other" writer, in succession.
+
+        """
+        wr = self.__class__.__new__(self.__class__)
+        wr.__dict__.update(self.__dict__)
+        wr._chained = other
+        return wr
+
+    def rewrites(self, operator):
+        """Register a function as rewriter for a given type.
+
+        The function should receive three arguments, which are
+        the :class:`.MigrationContext`, a ``revision`` tuple, and
+        an op directive of the type indicated.  E.g.::
+
+            @writer1.rewrites(ops.AddColumnOp)
+            def add_column_nullable(context, revision, op):
+                op.column.nullable = True
+                return op
+
+        """
+        return self.dispatch.dispatch_for(operator)
+
+    def _rewrite(self, context, revision, directive):
+        try:
+            _rewriter = self.dispatch.dispatch(directive)
+        except ValueError:
+            _rewriter = None
+            yield directive
+        else:
+            for r_directive in util.to_list(
+                    _rewriter(context, revision, directive)):
+                yield r_directive
+
+    def __call__(self, context, revision, directives):
+        self.process_revision_directives(context, revision, directives)
+        if self._chained:
+            self._chained(context, revision, directives)
+
+    @_traverse.dispatch_for(ops.MigrationScript)
+    def _traverse_script(self, context, revision, directive):
+        ret = self._traverse_for(context, revision, directive.upgrade_ops)
+        if len(ret) != 1:
+            raise ValueError(
+                "Can only return single object for UpgradeOps traverse")
+        directive.upgrade_ops = ret[0]
+        ret = self._traverse_for(context, revision, directive.downgrade_ops)
+        if len(ret) != 1:
+            raise ValueError(
+                "Can only return single object for DowngradeOps traverse")
+        directive.downgrade_ops = ret[0]
+
+    @_traverse.dispatch_for(ops.OpContainer)
+    def _traverse_op_container(self, context, revision, directive):
+        self._traverse_list(context, revision, directive.ops)
+
+    @_traverse.dispatch_for(ops.MigrateOperation)
+    def _traverse_any_directive(self, context, revision, directive):
+        pass
+
+    def _traverse_for(self, context, revision, directive):
+        directives = list(self._rewrite(context, revision, directive))
+        for directive in directives:
+            traverser = self._traverse.dispatch(directive)
+            traverser(self, context, revision, directive)
+        return directives
+
+    def _traverse_list(self, context, revision, directives):
+        dest = []
+        for directive in directives:
+            dest.extend(self._traverse_for(context, revision, directive))
+
+        directives[:] = dest
+
+    def process_revision_directives(self, context, revision, directives):
+        self._traverse_list(context, revision, directives)
index 7eb06edc500c4653ea4c4eafb4eeb161636f736c..3b6252c1e256a5a4f7c2440bea518162b03b80bc 100644 (file)
@@ -690,6 +690,10 @@ class EnvironmentContext(util.ModuleClsProxy):
          ``--autogenerate`` option itself can be inferred by inspecting
          ``context.config.cmd_opts.autogenerate``.
 
+         The callable function may optionally be an instance of
+         a :class:`.Rewriter` object.  This is a helper object that
+         assists in the production of autogenerate-stream rewriter functions.
+
 
          .. versionadded:: 0.8.0
 
@@ -697,6 +701,8 @@ class EnvironmentContext(util.ModuleClsProxy):
 
              :ref:`customizing_revision`
 
+             :ref:`autogen_rewriter`
+
 
         Parameters specific to individual backends:
 
index 9445949ed14850793f886d0ce61d73e52ea46678..54e5e806ca09177219f2c918c0ecf1a1d2c8ce56 100644 (file)
@@ -194,7 +194,7 @@ def to_list(x, default=None):
     elif isinstance(x, collections.Iterable):
         return list(x)
     else:
-        raise ValueError("Don't know how to turn %r into a list" % x)
+        return [x]
 
 
 def to_tuple(x, default=None):
@@ -205,7 +205,7 @@ def to_tuple(x, default=None):
     elif isinstance(x, collections.Iterable):
         return tuple(x)
     else:
-        raise ValueError("Don't know how to turn %r into a tuple" % x)
+        return (x, )
 
 
 def unique_list(seq, hashfunc=None):
@@ -282,10 +282,9 @@ class Dispatcher(object):
     def dispatch_for(self, target, qualifier='default'):
         def decorate(fn):
             if self.uselist:
-                assert target not in self._registry
                 self._registry.setdefault((target, qualifier), []).append(fn)
             else:
-                assert target not in self._registry
+                assert (target, qualifier) not in self._registry
                 self._registry[(target, qualifier)] = fn
             return fn
         return decorate
@@ -301,9 +300,11 @@ class Dispatcher(object):
 
         for spcls in targets:
             if qualifier != 'default' and (spcls, qualifier) in self._registry:
-                return self._fn_or_list(self._registry[(spcls, qualifier)])
+                return self._fn_or_list(
+                    self._registry[(spcls, qualifier)])
             elif (spcls, 'default') in self._registry:
-                return self._fn_or_list(self._registry[(spcls, 'default')])
+                return self._fn_or_list(
+                    self._registry[(spcls, 'default')])
         else:
             raise ValueError("no dispatch function for object: %s" % obj)
 
index 7376915ef6e44e44f1b661c339fbac99cd3d2ea3..9773d396a1c5d3d5975511fe773e28678e28f551 100644 (file)
@@ -205,6 +205,81 @@ to whatever is in this list.
 
 .. autofunction:: alembic.autogenerate.render_python_code
 
+.. _autogen_rewriter:
+
+Fine-Grained Autogenerate Generation with Rewriters
+---------------------------------------------------
+
+The preceding example illustrated how we can make a simple change to the
+structure of the operation directives to produce new autogenerate output.
+For the case where we want to affect very specific parts of the autogenerate
+stream, we can make a function for
+:paramref:`.EnvironmentContext.configure.process_revision_directives`
+which traverses through the whole :class:`.MigrationScript` structure, locates
+the elements we care about and modifies them in-place as needed.  However,
+to reduce the boilerplate associated with this task, we can use the
+:class:`.Rewriter` object to make this easier.  :class:`.Rewriter` gives
+us an object that we can pass directly to
+:paramref:`.EnvironmentContext.configure.process_revision_directives` which
+we can also attach handler functions onto, keyed to specific types of
+constructs.
+
+Below is an example where we rewrite :class:`.ops.AddColumnOp` directives;
+based on whether or not the new column is "nullable", we either return
+the existing directive, or we return the existing directive with
+the nullable flag changed, inside of a list with a second directive
+to alter the nullable flag in a second step::
+
+    # ... fragmented env.py script ....
+
+    from alembic.autogenerate import rewriter
+    from alembic import ops
+
+    writer = rewriter.Rewriter()
+
+    @writer.rewrites(ops.AddColumnOp)
+    def add_column(context, revision, op):
+        if op.column.nullable:
+            return op
+        else:
+            op.column.nullable = True
+            return [
+                op,
+                ops.AlterColumnOp(
+                    op.table_name,
+                    op.column_name,
+                    modify_nullable=False,
+                    existing_type=op.column.type,
+                )
+            ]
+
+    # ... later ...
+
+    def run_migrations_online():
+        # ...
+
+        with connectable.connect() as connection:
+            context.configure(
+                connection=connection,
+                target_metadata=target_metadata,
+                process_revision_directives=writer
+            )
+
+            with context.begin_transaction():
+                context.run_migrations()
+
+Above, in a full :class:`.ops.MigrationScript` structure, the
+:class:`.AddColumn` directives would be present within
+the paths ``MigrationScript->UpgradeOps->ModifyTableOps``
+and ``MigrationScript->DowngradeOps->ModifyTableOps``.   The
+:class:`.Rewriter` handles traversing into these structures as well
+as rewriting them as needed so that we only need to code for the specific
+object we care about.
+
+
+.. autoclass:: alembic.autogenerate.rewriter.Rewriter
+    :members:
+
 .. _autogen_custom_ops:
 
 Autogenerating Custom Operation Directives
index 3ce6200ca53dcc13a936e82e2defc264693a04cd..bf0d06584149153dc1be9b48a1399da22d5e30c7 100644 (file)
@@ -1,5 +1,5 @@
 from alembic.testing.fixtures import TestBase
-from alembic.testing import eq_, ne_, assert_raises_message
+from alembic.testing import eq_, ne_, assert_raises_message, is_
 from alembic.testing.env import clear_staging_env, staging_env, \
     _get_staging_directory, _no_sql_testing_config, env_file_fixture, \
     script_file_fixture, _testing_config, _sqlite_testing_config, \
@@ -11,6 +11,7 @@ from alembic.environment import EnvironmentContext
 from alembic.testing import mock
 from alembic import util
 from alembic.operations import ops
+from alembic import autogenerate
 import os
 import datetime
 import sqlalchemy as sa
@@ -387,6 +388,150 @@ def downgrade():
             )
 
 
+class RewriterTest(TestBase):
+    def test_all_traverse(self):
+        writer = autogenerate.Rewriter()
+
+        mocker = mock.Mock(side_effect=lambda context, revision, op: op)
+        writer.rewrites(ops.MigrateOperation)(mocker)
+
+        addcolop = ops.AddColumnOp(
+            't1', sa.Column('x', sa.Integer())
+        )
+
+        directives = [
+            ops.MigrationScript(
+                util.rev_id(),
+                ops.UpgradeOps(ops=[
+                    ops.ModifyTableOps('t1', ops=[
+                        addcolop
+                    ])
+                ]),
+                ops.DowngradeOps(ops=[
+                ]),
+            )
+        ]
+
+        ctx, rev = mock.Mock(), mock.Mock()
+        writer(ctx, rev, directives)
+        eq_(
+            mocker.mock_calls,
+            [
+                mock.call(ctx, rev, directives[0]),
+                mock.call(ctx, rev, directives[0].upgrade_ops),
+                mock.call(ctx, rev, directives[0].upgrade_ops.ops[0]),
+                mock.call(ctx, rev, addcolop),
+                mock.call(ctx, rev, directives[0].downgrade_ops),
+            ]
+        )
+
+    def test_double_migrate_table(self):
+        writer = autogenerate.Rewriter()
+
+        idx_ops = []
+
+        @writer.rewrites(ops.ModifyTableOps)
+        def second_table(context, revision, op):
+            return [
+                op,
+                ops.ModifyTableOps('t2', ops=[
+                    ops.AddColumnOp('t2', sa.Column('x', sa.Integer()))
+                ])
+            ]
+
+        @writer.rewrites(ops.AddColumnOp)
+        def add_column(context, revision, op):
+            idx_op = ops.CreateIndexOp('ixt', op.table_name, [op.column.name])
+            idx_ops.append(idx_op)
+            return [
+                op,
+                idx_op
+            ]
+
+        directives = [
+            ops.MigrationScript(
+                util.rev_id(),
+                ops.UpgradeOps(ops=[
+                    ops.ModifyTableOps('t1', ops=[
+                        ops.AddColumnOp('t1', sa.Column('x', sa.Integer()))
+                    ])
+                ]),
+                ops.DowngradeOps(ops=[]),
+            )
+        ]
+
+        ctx, rev = mock.Mock(), mock.Mock()
+        writer(ctx, rev, directives)
+        eq_(
+            [d.table_name for d in directives[0].upgrade_ops.ops],
+            ['t1', 't2']
+        )
+        is_(
+            directives[0].upgrade_ops.ops[0].ops[1],
+            idx_ops[0]
+        )
+        is_(
+            directives[0].upgrade_ops.ops[1].ops[1],
+            idx_ops[1]
+        )
+
+    def test_chained_ops(self):
+        writer1 = autogenerate.Rewriter()
+        writer2 = autogenerate.Rewriter()
+
+        @writer1.rewrites(ops.AddColumnOp)
+        def add_column_nullable(context, revision, op):
+            if op.column.nullable:
+                return op
+            else:
+                op.column.nullable = True
+                return [
+                    op,
+                    ops.AlterColumnOp(
+                        op.table_name,
+                        op.column.name,
+                        modify_nullable=False,
+                        existing_type=op.column.type,
+                    )
+                ]
+
+        @writer2.rewrites(ops.AddColumnOp)
+        def add_column_idx(context, revision, op):
+            idx_op = ops.CreateIndexOp('ixt', op.table_name, [op.column.name])
+            return [
+                op,
+                idx_op
+            ]
+
+        directives = [
+            ops.MigrationScript(
+                util.rev_id(),
+                ops.UpgradeOps(ops=[
+                    ops.ModifyTableOps('t1', ops=[
+                        ops.AddColumnOp(
+                            't1', sa.Column('x', sa.Integer(), nullable=False))
+                    ])
+                ]),
+                ops.DowngradeOps(ops=[]),
+            )
+        ]
+
+        ctx, rev = mock.Mock(), mock.Mock()
+        writer1.chain(writer2)(ctx, rev, directives)
+
+        eq_(
+            autogenerate.render_python_code(directives[0].upgrade_ops),
+            "### commands auto generated by Alembic - please adjust! ###\n"
+            "    op.add_column('t1', "
+            "sa.Column('x', sa.Integer(), nullable=True))\n"
+            "    op.create_index('ixt', 't1', ['x'], unique=False)\n"
+            "    op.alter_column('t1', 'x',\n"
+            "               existing_type=sa.Integer(),\n"
+            "               nullable=False)\n"
+            "    ### end Alembic commands ###"
+        )
+
+
 class MultiDirRevisionCommandTest(TestBase):
     def setUp(self):
         self.env = staging_env()