From: John Passaro Date: Tue, 27 Jun 2017 18:36:11 +0000 (-0400) Subject: expose on_version_apply callback to context users X-Git-Tag: rel_0_9_3~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6a043d10d87c69ffc7449148677e17a4b94a74db;p=thirdparty%2Fsqlalchemy%2Falembic.git expose on_version_apply callback to context users Change-Id: I694e26f7d161dcaf4f035277c8317ff6ffe41680 Pull-request: https://bitbucket.org/zzzeek/alembic/pull-requests/67 --- diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index edabc06a..613b7459 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -307,6 +307,7 @@ class EnvironmentContext(util.ModuleClsProxy): alembic_module_prefix="op.", sqlalchemy_module_prefix="sa.", user_module_prefix=None, + on_version_apply=None, **kw ): """Configure a :class:`.MigrationContext` within this @@ -416,6 +417,24 @@ class EnvironmentContext(util.ModuleClsProxy): flag and additionally established that the Alembic version table has a primary key constraint by default. + :param on_version_apply: a callable or collection of callables to be + run for each migration step. + The callables will be run in the order they are given, once for + each migration step, after the respective operation has been + applied but before its transaction is finalized. + Each callable accepts no positional arguments and the following + keyword arguments: + + * ``ctx``: the :class:`.MigrationContext` running the migration, + * ``step``: a :class:`.MigrationInfo` representing the + step currently being applied, + * ``heads``: a collection of version strings representing the + current heads, + * ``run_args``: the ``**kwargs`` passed to :meth:`.run_migrations`. + + .. versionadded:: 0.9.3 + + Parameters specific to the autogenerate feature, when ``alembic revision`` is run with the ``--autogenerate`` feature: @@ -732,7 +751,6 @@ class EnvironmentContext(util.ModuleClsProxy): :paramref:`.command.revision.process_revision_directives` - Parameters specific to individual backends: :param mssql_batch_separator: The "batch separator" which will @@ -774,6 +792,7 @@ class EnvironmentContext(util.ModuleClsProxy): opts['user_module_prefix'] = user_module_prefix opts['literal_binds'] = literal_binds opts['process_revision_directives'] = process_revision_directives + opts['on_version_apply'] = util.to_tuple(on_version_apply, default=()) if render_item is not None: opts['render_item'] = render_item diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 13e6ebe4..5b95208c 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -70,6 +70,7 @@ class MigrationContext(object): transactional_ddl = opts.get("transactional_ddl") self._transaction_per_migration = opts.get( "transaction_per_migration", False) + self.on_version_apply_callbacks = opts.get('on_version_apply', ()) if as_sql: self.connection = self._stdout_connection(connection) @@ -334,6 +335,11 @@ class MigrationContext(object): # and row-targeted updates and deletes, it's simpler for now # just to run the operations on every version head_maintainer.update_to_step(step) + for callback in self.on_version_apply_callbacks: + callback(ctx=self, + step=step.info, + heads=set(head_maintainer.heads), + run_args=kw) if not starting_in_transaction and not self.as_sql and \ not self.impl.transactional_ddl and \ @@ -534,6 +540,95 @@ class HeadMaintainer(object): self._update_version(from_, to_) +class MigrationInfo(object): + """Exposes information about a migration step to a callback listener. + + The :class:`.MigrationInfo` object is available exclusively for the + benefit of the :paramref:`.EnvironmentContext.on_version_apply` + callback hook. + + .. versionadded:: 0.9.3 + + """ + + is_upgrade = None + """True/False: indicates whether this operation ascends or descends the + version tree.""" + + is_stamp = None + """True/False: indicates whether this operation is a stamp (i.e. whether + it results in any actual database operations).""" + + up_revision_id = None + """Version string corresponding to :attr:`.Revision.revision`.""" + + down_revision_ids = None + """Tuple of strings representing the base revisions of this migration step. + + If empty, this represents a root revision; otherwise, the first item + corresponds to :attr:`.Revision.down_revision`, and the rest are inferred + from dependencies. + """ + + revision_map = None + """The revision map inside of which this operation occurs.""" + + def __init__(self, revision_map, is_upgrade, is_stamp, up_revision, + down_revisions): + self.revision_map = revision_map + self.is_upgrade = is_upgrade + self.is_stamp = is_stamp + self.up_revision_id = up_revision + self.down_revision_ids = util.to_tuple(down_revisions) + + @property + def is_migration(self): + """True/False: indicates whether this operation is a migration. + + At present this is true if and only the migration is not a stamp. + If other operation types are added in the future, both this attribute + and :attr:`~.MigrationInfo.is_stamp` will be false. + """ + return not self.is_stamp + + @property + def source_revision_ids(self): + """Active revisions before this migration step is applied.""" + revs = self.down_revision_ids if self.is_upgrade \ + else self.up_revision_id + return util.to_tuple(revs, default=()) + + @property + def destination_revision_ids(self): + """Active revisions after this migration step is applied.""" + revs = self.up_revision_id if self.is_upgrade \ + else self.down_revision_ids + return util.to_tuple(revs, default=()) + + @property + def up_revision(self): + """Get :attr:`~MigrationInfo.up_revision_id` as a :class:`.Revision`.""" + return self.revision_map.get_revision(self.up_revision_id) + + @property + def down_revisions(self): + """Get :attr:`~MigrationInfo.down_revision_ids` as a tuple of + :class:`Revisions <.Revision>`.""" + return self.revision_map.get_revisions(self.down_revision_ids) + + @property + def source_revisions(self): + """Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of + :class:`Revisions <.Revision>`.""" + return self.revision_map.get_revisions(self.source_revision_ids) + + @property + def destination_revisions(self): + """Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of + :class:`Revisions <.Revision>`.""" + return self.revision_map.get_revisions(self.destination_revision_ids) + + class MigrationStep(object): @property def name(self): @@ -759,14 +854,22 @@ class RevisionStep(MigrationStep): def insert_version_num(self): return self.revision.revision + @property + def info(self): + return MigrationInfo(revision_map=self.revision_map, + up_revision=self.revision.revision, + down_revisions=self.revision._all_down_revisions, + is_upgrade=self.is_upgrade, is_stamp=False) + class StampStep(MigrationStep): - def __init__(self, from_, to_, is_upgrade, branch_move): + def __init__(self, from_, to_, is_upgrade, branch_move, revision_map=None): self.from_ = util.to_tuple(from_, default=()) self.to_ = util.to_tuple(to_, default=()) self.is_upgrade = is_upgrade self.branch_move = branch_move self.migration_fn = self.stamp_revision + self.revision_map = revision_map doc = None @@ -836,3 +939,10 @@ class StampStep(MigrationStep): def should_unmerge_branches(self, heads): return len(self.to_) > 1 + + @property + def info(self): + up, down = (self.to_, self.from_) if self.is_upgrade \ + else (self.from_, self.to_) + return MigrationInfo(self.revision_map, up, down, self.is_upgrade, + True) diff --git a/alembic/script/base.py b/alembic/script/base.py index 17cb3de3..64486851 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -370,7 +370,8 @@ class ScriptDirectory(object): # dest is 'base'. Return a "delete branch" migration # for all applicable heads. steps.extend([ - migration.StampStep(head.revision, None, False, True) + migration.StampStep(head.revision, None, False, True, + self.revision_map) for head in filtered_heads ]) continue @@ -390,7 +391,8 @@ class ScriptDirectory(object): assert not ancestors.intersection(filtered_heads) todo_heads = [head.revision for head in filtered_heads] step = migration.StampStep( - todo_heads, dest.revision, False, False) + todo_heads, dest.revision, False, False, + self.revision_map) steps.append(step) continue elif ancestors.intersection(filtered_heads): @@ -398,13 +400,15 @@ class ScriptDirectory(object): # we can treat them as a "merge", single step. todo_heads = [head.revision for head in filtered_heads] step = migration.StampStep( - todo_heads, dest.revision, True, False) + todo_heads, dest.revision, True, False, + self.revision_map) steps.append(step) continue else: # destination is in a branch not represented, # treat it as new branch - step = migration.StampStep((), dest.revision, True, True) + step = migration.StampStep((), dest.revision, True, True, + self.revision_map) steps.append(step) continue return steps diff --git a/docs/build/api/runtime.rst b/docs/build/api/runtime.rst index f32e943a..cf707cd8 100644 --- a/docs/build/api/runtime.rst +++ b/docs/build/api/runtime.rst @@ -33,7 +33,8 @@ The Migration Context The :class:`.MigrationContext` handles the actual work to be performed against a database backend as migration operations proceed. It is generally -not exposed to the end-user. +not exposed to the end-user, except when the +:paramref:`~.EnvironmentContext.configure.on_version_apply` callback hook is used. .. automodule:: alembic.runtime.migration :members: MigrationContext diff --git a/docs/build/changelog.rst b/docs/build/changelog.rst index 4b7a57d7..c85662c5 100644 --- a/docs/build/changelog.rst +++ b/docs/build/changelog.rst @@ -7,6 +7,14 @@ Changelog :version: 0.9.3 :released: + .. change:: + :tags: feature, runtime + + Added a new callback hook :paramref:`.EnvironmentContext.on_version_apply`, + which allows user-defined code to be invoked each time an individual + upgrade, downgrade, or stamp operation proceeds against a database. + Pull request courtesy John Passaro. + .. change:: 433 :tags: bug, autogenerate :tickets: 433 diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py index b313273a..5ffa24a8 100644 --- a/tests/test_script_consumption.py +++ b/tests/test_script_consumption.py @@ -2,6 +2,7 @@ import os import re +import textwrap from alembic import command, util from alembic.util import compat @@ -133,6 +134,87 @@ class SourcelessApplyVersionsTest(ApplyVersionsFunctionalTest): sourceless = True +class CallbackEnvironmentTest(ApplyVersionsFunctionalTest): + exp_kwargs = frozenset(('ctx', 'heads', 'run_args', 'step')) + + @staticmethod + def _env_file_fixture(): + env_file_fixture(textwrap.dedent("""\ + import alembic + from alembic import context + from sqlalchemy import engine_from_config, pool + + config = context.config + + target_metadata = None + + def run_migrations_offline(): + url = config.get_main_option('sqlalchemy.url') + context.configure( + url=url, target_metadata=target_metadata, + on_version_apply=alembic.mock_event_listener, + literal_binds=True) + + with context.begin_transaction(): + context.run_migrations() + + def run_migrations_online(): + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool) + with connectable.connect() as connection: + context.configure( + connection=connection, + on_version_apply=alembic.mock_event_listener, + target_metadata=target_metadata, + ) + with context.begin_transaction(): + context.run_migrations() + + if context.is_offline_mode(): + run_migrations_offline() + else: + run_migrations_online() + """)) + + def test_steps(self): + import alembic + alembic.mock_event_listener = None + self._env_file_fixture() + with mock.patch('alembic.mock_event_listener', mock.Mock()) as mymock: + super(CallbackEnvironmentTest, self).test_steps() + calls = mymock.call_args_list + assert calls + for call in calls: + args, kw = call + assert not args + assert set(kw.keys()) >= self.exp_kwargs + assert kw['run_args'] == {} + assert hasattr(kw['ctx'], 'get_current_revision') + + step = kw['step'] + assert isinstance(getattr(step, 'is_upgrade', None), bool) + assert isinstance(getattr(step, 'is_stamp', None), bool) + assert isinstance(getattr(step, 'is_migration', None), bool) + assert isinstance(getattr(step, 'up_revision_id', None), + compat.string_types) + assert isinstance(getattr(step, 'up_revision', None), Script) + for revtype in 'down', 'source', 'destination': + revs = getattr(step, '%s_revisions' % revtype) + assert isinstance(revs, tuple) + for rev in revs: + assert isinstance(rev, Script) + revids = getattr(step, '%s_revision_ids' % revtype) + for revid in revids: + assert isinstance(revid, compat.string_types) + + heads = kw['heads'] + assert hasattr(heads, '__iter__') + for h in heads: + assert h is None or isinstance(h, compat.string_types) + + class OfflineTransactionalDDLTest(TestBase): def setUp(self): self.env = staging_env()