]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
expose on_version_apply callback to context users
authorJohn Passaro <john.a.passaro@gmail.com>
Tue, 27 Jun 2017 18:36:11 +0000 (14:36 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Jul 2017 18:56:39 +0000 (14:56 -0400)
Change-Id: I694e26f7d161dcaf4f035277c8317ff6ffe41680
Pull-request: https://bitbucket.org/zzzeek/alembic/pull-requests/67

alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/script/base.py
docs/build/api/runtime.rst
docs/build/changelog.rst
tests/test_script_consumption.py

index edabc06aa033f2634ceba2b9de4b64f3d081bee3..613b7459841c348a2a8b562095cd32894ada3c2f 100644 (file)
@@ -307,6 +307,7 @@ class EnvironmentContext(util.ModuleClsProxy):
                   alembic_module_prefix="op.",
                   sqlalchemy_module_prefix="sa.",
                   user_module_prefix=None,
+                  on_version_apply=None,
                   **kw
                   ):
         """Configure a :class:`.MigrationContext` within this
@@ -416,6 +417,24 @@ class EnvironmentContext(util.ModuleClsProxy):
             flag and additionally established that the Alembic version table
             has a primary key constraint by default.
 
+        :param on_version_apply: a callable or collection of callables to be
+            run for each migration step.
+            The callables will be run in the order they are given, once for
+            each migration step, after the respective operation has been
+            applied but before its transaction is finalized.
+            Each callable accepts no positional arguments and the following
+            keyword arguments:
+
+            * ``ctx``: the :class:`.MigrationContext` running the migration,
+            * ``step``: a :class:`.MigrationInfo` representing the
+              step currently being applied,
+            * ``heads``: a collection of version strings representing the
+              current heads,
+            * ``run_args``: the ``**kwargs`` passed to :meth:`.run_migrations`.
+
+            .. versionadded:: 0.9.3
+
+
         Parameters specific to the autogenerate feature, when
         ``alembic revision`` is run with the ``--autogenerate`` feature:
 
@@ -732,7 +751,6 @@ class EnvironmentContext(util.ModuleClsProxy):
 
              :paramref:`.command.revision.process_revision_directives`
 
-
         Parameters specific to individual backends:
 
         :param mssql_batch_separator: The "batch separator" which will
@@ -774,6 +792,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         opts['user_module_prefix'] = user_module_prefix
         opts['literal_binds'] = literal_binds
         opts['process_revision_directives'] = process_revision_directives
+        opts['on_version_apply'] = util.to_tuple(on_version_apply, default=())
 
         if render_item is not None:
             opts['render_item'] = render_item
index 13e6ebe4a9dfe530990861477c3b750da8b48dcd..5b95208c6822f2388d8296a785ac02dcd71d02ac 100644 (file)
@@ -70,6 +70,7 @@ class MigrationContext(object):
         transactional_ddl = opts.get("transactional_ddl")
         self._transaction_per_migration = opts.get(
             "transaction_per_migration", False)
+        self.on_version_apply_callbacks = opts.get('on_version_apply', ())
 
         if as_sql:
             self.connection = self._stdout_connection(connection)
@@ -334,6 +335,11 @@ class MigrationContext(object):
                 # and row-targeted updates and deletes, it's simpler for now
                 # just to run the operations on every version
                 head_maintainer.update_to_step(step)
+                for callback in self.on_version_apply_callbacks:
+                    callback(ctx=self,
+                             step=step.info,
+                             heads=set(head_maintainer.heads),
+                             run_args=kw)
 
             if not starting_in_transaction and not self.as_sql and \
                 not self.impl.transactional_ddl and \
@@ -534,6 +540,95 @@ class HeadMaintainer(object):
             self._update_version(from_, to_)
 
 
+class MigrationInfo(object):
+    """Exposes information about a migration step to a callback listener.
+
+    The :class:`.MigrationInfo` object is available exclusively for the
+    benefit of the :paramref:`.EnvironmentContext.on_version_apply`
+    callback hook.
+
+    .. versionadded:: 0.9.3
+
+    """
+
+    is_upgrade = None
+    """True/False: indicates whether this operation ascends or descends the
+    version tree."""
+
+    is_stamp = None
+    """True/False: indicates whether this operation is a stamp (i.e. whether
+    it results in any actual database operations)."""
+
+    up_revision_id = None
+    """Version string corresponding to :attr:`.Revision.revision`."""
+
+    down_revision_ids = None
+    """Tuple of strings representing the base revisions of this migration step.
+
+    If empty, this represents a root revision; otherwise, the first item
+    corresponds to :attr:`.Revision.down_revision`, and the rest are inferred
+    from dependencies.
+    """
+
+    revision_map = None
+    """The revision map inside of which this operation occurs."""
+
+    def __init__(self, revision_map, is_upgrade, is_stamp, up_revision,
+                 down_revisions):
+        self.revision_map = revision_map
+        self.is_upgrade = is_upgrade
+        self.is_stamp = is_stamp
+        self.up_revision_id = up_revision
+        self.down_revision_ids = util.to_tuple(down_revisions)
+
+    @property
+    def is_migration(self):
+        """True/False: indicates whether this operation is a migration.
+
+        At present this is true if and only the migration is not a stamp.
+        If other operation types are added in the future, both this attribute
+        and :attr:`~.MigrationInfo.is_stamp` will be false.
+        """
+        return not self.is_stamp
+
+    @property
+    def source_revision_ids(self):
+        """Active revisions before this migration step is applied."""
+        revs = self.down_revision_ids if self.is_upgrade \
+            else self.up_revision_id
+        return util.to_tuple(revs, default=())
+
+    @property
+    def destination_revision_ids(self):
+        """Active revisions after this migration step is applied."""
+        revs = self.up_revision_id if self.is_upgrade \
+            else self.down_revision_ids
+        return util.to_tuple(revs, default=())
+
+    @property
+    def up_revision(self):
+        """Get :attr:`~MigrationInfo.up_revision_id` as a :class:`.Revision`."""
+        return self.revision_map.get_revision(self.up_revision_id)
+
+    @property
+    def down_revisions(self):
+        """Get :attr:`~MigrationInfo.down_revision_ids` as a tuple of
+        :class:`Revisions <.Revision>`."""
+        return self.revision_map.get_revisions(self.down_revision_ids)
+
+    @property
+    def source_revisions(self):
+        """Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of
+        :class:`Revisions <.Revision>`."""
+        return self.revision_map.get_revisions(self.source_revision_ids)
+
+    @property
+    def destination_revisions(self):
+        """Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of
+        :class:`Revisions <.Revision>`."""
+        return self.revision_map.get_revisions(self.destination_revision_ids)
+
+
 class MigrationStep(object):
     @property
     def name(self):
@@ -759,14 +854,22 @@ class RevisionStep(MigrationStep):
     def insert_version_num(self):
         return self.revision.revision
 
+    @property
+    def info(self):
+        return MigrationInfo(revision_map=self.revision_map,
+                             up_revision=self.revision.revision,
+                             down_revisions=self.revision._all_down_revisions,
+                             is_upgrade=self.is_upgrade, is_stamp=False)
+
 
 class StampStep(MigrationStep):
-    def __init__(self, from_, to_, is_upgrade, branch_move):
+    def __init__(self, from_, to_, is_upgrade, branch_move, revision_map=None):
         self.from_ = util.to_tuple(from_, default=())
         self.to_ = util.to_tuple(to_, default=())
         self.is_upgrade = is_upgrade
         self.branch_move = branch_move
         self.migration_fn = self.stamp_revision
+        self.revision_map = revision_map
 
     doc = None
 
@@ -836,3 +939,10 @@ class StampStep(MigrationStep):
 
     def should_unmerge_branches(self, heads):
         return len(self.to_) > 1
+
+    @property
+    def info(self):
+        up, down = (self.to_, self.from_) if self.is_upgrade \
+            else (self.from_, self.to_)
+        return MigrationInfo(self.revision_map, up, down, self.is_upgrade,
+                             True)
index 17cb3de3fbd9a96f15fd509a59d012bb510088f4..6448685102ce6b528863c5499d0d8bd1079b1cc6 100644 (file)
@@ -370,7 +370,8 @@ class ScriptDirectory(object):
                     # dest is 'base'.  Return a "delete branch" migration
                     # for all applicable heads.
                     steps.extend([
-                        migration.StampStep(head.revision, None, False, True)
+                        migration.StampStep(head.revision, None, False, True,
+                                            self.revision_map)
                         for head in filtered_heads
                     ])
                     continue
@@ -390,7 +391,8 @@ class ScriptDirectory(object):
                     assert not ancestors.intersection(filtered_heads)
                     todo_heads = [head.revision for head in filtered_heads]
                     step = migration.StampStep(
-                        todo_heads, dest.revision, False, False)
+                        todo_heads, dest.revision, False, False,
+                        self.revision_map)
                     steps.append(step)
                     continue
                 elif ancestors.intersection(filtered_heads):
@@ -398,13 +400,15 @@ class ScriptDirectory(object):
                     # we can treat them as a "merge", single step.
                     todo_heads = [head.revision for head in filtered_heads]
                     step = migration.StampStep(
-                        todo_heads, dest.revision, True, False)
+                        todo_heads, dest.revision, True, False,
+                        self.revision_map)
                     steps.append(step)
                     continue
                 else:
                     # destination is in a branch not represented,
                     # treat it as new branch
-                    step = migration.StampStep((), dest.revision, True, True)
+                    step = migration.StampStep((), dest.revision, True, True,
+                                               self.revision_map)
                     steps.append(step)
                     continue
             return steps
index f32e943a4e81faeb583a51f0abda8a0dad7ae34a..cf707cd841c4576a2fc168070f6350170dcee126 100644 (file)
@@ -33,7 +33,8 @@ The Migration Context
 
 The :class:`.MigrationContext` handles the actual work to be performed
 against a database backend as migration operations proceed.  It is generally
-not exposed to the end-user.
+not exposed to the end-user, except when the
+:paramref:`~.EnvironmentContext.configure.on_version_apply` callback hook is used.
 
 .. automodule:: alembic.runtime.migration
     :members: MigrationContext
index 4b7a57d742331fe5ee1e2760d32eaf6824868577..c85662c55716dafe01bfaaf387eeccbafb2ac750 100644 (file)
@@ -7,6 +7,14 @@ Changelog
     :version: 0.9.3
     :released:
 
+    .. change::
+      :tags: feature, runtime
+
+      Added a new callback hook :paramref:`.EnvironmentContext.on_version_apply`,
+      which allows user-defined code to be invoked each time an individual
+      upgrade, downgrade, or stamp operation proceeds against a database.
+      Pull request courtesy John Passaro.
+
     .. change:: 433
       :tags: bug, autogenerate
       :tickets: 433
index b313273af78384b71526f562147e4d6c290ba11b..5ffa24a81ae438c7c412ccaf73a32d8272c012b2 100644 (file)
@@ -2,6 +2,7 @@
 
 import os
 import re
+import textwrap
 
 from alembic import command, util
 from alembic.util import compat
@@ -133,6 +134,87 @@ class SourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
     sourceless = True
 
 
+class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
+    exp_kwargs = frozenset(('ctx', 'heads', 'run_args', 'step'))
+
+    @staticmethod
+    def _env_file_fixture():
+        env_file_fixture(textwrap.dedent("""\
+            import alembic
+            from alembic import context
+            from sqlalchemy import engine_from_config, pool
+
+            config = context.config
+
+            target_metadata = None
+
+            def run_migrations_offline():
+                url = config.get_main_option('sqlalchemy.url')
+                context.configure(
+                    url=url, target_metadata=target_metadata,
+                    on_version_apply=alembic.mock_event_listener,
+                    literal_binds=True)
+
+                with context.begin_transaction():
+                    context.run_migrations()
+
+            def run_migrations_online():
+                connectable = engine_from_config(
+                    config.get_section(config.config_ini_section),
+                    prefix='sqlalchemy.',
+                    poolclass=pool.NullPool)
+                with connectable.connect() as connection:
+                    context.configure(
+                        connection=connection,
+                        on_version_apply=alembic.mock_event_listener,
+                        target_metadata=target_metadata,
+                    )
+                    with context.begin_transaction():
+                        context.run_migrations()
+
+            if context.is_offline_mode():
+                run_migrations_offline()
+            else:
+                run_migrations_online()
+            """))
+
+    def test_steps(self):
+        import alembic
+        alembic.mock_event_listener = None
+        self._env_file_fixture()
+        with mock.patch('alembic.mock_event_listener', mock.Mock()) as mymock:
+            super(CallbackEnvironmentTest, self).test_steps()
+        calls = mymock.call_args_list
+        assert calls
+        for call in calls:
+            args, kw = call
+            assert not args
+            assert set(kw.keys()) >= self.exp_kwargs
+            assert kw['run_args'] == {}
+            assert hasattr(kw['ctx'], 'get_current_revision')
+
+            step = kw['step']
+            assert isinstance(getattr(step, 'is_upgrade', None), bool)
+            assert isinstance(getattr(step, 'is_stamp', None), bool)
+            assert isinstance(getattr(step, 'is_migration', None), bool)
+            assert isinstance(getattr(step, 'up_revision_id', None),
+                              compat.string_types)
+            assert isinstance(getattr(step, 'up_revision', None), Script)
+            for revtype in 'down', 'source', 'destination':
+                revs = getattr(step, '%s_revisions' % revtype)
+                assert isinstance(revs, tuple)
+                for rev in revs:
+                    assert isinstance(rev, Script)
+                revids = getattr(step, '%s_revision_ids' % revtype)
+                for revid in revids:
+                    assert isinstance(revid, compat.string_types)
+
+            heads = kw['heads']
+            assert hasattr(heads, '__iter__')
+            for h in heads:
+                assert h is None or isinstance(h, compat.string_types)
+
+
 class OfflineTransactionalDDLTest(TestBase):
     def setUp(self):
         self.env = staging_env()