]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Handle include_* arguments in compare_metadata()
authorRoman Podoliaka <roman.podoliaka@gmail.com>
Tue, 28 Jan 2014 07:08:58 +0000 (09:08 +0200)
committerRoman Podoliaka <roman.podoliaka@gmail.com>
Tue, 28 Jan 2014 16:46:04 +0000 (18:46 +0200)
include_object, include_symbol and include_schemas are very useful,
when you need to specify what objects you want to be compared.
Modify compare_metadata() public API function, so that it takes those
arguments into account and pass them to _produce_net_changes().

alembic/autogenerate/api.py
tests/test_autogenerate.py

index 34da7c9042056f74cd7f06830ee1f449e572f2f3..910d6eccf7a1526cbf0d5d837a37e35ddb96d443 100644 (file)
@@ -101,7 +101,13 @@ def compare_metadata(context, metadata):
     """
     autogen_context, connection = _autogen_context(context, None)
     diffs = []
-    _produce_net_changes(connection, metadata, diffs, autogen_context)
+
+    object_filters = _get_object_filters(context.opts)
+    include_schemas = context.opts.get('include_schemas', False)
+
+    _produce_net_changes(connection, metadata, diffs, autogen_context,
+                         object_filters, include_schemas)
+
     return diffs
 
 ###################################################
@@ -113,21 +119,9 @@ def _produce_migration_diffs(context, template_args,
                                 include_schemas=False):
     opts = context.opts
     metadata = opts['target_metadata']
-    include_object = opts.get('include_object', include_object)
-    include_symbol = opts.get('include_symbol', include_symbol)
     include_schemas = opts.get('include_schemas', include_schemas)
 
-    object_filters = []
-    if include_symbol:
-        def include_symbol_filter(object, name, type_, reflected, compare_to):
-            if type_ == "table":
-                return include_symbol(name, object.schema)
-            else:
-                return True
-        object_filters.append(include_symbol_filter)
-    if include_object:
-        object_filters.append(include_object)
-
+    object_filters = _get_object_filters(opts, include_symbol, include_object)
 
     if metadata is None:
         raise util.CommandError(
@@ -147,6 +141,25 @@ def _produce_migration_diffs(context, template_args,
             _indent(_produce_downgrade_commands(diffs, autogen_context))
     template_args['imports'] = "\n".join(sorted(imports))
 
+
+def _get_object_filters(context_opts, include_symbol=None, include_object=None):
+    include_symbol = context_opts.get('include_symbol', include_symbol)
+    include_object = context_opts.get('include_object', include_object)
+
+    object_filters = []
+    if include_symbol:
+        def include_symbol_filter(object, name, type_, reflected, compare_to):
+            if type_ == "table":
+                return include_symbol(name, object.schema)
+            else:
+                return True
+        object_filters.append(include_symbol_filter)
+    if include_object:
+        object_filters.append(include_object)
+
+    return object_filters
+
+
 def _autogen_context(context, imports):
     opts = context.opts
     connection = context.bind
index 793700415b3d148ca4a09f50b21860dd8f2d6062..eb04d37b1e1ee2728fe6fdd8e7f56809684e6260 100644 (file)
@@ -299,6 +299,123 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
         eq_(diffs[7][0][5], True)
         eq_(diffs[7][0][6], False)
 
+    def test_compare_metadata(self):
+        metadata = self.m2
+        connection = self.context.bind
+
+        diffs = autogenerate.compare_metadata(self.context, metadata)
+
+        eq_(
+            diffs[0],
+            ('add_table', metadata.tables['item'])
+        )
+
+        eq_(diffs[1][0], 'remove_table')
+        eq_(diffs[1][1].name, "extra")
+
+        eq_(diffs[2][0], "add_column")
+        eq_(diffs[2][1], None)
+        eq_(diffs[2][2], "address")
+        eq_(diffs[2][3], metadata.tables['address'].c.street)
+
+        eq_(diffs[3][0], "add_column")
+        eq_(diffs[3][1], None)
+        eq_(diffs[3][2], "order")
+        eq_(diffs[3][3], metadata.tables['order'].c.user_id)
+
+        eq_(diffs[4][0][0], "modify_type")
+        eq_(diffs[4][0][1], None)
+        eq_(diffs[4][0][2], "order")
+        eq_(diffs[4][0][3], "amount")
+        eq_(repr(diffs[4][0][5]), "NUMERIC(precision=8, scale=2)")
+        eq_(repr(diffs[4][0][6]), "Numeric(precision=10, scale=2)")
+
+        eq_(diffs[5][0], 'remove_column')
+        eq_(diffs[5][3].name, 'pw')
+
+        eq_(diffs[6][0][0], "modify_default")
+        eq_(diffs[6][0][1], None)
+        eq_(diffs[6][0][2], "user")
+        eq_(diffs[6][0][3], "a1")
+        eq_(diffs[6][0][6].arg, "x")
+
+        eq_(diffs[7][0][0], 'modify_nullable')
+        eq_(diffs[7][0][5], True)
+        eq_(diffs[7][0][6], False)
+
+    def test_compare_metadata_include_object(self):
+        metadata = self.m2
+
+        def include_object(obj, name, type_, reflected, compare_to):
+            if type_ == "table":
+                return name in ("extra", "order")
+            elif type_ == "column":
+                return name != "amount"
+            else:
+                return True
+
+        context = MigrationContext.configure(
+            connection=self.bind.connect(),
+            opts={
+                'compare_type': True,
+                'compare_server_default': True,
+                'target_metadata': self.m1,
+                'upgrade_token': "upgrades",
+                'downgrade_token': "downgrades",
+                'include_object': include_object,
+            }
+        )
+
+        diffs = autogenerate.compare_metadata(context, metadata)
+
+        eq_(diffs[0][0], 'remove_table')
+        eq_(diffs[0][1].name, "extra")
+
+        eq_(diffs[1][0], "add_column")
+        eq_(diffs[1][1], None)
+        eq_(diffs[1][2], "order")
+        eq_(diffs[1][3], metadata.tables['order'].c.user_id)
+
+    def test_compare_metadata_include_symbol(self):
+        metadata = self.m2
+
+        def include_symbol(table_name, schema_name):
+            return table_name in ('extra', 'order')
+
+        context = MigrationContext.configure(
+            connection=self.bind.connect(),
+            opts={
+                'compare_type': True,
+                'compare_server_default': True,
+                'target_metadata': self.m1,
+                'upgrade_token': "upgrades",
+                'downgrade_token': "downgrades",
+                'include_symbol': include_symbol,
+            }
+        )
+
+        diffs = autogenerate.compare_metadata(context, metadata)
+
+        eq_(diffs[0][0], 'remove_table')
+        eq_(diffs[0][1].name, "extra")
+
+        eq_(diffs[1][0], "add_column")
+        eq_(diffs[1][1], None)
+        eq_(diffs[1][2], "order")
+        eq_(diffs[1][3], metadata.tables['order'].c.user_id)
+
+        eq_(diffs[2][0][0], "modify_type")
+        eq_(diffs[2][0][1], None)
+        eq_(diffs[2][0][2], "order")
+        eq_(diffs[2][0][3], "amount")
+        eq_(repr(diffs[2][0][5]), "NUMERIC(precision=8, scale=2)")
+        eq_(repr(diffs[2][0][6]), "Numeric(precision=10, scale=2)")
+
+        eq_(diffs[2][1][0], 'modify_nullable')
+        eq_(diffs[2][1][2], 'order')
+        eq_(diffs[2][1][5], False)
+        eq_(diffs[2][1][6], True)
+
     def test_render_nothing(self):
         context = MigrationContext.configure(
             connection=self.bind.connect(),