]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Added new kw argument to :meth:`.EnvironmentContext.configure`
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Jul 2013 23:07:14 +0000 (19:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Jul 2013 23:07:14 +0000 (19:07 -0400)
``include_object``.  This is a more flexible version of the
``include_symbol`` argument which allows filtering of columns as well as tables
from the autogenerate process,
and in the future will also work for types, constraints and
other constructs.  The fully constructed schema object is passed,
including its name and type as well as a flag indicating if the object
is from the local application metadata or is reflected.

alembic/autogenerate.py
alembic/ddl/impl.py
alembic/ddl/postgresql.py
alembic/environment.py
alembic/migration.py
docs/build/changelog.rst
tests/test_autogenerate.py
tests/test_postgresql.py

index f1af020d0b116608e073fa8baf69e9e5f78d1fa4..70b2e91c803d6acd9c8b6c94af8f5601937b24ea 100644 (file)
@@ -109,12 +109,26 @@ def compare_metadata(context, metadata):
 
 def _produce_migration_diffs(context, template_args,
                                 imports, include_symbol=None,
+                                include_object=None,
                                 include_schemas=False):
     opts = context.opts
     metadata = opts['target_metadata']
+    include_object = opts.get('include_object', include_object)
     include_symbol = opts.get('include_symbol', include_symbol)
     include_schemas = opts.get('include_schemas', include_schemas)
 
+    object_filters = []
+    if include_symbol:
+        def include_symbol_filter(object, name, type_, reflected, compare_to):
+            if type_ == "table":
+                return include_symbol(name, object.schema)
+            else:
+                return True
+        object_filters.append(include_symbol_filter)
+    if include_object:
+        object_filters.append(include_object)
+
+
     if metadata is None:
         raise util.CommandError(
                 "Can't proceed with --autogenerate option; environment "
@@ -126,8 +140,7 @@ def _produce_migration_diffs(context, template_args,
 
     diffs = []
     _produce_net_changes(connection, metadata, diffs,
-                                autogen_context, include_symbol,
-                                include_schemas)
+                                autogen_context, object_filters, include_schemas)
     template_args[opts['upgrade_token']] = \
             _indent(_produce_upgrade_commands(diffs, autogen_context))
     template_args[opts['downgrade_token']] = \
@@ -155,8 +168,16 @@ def _indent(text):
 ###################################################
 # walk structures
 
+
+def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
+    for fn in object_filters:
+        if not fn(object_, name, type_, reflected, compare_to):
+            return False
+    else:
+        return True
+
 def _produce_net_changes(connection, metadata, diffs, autogen_context,
-                            include_symbol=None,
+                            object_filters=(),
                             include_schemas=False):
     inspector = Inspector.from_engine(connection)
     # TODO: not hardcode alembic_version here ?
@@ -179,53 +200,53 @@ def _produce_net_changes(connection, metadata, diffs, autogen_context,
     metadata_table_names = OrderedSet([(table.schema, table.name)
                                 for table in metadata.sorted_tables])
 
-    if include_symbol:
-        conn_table_names = set((s, name)
-                                for s, name in conn_table_names
-                                if include_symbol(name, s))
-        metadata_table_names = OrderedSet((s, name)
-                                for s, name in metadata_table_names
-                                if include_symbol(name, s))
-
     _compare_tables(conn_table_names, metadata_table_names,
+                    object_filters,
                     inspector, metadata, diffs, autogen_context)
 
 def _compare_tables(conn_table_names, metadata_table_names,
+                    object_filters,
                     inspector, metadata, diffs, autogen_context):
 
     for s, tname in metadata_table_names.difference(conn_table_names):
         name = '%s.%s' % (s, tname) if s else tname
-        diffs.append(("add_table", metadata.tables[name]))
-        log.info("Detected added table %r", name)
+        metadata_table = metadata.tables[sa_schema._get_table_key(tname, s)]
+        if _run_filters(metadata_table, tname, "table", False, None, object_filters):
+            diffs.append(("add_table", metadata.tables[name]))
+            log.info("Detected added table %r", name)
 
     removal_metadata = sa_schema.MetaData()
     for s, tname in conn_table_names.difference(metadata_table_names):
-        name = '%s.%s' % (s, tname) if s else tname
+        name = sa_schema._get_table_key(tname, s)
         exists = name in removal_metadata.tables
         t = sa_schema.Table(tname, removal_metadata, schema=s)
         if not exists:
             inspector.reflecttable(t, None)
-        diffs.append(("remove_table", t))
-        log.info("Detected removed table %r", name)
+        if _run_filters(t, tname, "table", True, None, object_filters):
+            diffs.append(("remove_table", t))
+            log.info("Detected removed table %r", name)
 
     existing_tables = conn_table_names.intersection(metadata_table_names)
 
-    conn_column_info = dict(
-        ((s, tname),
-            dict(
-                (rec["name"], rec)
-                for rec in inspector.get_columns(tname, schema=s)
-            )
-        )
-        for s, tname in existing_tables
-    )
+    existing_metadata = sa_schema.MetaData()
+    conn_column_info = {}
+    for s, tname in existing_tables:
+        name = sa_schema._get_table_key(tname, s)
+        exists = name in existing_metadata.tables
+        t = sa_schema.Table(tname, existing_metadata, schema=s)
+        if not exists:
+            inspector.reflecttable(t, None)
+        conn_column_info[(s, tname)] = t
 
     for s, tname in sorted(existing_tables):
         name = '%s.%s' % (s, tname) if s else tname
-        _compare_columns(s, tname,
-                conn_column_info[(s, tname)],
-                metadata.tables[name],
-                diffs, autogen_context)
+        metadata_table = metadata.tables[name]
+        conn_table = existing_metadata.tables[name]
+        if _run_filters(metadata_table, tname, "table", False, conn_table, object_filters):
+            _compare_columns(s, tname, object_filters,
+                    conn_table,
+                    metadata_table,
+                    diffs, autogen_context)
 
     # TODO:
     # index add/drop
@@ -235,33 +256,42 @@ def _compare_tables(conn_table_names, metadata_table_names,
 ###################################################
 # element comparison
 
-def _compare_columns(schema, tname, conn_table, metadata_table,
+def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
                                 diffs, autogen_context):
     name = '%s.%s' % (schema, tname) if schema else tname
     metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c)
-    conn_col_names = set(conn_table)
+    conn_col_names = dict((c.name, c) for c in conn_table.c)
     metadata_col_names = OrderedSet(sorted(metadata_cols_by_name))
 
-    for cname in metadata_col_names.difference(conn_col_names):
-        diffs.append(
-            ("add_column", schema, tname, metadata_cols_by_name[cname])
-        )
-        log.info("Detected added column '%s.%s'", name, cname)
 
-    for cname in conn_col_names.difference(metadata_col_names):
-        diffs.append(
-            ("remove_column", schema, tname, sa_schema.Column(
-                cname,
-                conn_table[cname]['type'],
-                nullable=conn_table[cname]['nullable'],
-                server_default=conn_table[cname]['default']
-            ))
-        )
-        log.info("Detected removed column '%s.%s'", name, cname)
+    for cname in metadata_col_names.difference(conn_col_names):
+        if _run_filters(metadata_cols_by_name[cname], cname,
+                                "column", False, None, object_filters):
+            diffs.append(
+                ("add_column", schema, tname, metadata_cols_by_name[cname])
+            )
+            log.info("Detected added column '%s.%s'", name, cname)
+
+    for cname in set(conn_col_names).difference(metadata_col_names):
+        rem_col = sa_schema.Column(
+                    cname,
+                    conn_table.c[cname].type,
+                    nullable=conn_table.c[cname].nullable,
+                    server_default=conn_table.c[cname].server_default
+                )
+        if _run_filters(rem_col, cname,
+                                "column", True, None, object_filters):
+            diffs.append(
+                ("remove_column", schema, tname, rem_col)
+            )
+            log.info("Detected removed column '%s.%s'", name, cname)
 
     for colname in metadata_col_names.intersection(conn_col_names):
         metadata_col = metadata_cols_by_name[colname]
-        conn_col = conn_table[colname]
+        conn_col = conn_table.c[colname]
+        if not _run_filters(
+                    metadata_col, colname, "column", False, conn_col, object_filters):
+            continue
         col_diff = []
         _compare_type(schema, tname, colname,
             conn_col,
@@ -284,13 +314,13 @@ def _compare_columns(schema, tname, conn_table, metadata_table,
 def _compare_nullable(schema, tname, cname, conn_col,
                             metadata_col_nullable, diffs,
                             autogen_context):
-    conn_col_nullable = conn_col['nullable']
+    conn_col_nullable = conn_col.nullable
     if conn_col_nullable is not metadata_col_nullable:
         diffs.append(
             ("modify_nullable", schema, tname, cname,
                 {
-                    "existing_type": conn_col['type'],
-                    "existing_server_default": conn_col['default'],
+                    "existing_type": conn_col.type,
+                    "existing_server_default": conn_col.server_default,
                 },
                 conn_col_nullable,
                 metadata_col_nullable),
@@ -305,7 +335,7 @@ def _compare_type(schema, tname, cname, conn_col,
                             metadata_col, diffs,
                             autogen_context):
 
-    conn_type = conn_col['type']
+    conn_type = conn_col.type
     metadata_type = metadata_col.type
     if conn_type._type_affinity is sqltypes.NullType:
         log.info("Couldn't determine database type "
@@ -323,8 +353,8 @@ def _compare_type(schema, tname, cname, conn_col,
         diffs.append(
             ("modify_type", schema, tname, cname,
                     {
-                        "existing_nullable": conn_col['nullable'],
-                        "existing_server_default": conn_col['default'],
+                        "existing_nullable": conn_col.nullable,
+                        "existing_server_default": conn_col.server_default,
                     },
                     conn_type,
                     metadata_type),
@@ -337,22 +367,25 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
                                 diffs, autogen_context):
 
     metadata_default = metadata_col.server_default
-    conn_col_default = conn_col['default']
+    conn_col_default = conn_col.server_default
     if conn_col_default is None and metadata_default is None:
         return False
     rendered_metadata_default = _render_server_default(
                             metadata_default, autogen_context)
+    rendered_conn_default = conn_col.server_default.arg.text \
+                            if conn_col.server_default else None
     isdiff = autogen_context['context']._compare_server_default(
                         conn_col, metadata_col,
-                        rendered_metadata_default
+                        rendered_metadata_default,
+                        rendered_conn_default
                     )
     if isdiff:
-        conn_col_default = conn_col['default']
+        conn_col_default = rendered_conn_default
         diffs.append(
             ("modify_default", schema, tname, cname,
                 {
-                    "existing_nullable": conn_col['nullable'],
-                    "existing_type": conn_col['type'],
+                    "existing_nullable": conn_col.nullable,
+                    "existing_type": conn_col.type,
                 },
                 conn_col_default,
                 metadata_default),
index 2bb672c937d8654f4940521d6e71c93c712d9ee1..6a7c688a3bab57120f46d35b9e0bf33042bcc336 100644 (file)
@@ -182,7 +182,7 @@ class DefaultImpl(with_metaclass(ImplMeta)):
 
     def compare_type(self, inspector_column, metadata_column):
 
-        conn_type = inspector_column['type']
+        conn_type = inspector_column.type
         metadata_type = metadata_column.type
 
         metadata_impl = metadata_type.dialect_impl(self.dialect)
@@ -202,9 +202,9 @@ class DefaultImpl(with_metaclass(ImplMeta)):
 
     def compare_server_default(self, inspector_column,
                             metadata_column,
-                            rendered_metadata_default):
-        conn_col_default = inspector_column['default']
-        return conn_col_default != rendered_metadata_default
+                            rendered_metadata_default,
+                            rendered_inspector_default):
+        return rendered_inspector_default != rendered_metadata_default
 
     def start_migrations(self):
         """A hook called when :meth:`.EnvironmentContext.run_migrations`
index d63958ba4c7ec2ceaf3f41e631e629c63982b45f..5ca0d1f592a8e07956d102cefefe98106ca61d91 100644 (file)
@@ -11,14 +11,15 @@ class PostgresqlImpl(DefaultImpl):
 
     def compare_server_default(self, inspector_column,
                             metadata_column,
-                            rendered_metadata_default):
+                            rendered_metadata_default,
+                            rendered_inspector_default):
 
         # don't do defaults for SERIAL columns
         if metadata_column.primary_key and \
             metadata_column is metadata_column.table._autoincrement_column:
             return False
 
-        conn_col_default = inspector_column['default']
+        conn_col_default = rendered_inspector_default
 
         if None in (conn_col_default, rendered_metadata_default):
             return conn_col_default != rendered_metadata_default
index 5d4e056105286a06cc79ff475213ea25b0b83209..641d435a42cc0c9401f5fff97151e35bc95a10cc 100644 (file)
@@ -264,6 +264,7 @@ class EnvironmentContext(object):
             template_args=None,
             target_metadata=None,
             include_symbol=None,
+            include_object=None,
             include_schemas=False,
             compare_type=False,
             compare_server_default=False,
@@ -425,9 +426,59 @@ class EnvironmentContext(object):
          execute
          the two defaults on the database side to compare for equivalence.
 
+        :param include_object: A callable function which is given
+         the chance to return ``True`` or ``False`` for any object,
+         indicating if the given object should be considered in the
+         autogenerate sweep.
+
+         The function accepts the following positional arguments:
+
+         * ``object``: a :class:`~sqlalchemy.schema.SchemaItem` object such as a
+           :class:`~sqlalchemy.schema.Table` or :class:`~sqlalchemy.schema.Column`
+           object
+         * ``name``: the name of the object. This is typically available
+           via ``object.name``.
+         * ``type``: a string describing the type of object; currently
+           ``"table"`` or ``"column"``, but will include other types in a
+           future release
+         * ``reflected``: ``True`` if the given object was produced based on
+           table reflection, ``False`` if it's from a local :class:`.MetaData`
+           object.
+         * ``compare_to``: the object being compared against, if available,
+           else ``None``.
+
+         E.g.::
+
+            def include_object(object, name, type_, reflected, compare_to):
+                if (type_ == "column" and
+                    not reflected and
+                    object.info.get("skip_autogenerate", False)):
+                    return False
+                else:
+                    return True
+
+            context.configure(
+                # ...
+                include_object = include_object
+            )
+
+         The ``include_object`` filter will be expanded in a future release
+         to also receive type, constraint, and default objects.
+
+         .. versionadded:: 0.6.0
+
+         .. seealso::
+
+            ``include_schemas``, ``include_symbol``
+
+
         :param include_symbol: A callable function which, given a table name
          and schema name (may be ``None``), returns ``True`` or ``False``, indicating
          if the given table should be considered in the autogenerate sweep.
+
+         .. deprecated:: 0.6.0 ``include_symbol`` is superceded by the
+            more generic ``include_object`` parameter.
+
          E.g.::
 
             def include_symbol(tablename, schema):
@@ -455,6 +506,24 @@ class EnvironmentContext(object):
          .. versionchanged:: 0.4.0  the ``include_symbol`` callable must now
             also accept a "schema" argument, which may be None.
 
+         .. seealso::
+
+            ``include_schemas``, ``include_object``
+
+        :param include_schemas: If True, autogenerate will scan across
+         all schemas located by the SQLAlchemy
+         :meth:`~sqlalchemy.engine.reflection.Inspector.get_schema_names`
+         method, and include all differences in tables found across all
+         those schemas.  When using this option, you may want to also
+         use the ``include_symbol`` option to specify a callable which
+         can filter the tables/schemas that get included.
+
+         .. versionadded :: 0.4.0
+
+         .. seealso::
+
+            ``include_symbol``, ``include_object``
+
         :param render_item: Callable that can be used to override how
          any schema item, i.e. column, constraint, type,
          etc., is rendered for autogenerate.  The callable receives a
@@ -506,15 +575,6 @@ class EnvironmentContext(object):
          will render them using the dialect module name, i.e. ``mssql.BIT()``,
          ``postgresql.UUID()``.
 
-        :param include_schemas: If True, autogenerate will scan across
-         all schemas located by the SQLAlchemy
-         :meth:`~sqlalchemy.engine.reflection.Inspector.get_schema_names`
-         method, and include all differences in tables found across all
-         those schemas.  When using this option, you may want to also
-         use the ``include_symbol`` option to specify a callable which
-         can filter the tables/schemas that get included.
-
-         .. versionadded :: 0.4.0
 
         Parameters specific to individual backends:
 
@@ -546,6 +606,7 @@ class EnvironmentContext(object):
             opts['template_args'].update(template_args)
         opts['target_metadata'] = target_metadata
         opts['include_symbol'] = include_symbol
+        opts['include_object'] = include_object
         opts['include_schemas'] = include_schemas
         opts['upgrade_token'] = upgrade_token
         opts['downgrade_token'] = downgrade_token
index cdd244a4220a6eff463aca0fea5ff2017b3d61e6..4838e66b135a6ef143fdc5d82668eb5b2d667f0a 100644 (file)
@@ -292,7 +292,8 @@ class MigrationContext(object):
 
     def _compare_server_default(self, inspector_column,
                             metadata_column,
-                            rendered_metadata_default):
+                            rendered_metadata_default,
+                            rendered_column_default):
 
         if self._user_compare_server_default is False:
             return False
@@ -302,7 +303,7 @@ class MigrationContext(object):
                     self,
                     inspector_column,
                     metadata_column,
-                    inspector_column['default'],
+                    rendered_column_default,
                     metadata_column.server_default,
                     rendered_metadata_default
             )
@@ -312,5 +313,6 @@ class MigrationContext(object):
         return self.impl.compare_server_default(
                                 inspector_column,
                                 metadata_column,
-                                rendered_metadata_default)
+                                rendered_metadata_default,
+                                rendered_column_default)
 
index bc39aea28788cf20039967c87e6667a78273bd61..6b98e3c42c896b174a3b7aa58c6e831cdfaee802 100644 (file)
@@ -7,6 +7,19 @@ Changelog
     :version: 0.6.0
     :released:
 
+    .. change::
+      :tags: feature
+      :tickets: 101
+
+      Added new kw argument to :meth:`.EnvironmentContext.configure`
+      ``include_object``.  This is a more flexible version of the
+      ``include_symbol`` argument which allows filtering of columns as well as tables
+      from the autogenerate process,
+      and in the future will also work for types, constraints and
+      other constructs.  The fully constructed schema object is passed,
+      including its name and type as well as a flag indicating if the object
+      is from the local application metadata or is reflected.
+
     .. change::
       :tags: feature
 
index 9bdc2decafaf9f29efc1ffcfce65c0ef91ccf970..a1b5ac3e4660714e853d4a58332057053a4f5cba 100644 (file)
@@ -94,9 +94,17 @@ def _model_four():
 
     return m
 
-_default_include_symbol = lambda name, schema=None: name in ("parent", "child",
+def _default_include_object(obj, name, type_, reflected, compare_to):
+    if type_ == "table":
+        return name in ("parent", "child",
                                 "user", "order", "item",
                                 "address", "extra")
+    else:
+        return True
+
+_default_object_filters = [
+    _default_include_object
+]
 
 class AutogenTest(object):
     @classmethod
@@ -244,9 +252,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
         metadata = self.m2
         connection = self.context.bind
         diffs = []
+        def include_object(obj, name, type_, reflected, compare_to):
+            if type_ == "table":
+                return name == "t3"
+            else:
+                return True
         autogenerate._produce_net_changes(connection, metadata, diffs,
                                           self.autogen_context,
-                                          include_symbol=lambda n, s: n == 't3',
+                                          object_filters=[include_object],
                                           include_schemas=True
                                           )
         eq_(diffs[0][0], "add_table")
@@ -256,9 +269,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
         metadata = self.m2
         connection = self.context.bind
         diffs = []
+        def include_object(obj, name, type_, reflected, compare_to):
+            if type_ == "table":
+                return name == "t4"
+            else:
+                return True
         autogenerate._produce_net_changes(connection, metadata, diffs,
                                           self.autogen_context,
-                                          include_symbol=lambda n, s: n == 't4',
+                                          object_filters=[include_object],
                                           include_schemas=True
                                           )
         eq_(diffs[0][0], "add_table")
@@ -268,9 +286,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
         metadata = self.m2
         connection = self.context.bind
         diffs = []
+        def include_object(obj, name, type_, reflected, compare_to):
+            if type_ == "table":
+                return name == "t1"
+            else:
+                return True
         autogenerate._produce_net_changes(connection, metadata, diffs,
                                           self.autogen_context,
-                                          include_symbol=lambda n, s: n == 't1',
+                                          object_filters=[include_object],
                                           include_schemas=True
                                           )
         eq_(diffs[0][0], "remove_table")
@@ -280,9 +303,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
         metadata = self.m2
         connection = self.context.bind
         diffs = []
+        def include_object(obj, name, type_, reflected, compare_to):
+            if type_ == "table":
+                return name == "t2"
+            else:
+                return True
         autogenerate._produce_net_changes(connection, metadata, diffs,
                                           self.autogen_context,
-                                          include_symbol=lambda n, s: n == 't2',
+                                          object_filters=[include_object],
                                           include_schemas=True
                                           )
         eq_(diffs[0][0], "remove_table")
@@ -315,7 +343,7 @@ class AutogenerateDiffTestWSchema(AutogenTest, TestCase):
         diffs = []
         autogenerate._produce_net_changes(connection, metadata, diffs,
                                           self.autogen_context,
-                                          include_symbol=_default_include_symbol,
+                                          object_filters=_default_object_filters,
                                           include_schemas=True
                                           )
 
@@ -390,7 +418,7 @@ class AutogenerateDiffTestWSchema(AutogenTest, TestCase):
         template_args = {}
         autogenerate._produce_migration_diffs(
                         self.context, template_args, set(),
-                        include_symbol=_default_include_symbol,
+                        include_object=_default_include_object,
                         include_schemas=True
                         )
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
@@ -472,7 +500,7 @@ class AutogenerateDiffTest(AutogenTest, TestCase):
         diffs = []
         autogenerate._produce_net_changes(connection, metadata, diffs,
                                           self.autogen_context,
-                                          include_symbol= _default_include_symbol
+                                          object_filters=_default_object_filters,
                                     )
 
         eq_(
@@ -620,11 +648,55 @@ class AutogenerateDiffTest(AutogenTest, TestCase):
         assert "alter_column('order'" in template_args['upgrades']
         assert "alter_column('order'" in template_args['downgrades']
 
+    def test_include_object(self):
+        def include_object(obj, name, type_, reflected, compare_to):
+            assert obj.name == name
+            if type_ == "table":
+                if reflected:
+                    assert obj.metadata is not self.m2
+                else:
+                    assert obj.metadata is self.m2
+                return name in ("address", "order")
+            elif type_ == "column":
+                if reflected:
+                    assert obj.table.metadata is not self.m2
+                else:
+                    assert obj.table.metadata is self.m2
+                return name != "street"
+            else:
+                return True
+
+
+        context = MigrationContext.configure(
+            connection=self.bind.connect(),
+            opts={
+                'compare_type': True,
+                'compare_server_default': True,
+                'target_metadata': self.m2,
+                'include_object': include_object,
+                'upgrade_token': "upgrades",
+                'downgrade_token': "downgrades",
+                'alembic_module_prefix': 'op.',
+                'sqlalchemy_module_prefix': 'sa.',
+            }
+        )
+        template_args = {}
+        autogenerate._produce_migration_diffs(context, template_args, set())
+        template_args['upgrades'] = template_args['upgrades'].replace("u'", "'")
+        template_args['downgrades'] = template_args['downgrades'].\
+                                        replace("u'", "'")
+
+        assert "alter_column('user'" not in template_args['upgrades']
+        assert "alter_column('user'" not in template_args['downgrades']
+        assert "'street'" not in template_args['upgrades']
+        assert "'street'" not in template_args['downgrades']
+        assert "alter_column('order'" in template_args['upgrades']
+        assert "alter_column('order'" in template_args['downgrades']
+
     def test_skip_null_type_comparison_reflected(self):
         diff = []
         autogenerate._compare_type(None, "sometable", "somecol",
-            {"name":"somecol", "type":NULLTYPE,
-            "nullable":True, "default":None},
+            Column("somecol", NULLTYPE),
             Column("somecol", Integer()),
             diff, self.autogen_context
         )
@@ -633,8 +705,7 @@ class AutogenerateDiffTest(AutogenTest, TestCase):
     def test_skip_null_type_comparison_local(self):
         diff = []
         autogenerate._compare_type(None, "sometable", "somecol",
-            {"name":"somecol", "type":Integer(),
-            "nullable":True, "default":None},
+            Column("somecol", Integer()),
             Column("somecol", NULLTYPE),
             diff, self.autogen_context
         )
@@ -652,8 +723,7 @@ class AutogenerateDiffTest(AutogenTest, TestCase):
 
         diff = []
         autogenerate._compare_type(None, "sometable", "somecol",
-            {"name":"somecol", "type":Integer(),
-            "nullable":True, "default":None},
+            Column("somecol", Integer, nullable=True),
             Column("somecol", MyType()),
             diff, self.autogen_context
         )
@@ -664,7 +734,8 @@ class AutogenerateDiffTest(AutogenTest, TestCase):
         from sqlalchemy.util import OrderedSet
         inspector = Inspector.from_engine(self.bind)
         autogenerate._compare_tables(
-            OrderedSet([(None, 'extra'), (None, 'user')]), OrderedSet(), inspector,
+            OrderedSet([(None, 'extra'), (None, 'user')]),
+            OrderedSet(), [], inspector,
                 MetaData(), diffs, self.autogen_context
         )
         eq_(
index 047ac92dc74e5cbc976d12b9e657ebff4a2025bc..9c530353fb50f6baba3a35240cbfda51d7df50bc 100644 (file)
@@ -197,9 +197,10 @@ class PostgresqlDefaultCompareTest(TestCase):
         cols = insp.get_columns(t1.name)
         ctx = self.autogen_context['context']
         return ctx.impl.compare_server_default(
-            cols[0],
+            None,
             col,
-            rendered)
+            rendered,
+            cols[0]['default'])
 
     def test_compare_current_timestamp(self):
         self._compare_default_roundtrip(