From 4cb18979a4e1ceef2ec592a00cc609b53ca37ef9 Mon Sep 17 00:00:00 2001 From: Roman Podoliaka Date: Tue, 28 Jan 2014 09:08:58 +0200 Subject: [PATCH] Handle include_* arguments in compare_metadata() include_object, include_symbol and include_schemas are very useful, when you need to specify what objects you want to be compared. Modify compare_metadata() public API function, so that it takes those arguments into account and pass them to _produce_net_changes(). --- alembic/autogenerate/api.py | 41 ++++++++----- tests/test_autogenerate.py | 117 ++++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 14 deletions(-) diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 34da7c90..910d6ecc 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -101,7 +101,13 @@ def compare_metadata(context, metadata): """ autogen_context, connection = _autogen_context(context, None) diffs = [] - _produce_net_changes(connection, metadata, diffs, autogen_context) + + object_filters = _get_object_filters(context.opts) + include_schemas = context.opts.get('include_schemas', False) + + _produce_net_changes(connection, metadata, diffs, autogen_context, + object_filters, include_schemas) + return diffs ################################################### @@ -113,21 +119,9 @@ def _produce_migration_diffs(context, template_args, include_schemas=False): opts = context.opts metadata = opts['target_metadata'] - include_object = opts.get('include_object', include_object) - include_symbol = opts.get('include_symbol', include_symbol) include_schemas = opts.get('include_schemas', include_schemas) - object_filters = [] - if include_symbol: - def include_symbol_filter(object, name, type_, reflected, compare_to): - if type_ == "table": - return include_symbol(name, object.schema) - else: - return True - object_filters.append(include_symbol_filter) - if include_object: - object_filters.append(include_object) - + object_filters = _get_object_filters(opts, include_symbol, include_object) if metadata is None: raise util.CommandError( @@ -147,6 +141,25 @@ def _produce_migration_diffs(context, template_args, _indent(_produce_downgrade_commands(diffs, autogen_context)) template_args['imports'] = "\n".join(sorted(imports)) + +def _get_object_filters(context_opts, include_symbol=None, include_object=None): + include_symbol = context_opts.get('include_symbol', include_symbol) + include_object = context_opts.get('include_object', include_object) + + object_filters = [] + if include_symbol: + def include_symbol_filter(object, name, type_, reflected, compare_to): + if type_ == "table": + return include_symbol(name, object.schema) + else: + return True + object_filters.append(include_symbol_filter) + if include_object: + object_filters.append(include_object) + + return object_filters + + def _autogen_context(context, imports): opts = context.opts connection = context.bind diff --git a/tests/test_autogenerate.py b/tests/test_autogenerate.py index 79370041..eb04d37b 100644 --- a/tests/test_autogenerate.py +++ b/tests/test_autogenerate.py @@ -299,6 +299,123 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): eq_(diffs[7][0][5], True) eq_(diffs[7][0][6], False) + def test_compare_metadata(self): + metadata = self.m2 + connection = self.context.bind + + diffs = autogenerate.compare_metadata(self.context, metadata) + + eq_( + diffs[0], + ('add_table', metadata.tables['item']) + ) + + eq_(diffs[1][0], 'remove_table') + eq_(diffs[1][1].name, "extra") + + eq_(diffs[2][0], "add_column") + eq_(diffs[2][1], None) + eq_(diffs[2][2], "address") + eq_(diffs[2][3], metadata.tables['address'].c.street) + + eq_(diffs[3][0], "add_column") + eq_(diffs[3][1], None) + eq_(diffs[3][2], "order") + eq_(diffs[3][3], metadata.tables['order'].c.user_id) + + eq_(diffs[4][0][0], "modify_type") + eq_(diffs[4][0][1], None) + eq_(diffs[4][0][2], "order") + eq_(diffs[4][0][3], "amount") + eq_(repr(diffs[4][0][5]), "NUMERIC(precision=8, scale=2)") + eq_(repr(diffs[4][0][6]), "Numeric(precision=10, scale=2)") + + eq_(diffs[5][0], 'remove_column') + eq_(diffs[5][3].name, 'pw') + + eq_(diffs[6][0][0], "modify_default") + eq_(diffs[6][0][1], None) + eq_(diffs[6][0][2], "user") + eq_(diffs[6][0][3], "a1") + eq_(diffs[6][0][6].arg, "x") + + eq_(diffs[7][0][0], 'modify_nullable') + eq_(diffs[7][0][5], True) + eq_(diffs[7][0][6], False) + + def test_compare_metadata_include_object(self): + metadata = self.m2 + + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "table": + return name in ("extra", "order") + elif type_ == "column": + return name != "amount" + else: + return True + + context = MigrationContext.configure( + connection=self.bind.connect(), + opts={ + 'compare_type': True, + 'compare_server_default': True, + 'target_metadata': self.m1, + 'upgrade_token': "upgrades", + 'downgrade_token': "downgrades", + 'include_object': include_object, + } + ) + + diffs = autogenerate.compare_metadata(context, metadata) + + eq_(diffs[0][0], 'remove_table') + eq_(diffs[0][1].name, "extra") + + eq_(diffs[1][0], "add_column") + eq_(diffs[1][1], None) + eq_(diffs[1][2], "order") + eq_(diffs[1][3], metadata.tables['order'].c.user_id) + + def test_compare_metadata_include_symbol(self): + metadata = self.m2 + + def include_symbol(table_name, schema_name): + return table_name in ('extra', 'order') + + context = MigrationContext.configure( + connection=self.bind.connect(), + opts={ + 'compare_type': True, + 'compare_server_default': True, + 'target_metadata': self.m1, + 'upgrade_token': "upgrades", + 'downgrade_token': "downgrades", + 'include_symbol': include_symbol, + } + ) + + diffs = autogenerate.compare_metadata(context, metadata) + + eq_(diffs[0][0], 'remove_table') + eq_(diffs[0][1].name, "extra") + + eq_(diffs[1][0], "add_column") + eq_(diffs[1][1], None) + eq_(diffs[1][2], "order") + eq_(diffs[1][3], metadata.tables['order'].c.user_id) + + eq_(diffs[2][0][0], "modify_type") + eq_(diffs[2][0][1], None) + eq_(diffs[2][0][2], "order") + eq_(diffs[2][0][3], "amount") + eq_(repr(diffs[2][0][5]), "NUMERIC(precision=8, scale=2)") + eq_(repr(diffs[2][0][6]), "Numeric(precision=10, scale=2)") + + eq_(diffs[2][1][0], 'modify_nullable') + eq_(diffs[2][1][2], 'order') + eq_(diffs[2][1][5], False) + eq_(diffs[2][1][6], True) + def test_render_nothing(self): context = MigrationContext.configure( connection=self.bind.connect(), -- 2.47.2