]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- refactor the FK merge a bit
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Nov 2014 23:02:57 +0000 (18:02 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Nov 2014 23:02:57 +0000 (18:02 -0500)
- getting at attributes of FKs varies a bit on SQLA versions,
so implement an _fk_spec() called for all FK inspection
- to enable include_object() filters and allow the FK constraint
code to flow like that of indexes/uniques, change the approach
so that we deal with an _fk_constraint_sig() object again which
contains the real ForeignKeyConstraint() within; we need this
anyway for include_object, but also allows us to use the standard
"drop_constraint" call for rendering.
- enhance tests in test_autogen_fks to support real FK databases like
Postgresql, MySQL, add in InnoDB flags and ensure that FKs refer
to real primary key constraints for PG support
- implement and test include_object() support for FKs
- inspectors all have get_foreign_keys(), no need to check
- repair the drop_constraint call to quote the "type" and table
name correctly, run all constraint drops through drop_constraint()
for rendering
- fix up schema identifiers for foreign key autogens

alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/ddl/base.py
alembic/ddl/mysql.py
tests/requirements.py
tests/test_autogen_fks.py
tests/test_autogenerate.py

index 60e2100c4c5cc12f744187e2aed966a3f7e417a9..cc5debe8b4720985b81e719f3265c5581e70d784 100644 (file)
@@ -13,7 +13,7 @@ from sqlalchemy.util import OrderedSet
 from .compare import _compare_tables
 from .render import _drop_table, _drop_column, _drop_index, _drop_constraint, \
     _add_table, _add_column, _add_index, _add_constraint, _modify_col, \
-    _add_fk_constraint, _drop_fk_constraint
+    _add_fk_constraint
 from .. import util
 
 log = logging.getLogger(__name__)
@@ -294,7 +294,7 @@ def _invoke_adddrop_command(updown, args, autogen_context):
         "column": (_drop_column, _add_column),
         "index": (_drop_index, _add_index),
         "constraint": (_drop_constraint, _add_constraint),
-        "fk":(_drop_fk_constraint, _add_fk_constraint)
+        "fk": (_drop_constraint, _add_fk_constraint)
     }
 
     cmd_callables = _commands[cmd_type]
@@ -346,8 +346,7 @@ def _group_diffs_by_table(diffs):
         "column": lambda diff: (diff[0], diff[1]),
         "index": lambda diff: (diff[0].table.schema, diff[0].table.name),
         "constraint": lambda diff: (diff[0].table.schema, diff[0].table.name),
-        "fk": lambda diff:
-        (diff[0].parent.table.schema, diff[0].parent.table.name)
+        "fk": lambda diff: (diff[0].parent.schema, diff[0].parent.name)
     }
 
     def _derive_table(diff):
index eb649ac8d7be7918af07ca32ab265b94acab3e1e..65770253400245a379d17a97d043993ea7aee483 100644 (file)
@@ -1,4 +1,3 @@
-import collections
 from sqlalchemy import schema as sa_schema, types as sqltypes
 from sqlalchemy import event
 import logging
@@ -7,6 +6,7 @@ from sqlalchemy.util import OrderedSet
 import re
 from .render import _user_defined_render
 import contextlib
+from alembic.ddl.base import _fk_spec
 
 log = logging.getLogger(__name__)
 
@@ -139,6 +139,25 @@ def _make_unique_constraint(params, conn_table):
     )
 
 
+def _make_foreign_key(params, conn_table):
+    tname = params['referred_table']
+    if params['referred_schema']:
+        tname = "%s.%s" % (params['referred_schema'], tname)
+
+    const = sa_schema.ForeignKeyConstraint(
+        [conn_table.c[cname] for cname in params['constrained_columns']],
+        ["%s.%s" % (tname, n) for n in params['referred_columns']],
+        onupdate=params.get('onupdate'),
+        ondelete=params.get('ondelete'),
+        deferrable=params.get('deferrable'),
+        initially=params.get('initially'),
+        name=params['name']
+    )
+    # needed by 0.7
+    conn_table.append_constraint(const)
+    return const
+
+
 @contextlib.contextmanager
 def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
                      diffs, autogen_context, inspector):
@@ -194,7 +213,6 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
             log.info("Detected removed column '%s.%s'", name, cname)
 
 
-
 class _constraint_sig(object):
 
     def __eq__(self, other):
@@ -235,6 +253,20 @@ class _ix_constraint_sig(_constraint_sig):
         return _get_index_column_names(self.const)
 
 
+class _fk_constraint_sig(_constraint_sig):
+    def __init__(self, const):
+        self.const = const
+        self.name = const.name
+        self.source_schema, self.source_table, \
+            self.source_columns, self.target_schema, self.target_table, \
+            self.target_columns = _fk_spec(const)
+
+        self.sig = (
+            self.source_schema, self.source_table, tuple(self.source_columns),
+            self.target_schema, self.target_table, tuple(self.target_columns)
+        )
+
+
 def _get_index_column_names(idx):
     if compat.sqla_08:
         return [getattr(exp, "name", None) for exp in idx.expressions]
@@ -571,63 +603,79 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
                  cname
                  )
 
-FKInfo = collections.namedtuple('fk_info', ['constrained_columns',
-                                            'referred_table',
-                                            'referred_columns'])
-
 
 def _compare_foreign_keys(schema, tname, object_filters, conn_table,
                           metadata_table, diffs, autogen_context, inspector):
 
-    # This methods checks foreign keys that tables contain in models with
-    # foreign keys that are in db.
-    # Get all necessary information about key of current table from db
+    # if we're doing CREATE TABLE, all FKs are created
+    # inline within the table def
     if conn_table is None:
         return
 
-    fk_db = {}
-    if hasattr(inspector, "get_foreign_keys"):
-        try:
-            fk_db = dict((_get_fk_info_from_db(i), i['name']) for i in
-                         inspector.get_foreign_keys(tname, schema=schema))
-        except NotImplementedError:
-            pass
-
+    metadata_fks = set(
+        fk for fk in metadata_table.constraints
+        if isinstance(fk, sa_schema.ForeignKeyConstraint)
+    )
+    metadata_fks = set(_fk_constraint_sig(fk) for fk in metadata_fks)
 
-    # Get all necessary information about key of current table from
-    # models
-    fk_models = dict((_get_fk_info_from_model(fk), fk) for fk in
-                     metadata_table.foreign_keys)
-    fk_models_set = set(fk_models.keys())
-    fk_db_set = set(fk_db.keys())
-    for key in (fk_db_set - fk_models_set):
-            diffs.append(('drop_fk', fk_db[key], conn_table, key))
-            log.info(("Detected removed foreign key %(fk)r on "
-                      "table %(table)r"), {'fk': fk_db[key],
-                                           'table': conn_table})
-    for key in (fk_models_set - fk_db_set):
-            diffs.append(('add_fk', fk_models[key], key))
-            log.info((
-                "Detected added foreign key for column %(fk)r on table "
-                "%(table)r"), {'fk': fk_models[key].column.name,
-                               'table': conn_table})
-    return diffs
+    conn_fks = inspector.get_foreign_keys(tname, schema=schema)
+    conn_fks = set(_fk_constraint_sig(_make_foreign_key(const, conn_table))
+                   for const in conn_fks)
 
+    conn_fks_by_sig = dict(
+        (c.sig, c) for c in conn_fks
+    )
+    metadata_fks_by_sig = dict(
+        (c.sig, c) for c in metadata_fks
+    )
 
-def _get_fk_info_from_db(fk):
-    return FKInfo(tuple(fk['constrained_columns']),
-                  fk['referred_table'],
-                  tuple(fk['referred_columns']))
+    metadata_fks_by_name = dict(
+        (c.name, c) for c in metadata_fks if c.name is not None
+    )
+    conn_fks_by_name = dict(
+        (c.name, c) for c in conn_fks if c.name is not None
+    )
 
+    def _add_fk(obj, compare_to):
+        if _run_filters(
+                obj.const, obj.name, "foreignkey", False,
+                compare_to, object_filters):
+            diffs.append(('add_fk', const.const))
+
+            log.info(
+                "Detected added foreign key (%s)(%s) on table %s%s",
+                ", ".join(obj.source_columns),
+                ", ".join(obj.target_columns),
+                "%s." % obj.source_schema if obj.source_schema else "",
+                obj.source_table)
+
+    def _remove_fk(obj, compare_to):
+        if _run_filters(
+                obj.const, obj.name, "foreignkey", True,
+                compare_to, object_filters):
+            diffs.append(('remove_fk', obj.const))
+            log.info(
+                "Detected removed foreign key (%s)(%s) on table %s%s",
+                ", ".join(obj.source_columns),
+                ", ".join(obj.target_columns),
+                "%s." % obj.source_schema if obj.source_schema else "",
+                obj.source_table)
+
+    # so far it appears we don't need to do this by name at all.
+    # SQLite doesn't preserve constraint names anyway
+
+    for removed_sig in set(conn_fks_by_sig).difference(metadata_fks_by_sig):
+        const = conn_fks_by_sig[removed_sig]
+        if removed_sig not in metadata_fks_by_sig:
+            compare_to = metadata_fks_by_name[const.name].const \
+                if const.name in metadata_fks_by_name else None
+            _remove_fk(const, compare_to)
+
+    for added_sig in set(metadata_fks_by_sig).difference(conn_fks_by_sig):
+        const = metadata_fks_by_sig[added_sig]
+        if added_sig not in conn_fks_by_sig:
+            compare_to = conn_fks_by_name[const.name].const \
+                if const.name in conn_fks_by_name else None
+            _add_fk(const, compare_to)
 
-def _get_fk_info_from_model(fk):
-    constrained_columns = []
-    for column in fk.constraint.columns:
-        if not isinstance(column, basestring):
-            constrained_columns.append(column.name)
-        else:
-            constrained_columns.append(column)
-    return FKInfo(
-        tuple(constrained_columns),
-        fk.column.table.name,
-        tuple(k.column.name for k in fk.constraint._elements.values()))
+    return diffs
index f6f7d9d39d23856d28d02e72648e42140c9fe79b..ec6165ba2a54e48aa521b71ef99f8ec03b6a14e5 100644 (file)
@@ -1,6 +1,7 @@
 from sqlalchemy import schema as sa_schema, types as sqltypes, sql
 import logging
 from .. import compat
+from ..ddl.base import _table_for_constraint, _fk_spec
 import re
 from ..compat import string_types
 
@@ -241,15 +242,27 @@ def _uq_constraint(constraint, autogen_context, alter):
         }
 
 
-def _add_fk_constraint(constraint, fk_info, autogen_context):
+def _add_fk_constraint(constraint, autogen_context):
+    source_schema, source_table, \
+        source_columns, target_schema, \
+        target_table, target_columns = _fk_spec(constraint)
+
     args = [
         repr(_render_gen_name(autogen_context, constraint.name)),
-        constraint.parent.table.name,
-        fk_info.referred_table,
-        str(list(fk_info.constrained_columns)),
-        str(list(fk_info.referred_columns)),
-        "%s=%r" % ('schema', constraint.parent.table.schema),
+        source_table,
+        target_table,
+        repr(source_columns),
+        repr(target_columns)
     ]
+    if source_schema:
+        args.append(
+            "%s=%r" % ('source_schema', source_schema),
+        )
+    if target_schema:
+        args.append(
+            "%s=%r" % ('referent_schema', target_schema)
+        )
+
     if constraint.deferrable:
         args.append("%s=%r" % ("deferrable", str(constraint.deferrable)))
     if constraint.initially:
@@ -260,15 +273,6 @@ def _add_fk_constraint(constraint, fk_info, autogen_context):
     }
 
 
-def _drop_fk_constraint(constraint, fk_info, autogen_context):
-    args = [repr(_render_gen_name(autogen_context, constraint.name)),
-            constraint.parent.table.name, "type_='foreignkey'"]
-    return "%(prefix)sdrop_constraint(%(args)s)" % {
-        'prefix': _alembic_autogenerate_prefix(autogen_context),
-        'args': ", ".join(args)
-    }
-
-
 def _add_pk_constraint(constraint, autogen_context):
     raise NotImplementedError()
 
@@ -296,19 +300,30 @@ def _drop_constraint(constraint, autogen_context):
     Generate Alembic operations for the ALTER TABLE ... DROP CONSTRAINT
     of a  :class:`~sqlalchemy.schema.UniqueConstraint` instance.
     """
+
+    types = {
+        "unique_constraint": "unique",
+        "foreign_key_constraint": "foreignkey",
+        "primary_key_constraint": "primary",
+        "check_constraint": "check",
+        "column_check_constraint": "check",
+    }
+
     if 'batch_prefix' in autogen_context:
         template = "%(prefix)sdrop_constraint"\
-            "(%(name)r)"
+            "(%(name)r, type_=%(type)r)"
     else:
         template = "%(prefix)sdrop_constraint"\
-            "(%(name)r, '%(table_name)s'%(schema)s)"
+            "(%(name)r, '%(table_name)s'%(schema)s, type_=%(type)r)"
 
+    constraint_table = _table_for_constraint(constraint)
     text = template % {
         'prefix': _alembic_autogenerate_prefix(autogen_context),
         'name': _render_gen_name(autogen_context, constraint.name),
-        'table_name': _ident(constraint.table.name),
-        'schema': (", schema='%s'" % _ident(constraint.table.schema))
-        if constraint.table.schema else '',
+        'table_name': _ident(constraint_table.name),
+        'type': types[constraint.__visit_name__],
+        'schema': (", schema='%s'" % _ident(constraint_table.schema))
+        if constraint_table.schema else '',
     }
     return text
 
index 32878b10efb22e9defa7ff41f5343b36e718d39e..d497253c3f7b183fbbc579d00298aa2000b1ffb8 100644 (file)
@@ -154,6 +154,13 @@ def visit_column_default(element, compiler, **kw):
     )
 
 
+def _table_for_constraint(constraint):
+    if isinstance(constraint, ForeignKeyConstraint):
+        return constraint.parent
+    else:
+        return constraint.table
+
+
 def _columns_for_constraint(constraint):
     if isinstance(constraint, ForeignKeyConstraint):
         return [fk.parent for fk in constraint.elements]
@@ -163,6 +170,24 @@ def _columns_for_constraint(constraint):
         return list(constraint.columns)
 
 
+def _fk_spec(constraint):
+    if util.sqla_100:
+        source_columns = constraint.column_keys
+    else:
+        source_columns = [
+            element.parent.key for element in constraint.elements]
+
+    source_table = constraint.parent.name
+    source_schema = constraint.parent.schema
+    target_schema = constraint.elements[0].column.table.schema
+    target_table = constraint.elements[0].column.table.name
+    target_columns = [element.column.name for element in constraint.elements]
+
+    return (
+        source_schema, source_table,
+        source_columns, target_schema, target_table, target_columns)
+
+
 def _is_type_bound(constraint):
     # this deals with SQLAlchemy #3260, don't copy CHECK constraints
     # that will be generated by the type.
index aac8184bdf190da8608b1ce05597cea59286a4ee..b93d29fc492555f493844e1fb95fabe0890902a0 100644 (file)
@@ -95,7 +95,6 @@ class MySQLImpl(DefaultImpl):
 
         # TODO: if SQLA 1.0, make use of "duplicates_index"
         # metadata
-
         removed = set()
         for idx in list(conn_indexes):
             # MySQL puts implicit indexes on FK columns, even if
index 7512adf1f456d748cd1807b06b704c794f0bc67f..afacd00be26fdcb2a63969b30740ed9d36a3aeb5 100644 (file)
@@ -29,3 +29,13 @@ class DefaultRequirements(SuiteRequirements):
                 lambda config: config.db.dialect.supports_native_boolean
             )
         )
+
+    @property
+    def no_fk_names(self):
+        """foreign key constraints have no names in the DB"""
+        return exclusions.only_on(['sqlite'])
+
+    @property
+    def fk_names(self):
+        """foreign key constraints always have names in the DB"""
+        return exclusions.fails_on('sqlite')
index d65b1aa57e024dc9c6055768e1d85e49439539b5..307cca7fcc8ca15c2bc3c39f2f71d7a1118d570a 100644 (file)
@@ -1,5 +1,5 @@
 import sys
-from alembic.testing import TestBase
+from alembic.testing import TestBase, config
 
 from sqlalchemy import MetaData, Column, Table, Integer, String, \
     ForeignKeyConstraint
@@ -11,74 +11,81 @@ from .test_autogenerate import AutogenFixtureTest
 
 
 class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
-    __only_on__ = 'sqlite'
+    __backend__ = True
 
-    def test_extra_fk(self):
+    def test_added_fk(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table('table', m1,
-              Column('id', Integer, primary_key=True),
-              Column('test', String(10)))
+              Column('test', String(10), primary_key=True),
+              mysql_engine='InnoDB')
 
         Table('user', m1,
               Column('id', Integer, primary_key=True),
               Column('name', String(50), nullable=False),
               Column('a1', String(10), server_default="x"),
               Column('test2', String(10)),
-              ForeignKeyConstraint(['test2'], ['table.test']))
+              ForeignKeyConstraint(['test2'], ['table.test']),
+              mysql_engine='InnoDB')
 
         Table('table', m2,
-              Column('id', Integer, primary_key=True),
-              Column('test', String(10)))
+              Column('test', String(10), primary_key=True),
+              mysql_engine='InnoDB')
 
         Table('user', m2,
               Column('id', Integer, primary_key=True),
               Column('name', String(50), nullable=False),
               Column('a1', String(10), server_default="x"),
-              Column('test2', String(10))
+              Column('test2', String(10)),
+              mysql_engine='InnoDB'
               )
 
         diffs = self._fixture(m1, m2)
 
-        eq_(diffs[0][0], "drop_fk")
-        eq_(diffs[0][2].name, "user")
-        eq_(diffs[0][3].constrained_columns, ('test2',))
-        eq_(diffs[0][3].referred_table, 'table')
-        eq_(diffs[0][3].referred_columns, ('test',))
+        self._assert_fk_diff(
+            diffs[0], "remove_fk",
+            "user", ['test2'],
+            'table', ['test'],
+            conditional_name="servergenerated"
+        )
 
-    def test_missing_fk(self):
+    def test_removed_fk(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table('table', m1,
               Column('id', Integer, primary_key=True),
-              Column('test', String(10)))
+              Column('test', String(10)),
+              mysql_engine='InnoDB')
 
         Table('user', m1,
               Column('id', Integer, primary_key=True),
               Column('name', String(50), nullable=False),
               Column('a1', String(10), server_default="x"),
-              Column('test2', String(10)))
+              Column('test2', String(10)),
+              mysql_engine='InnoDB')
 
         Table('table', m2,
               Column('id', Integer, primary_key=True),
-              Column('test', String(10)))
+              Column('test', String(10)),
+              mysql_engine='InnoDB')
 
         Table('user', m2,
               Column('id', Integer, primary_key=True),
               Column('name', String(50), nullable=False),
               Column('a1', String(10), server_default="x"),
               Column('test2', String(10)),
-              ForeignKeyConstraint(['test2'], ['table.test']))
+              ForeignKeyConstraint(['test2'], ['table.test']),
+              mysql_engine='InnoDB')
 
         diffs = self._fixture(m1, m2)
 
-        eq_(diffs[0][0], "add_fk")
-        eq_(diffs[0][1].parent.table.name, "user")
-        eq_(diffs[0][2].constrained_columns, ('test2',))
-        eq_(diffs[0][2].referred_table, 'table')
-        eq_(diffs[0][2].referred_columns, ('test',))
+        self._assert_fk_diff(
+            diffs[0], "add_fk",
+            "user", ["test2"],
+            "table", ["test"]
+        )
 
     def test_no_change(self):
         m1 = MetaData()
@@ -86,25 +93,29 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
 
         Table('table', m1,
               Column('id', Integer, primary_key=True),
-              Column('test', String(10)))
+              Column('test', String(10)),
+              mysql_engine='InnoDB')
 
         Table('user', m1,
               Column('id', Integer, primary_key=True),
               Column('name', String(50), nullable=False),
               Column('a1', String(10), server_default="x"),
-              Column('test2', String(10)),
-              ForeignKeyConstraint(['test2'], ['table.test']))
+              Column('test2', Integer),
+              ForeignKeyConstraint(['test2'], ['table.id']),
+              mysql_engine='InnoDB')
 
         Table('table', m2,
               Column('id', Integer, primary_key=True),
-              Column('test', String(10)))
+              Column('test', String(10)),
+              mysql_engine='InnoDB')
 
         Table('user', m2,
               Column('id', Integer, primary_key=True),
               Column('name', String(50), nullable=False),
               Column('a1', String(10), server_default="x"),
-              Column('test2', String(10)),
-              ForeignKeyConstraint(['test2'], ['table.test']))
+              Column('test2', Integer),
+              ForeignKeyConstraint(['test2'], ['table.id']),
+              mysql_engine='InnoDB')
 
         diffs = self._fixture(m1, m2)
 
@@ -115,9 +126,9 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         Table('table', m1,
-              Column('id', Integer, primary_key=True),
-              Column('id_1', String(10)),
-              Column('id_2', String(10)))
+              Column('id_1', String(10), primary_key=True),
+              Column('id_2', String(10), primary_key=True),
+              mysql_engine='InnoDB')
 
         Table('user', m1,
               Column('id', Integer, primary_key=True),
@@ -126,12 +137,14 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
               Column('other_id_1', String(10)),
               Column('other_id_2', String(10)),
               ForeignKeyConstraint(['other_id_1', 'other_id_2'],
-                                   ['table.id_1','table.id_2']))
+                                   ['table.id_1', 'table.id_2']),
+              mysql_engine='InnoDB')
 
         Table('table', m2,
-              Column('id', Integer, primary_key=True),
-              Column('id_1', String(10)),
-              Column('id_2', String(10)))
+              Column('id_1', String(10), primary_key=True),
+              Column('id_2', String(10), primary_key=True),
+              mysql_engine='InnoDB'
+              )
 
         Table('user', m2,
               Column('id', Integer, primary_key=True),
@@ -140,32 +153,36 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
               Column('other_id_1', String(10)),
               Column('other_id_2', String(10)),
               ForeignKeyConstraint(['other_id_1', 'other_id_2'],
-                                   ['table.id_1','table.id_2']))
+                                   ['table.id_1', 'table.id_2']),
+              mysql_engine='InnoDB')
 
         diffs = self._fixture(m1, m2)
 
         eq_(diffs, [])
 
-    def test_missing_composite_fk_with_name(self):
+    def test_removed_composite_fk_with_name(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table('table', m1,
               Column('id', Integer, primary_key=True),
               Column('id_1', String(10)),
-              Column('id_2', String(10)))
+              Column('id_2', String(10)),
+              mysql_engine='InnoDB')
 
         Table('user', m1,
               Column('id', Integer, primary_key=True),
               Column('name', String(50), nullable=False),
               Column('a1', String(10), server_default="x"),
               Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)))
+              Column('other_id_2', String(10)),
+              mysql_engine='InnoDB')
 
         Table('table', m2,
               Column('id', Integer, primary_key=True),
               Column('id_1', String(10)),
-              Column('id_2', String(10)))
+              Column('id_2', String(10)),
+              mysql_engine='InnoDB')
 
         Table('user', m2,
               Column('id', Integer, primary_key=True),
@@ -174,26 +191,27 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
               Column('other_id_1', String(10)),
               Column('other_id_2', String(10)),
               ForeignKeyConstraint(['other_id_1', 'other_id_2'],
-                                   ['table.id_1','table.id_2'],
-                                   name='fk_test_name'))
+                                   ['table.id_1', 'table.id_2'],
+                                   name='fk_test_name'),
+              mysql_engine='InnoDB')
 
         diffs = self._fixture(m1, m2)
 
-        eq_(diffs[0][0], "add_fk")
-        eq_(diffs[0][1].parent.table.name, "user")
-        eq_(diffs[0][1].name, "fk_test_name")
-        eq_(diffs[0][2].constrained_columns, ('other_id_1', 'other_id_2'))
-        eq_(diffs[0][2].referred_table, 'table')
-        eq_(diffs[0][2].referred_columns, ('id_1', 'id_2'))
+        self._assert_fk_diff(
+            diffs[0], "add_fk",
+            "user", ['other_id_1', 'other_id_2'],
+            'table', ['id_1', 'id_2'],
+            name="fk_test_name"
+        )
 
-    def test_extra_composite_fk(self):
+    def test_added_composite_fk(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table('table', m1,
-              Column('id', Integer, primary_key=True),
-              Column('id_1', String(10)),
-              Column('id_2', String(10)))
+              Column('id_1', String(10), primary_key=True),
+              Column('id_2', String(10), primary_key=True),
+              mysql_engine='InnoDB')
 
         Table('user', m1,
               Column('id', Integer, primary_key=True),
@@ -203,24 +221,178 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
               Column('other_id_2', String(10)),
               ForeignKeyConstraint(['other_id_1', 'other_id_2'],
                                    ['table.id_1', 'table.id_2'],
-                                   name='fk_test_name'))
+                                   name='fk_test_name'),
+              mysql_engine='InnoDB')
 
         Table('table', m2,
-              Column('id', Integer, primary_key=True),
-              Column('id_1', String(10)),
-              Column('id_2', String(10)))
+              Column('id_1', String(10), primary_key=True),
+              Column('id_2', String(10), primary_key=True),
+              mysql_engine='InnoDB')
 
         Table('user', m2,
               Column('id', Integer, primary_key=True),
               Column('name', String(50), nullable=False),
               Column('a1', String(10), server_default="x"),
               Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)))
+              Column('other_id_2', String(10)),
+              mysql_engine='InnoDB')
 
         diffs = self._fixture(m1, m2)
 
-        eq_(diffs[0][0], "drop_fk")
-        eq_(diffs[0][2].name, "user")
-        eq_(diffs[0][3].constrained_columns, ('other_id_1', 'other_id_2'))
-        eq_(diffs[0][3].referred_table, 'table')
-        eq_(diffs[0][3].referred_columns, ('id_1', 'id_2'))
\ No newline at end of file
+        self._assert_fk_diff(
+            diffs[0], "remove_fk",
+            "user", ['other_id_1', 'other_id_2'],
+            "table", ['id_1', 'id_2'],
+            conditional_name="fk_test_name"
+        )
+
+
+class IncludeHooksTest(AutogenFixtureTest, TestBase):
+    __backend__ = True
+    __requires__ = 'fk_names',
+
+    def test_remove_connection_fk(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        ref = Table(
+            'ref', m1, Column('id', Integer, primary_key=True),
+            mysql_engine='InnoDB')
+        t1 = Table(
+            't', m1, Column('x', Integer), Column('y', Integer),
+            mysql_engine='InnoDB')
+        t1.append_constraint(
+            ForeignKeyConstraint([t1.c.x], [ref.c.id], name="fk1")
+        )
+        t1.append_constraint(
+            ForeignKeyConstraint([t1.c.y], [ref.c.id], name="fk2")
+        )
+
+        ref = Table(
+            'ref', m2, Column('id', Integer, primary_key=True),
+            mysql_engine='InnoDB')
+        Table(
+            't', m2, Column('x', Integer), Column('y', Integer),
+            mysql_engine='InnoDB')
+
+        def include_object(object_, name, type_, reflected, compare_to):
+            return not (
+                isinstance(object_, ForeignKeyConstraint) and
+                type_ == 'foreignkey' and reflected and name == 'fk1')
+
+        diffs = self._fixture(m1, m2, object_filters=[include_object])
+
+        self._assert_fk_diff(
+            diffs[0], "remove_fk",
+            't', ['y'], 'ref', ['id'],
+            conditional_name='fk2'
+        )
+        eq_(len(diffs), 1)
+
+    def test_add_metadata_fk(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        Table(
+            'ref', m1,
+            Column('id', Integer, primary_key=True), mysql_engine='InnoDB')
+        Table(
+            't', m1,
+            Column('x', Integer), Column('y', Integer), mysql_engine='InnoDB')
+
+        ref = Table(
+            'ref', m2, Column('id', Integer, primary_key=True),
+            mysql_engine='InnoDB')
+        t2 = Table(
+            't', m2, Column('x', Integer), Column('y', Integer),
+            mysql_engine='InnoDB')
+        t2.append_constraint(
+            ForeignKeyConstraint([t2.c.x], [ref.c.id], name="fk1")
+        )
+        t2.append_constraint(
+            ForeignKeyConstraint([t2.c.y], [ref.c.id], name="fk2")
+        )
+
+        def include_object(object_, name, type_, reflected, compare_to):
+            return not (
+                isinstance(object_, ForeignKeyConstraint) and
+                type_ == 'foreignkey' and not reflected and name == 'fk1')
+
+        diffs = self._fixture(m1, m2, object_filters=[include_object])
+
+        self._assert_fk_diff(
+            diffs[0], "add_fk",
+            't', ['y'], 'ref', ['id'],
+            name='fk2'
+        )
+        eq_(len(diffs), 1)
+
+    def test_change_fk(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        r1a = Table(
+            'ref_a', m1,
+            Column('a', Integer, primary_key=True),
+            mysql_engine='InnoDB'
+        )
+        Table(
+            'ref_b', m1,
+            Column('a', Integer, primary_key=True),
+            Column('b', Integer, primary_key=True),
+            mysql_engine='InnoDB'
+        )
+        t1 = Table(
+            't', m1, Column('x', Integer),
+            Column('y', Integer), Column('z', Integer),
+            mysql_engine='InnoDB')
+        t1.append_constraint(
+            ForeignKeyConstraint([t1.c.x], [r1a.c.a], name="fk1")
+        )
+        t1.append_constraint(
+            ForeignKeyConstraint([t1.c.y], [r1a.c.a], name="fk2")
+        )
+
+        Table(
+            'ref_a', m2,
+            Column('a', Integer, primary_key=True),
+            mysql_engine='InnoDB'
+        )
+        r2b = Table(
+            'ref_b', m2,
+            Column('a', Integer, primary_key=True),
+            Column('b', Integer, primary_key=True),
+            mysql_engine='InnoDB'
+        )
+        t2 = Table(
+            't', m2, Column('x', Integer),
+            Column('y', Integer), Column('z', Integer),
+            mysql_engine='InnoDB')
+        t2.append_constraint(
+            ForeignKeyConstraint(
+                [t2.c.x, t2.c.z], [r2b.c.a, r2b.c.b], name="fk1")
+        )
+        t2.append_constraint(
+            ForeignKeyConstraint(
+                [t2.c.y, t2.c.z], [r2b.c.a, r2b.c.b], name="fk2")
+        )
+
+        def include_object(object_, name, type_, reflected, compare_to):
+            return not (
+                isinstance(object_, ForeignKeyConstraint) and
+                type_ == 'foreignkey' and name == 'fk1'
+            )
+
+        diffs = self._fixture(m1, m2, object_filters=[include_object])
+
+        self._assert_fk_diff(
+            diffs[0], "remove_fk",
+            't', ['y'], 'ref_a', ['a'],
+            name='fk2'
+        )
+        self._assert_fk_diff(
+            diffs[1], "add_fk",
+            't', ['y', 'z'], 'ref_b', ['a', 'b'],
+            name='fk2'
+        )
+        eq_(len(diffs), 2)
index 7351f7907f93ff0ba6834ec3f0f9ba0f287e4576..4fb3a532345f52eb51821920e04ced38916bd0bd 100644 (file)
@@ -14,6 +14,7 @@ from alembic.testing import config
 from alembic.testing.mock import Mock
 from alembic.testing.env import staging_env, clear_staging_env
 from alembic.testing import eq_
+from alembic.ddl.base import _fk_spec
 
 py3k = sys.version_info >= (3, )
 
@@ -37,7 +38,41 @@ def new_table(table, parent):
     names_in_this_test.add(table.name)
 
 
-class AutogenTest(object):
+class _ComparesFKs(object):
+    def _assert_fk_diff(
+            self, diff, type_, source_table, source_columns,
+            target_table, target_columns, name=None, conditional_name=None,
+            source_schema=None):
+        # the public API for ForeignKeyConstraint was not very rich
+        # in 0.7, 0.8, so here we use the well-known but slightly
+        # private API to get at its elements
+        (fk_source_schema, fk_source_table,
+         fk_source_columns, fk_target_schema, fk_target_table,
+         fk_target_columns) = _fk_spec(diff[1])
+
+        eq_(diff[0], type_)
+        eq_(fk_source_table, source_table)
+        eq_(fk_source_columns, source_columns)
+        eq_(fk_target_table, target_table)
+        eq_(fk_source_schema, source_schema)
+
+        eq_([elem.column.name for elem in diff[1].elements],
+            target_columns)
+        if conditional_name is not None:
+            if config.requirements.no_fk_names.enabled:
+                eq_(diff[1].name, None)
+            elif conditional_name == 'servergenerated':
+                fks = Inspector.from_engine(self.bind).\
+                    get_foreign_keys(source_table)
+                server_fk_name = fks[0]['name']
+                eq_(diff[1].name, server_fk_name)
+            else:
+                eq_(diff[1].name, conditional_name)
+        else:
+            eq_(diff[1].name, name)
+
+
+class AutogenTest(_ComparesFKs):
 
     @classmethod
     def _get_bind(cls):
@@ -88,7 +123,7 @@ class AutogenTest(object):
         self.conn.close()
 
 
-class AutogenFixtureTest(object):
+class AutogenFixtureTest(_ComparesFKs):
 
     def _fixture(
             self, m1, m2, include_schemas=False,
@@ -425,11 +460,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         eq_(repr(diffs[5][0][5]), "NUMERIC(precision=8, scale=2)")
         eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
 
-        eq_(diffs[6][0], 'add_fk')
-        eq_(diffs[6][1].column.name, 'id')
-        eq_(diffs[6][1].parent.table.name, 'order')
-        eq_(diffs[6][2].referred_table, 'user')
-        eq_(diffs[6][2].constrained_columns, ('user_id',))
+        self._assert_fk_diff(
+            diffs[6], "add_fk",
+            "order", ["user_id"],
+            "user", ["id"]
+        )
 
         eq_(diffs[7][0][0], "modify_default")
         eq_(diffs[7][0][1], None)
@@ -524,7 +559,7 @@ nullable=True))
                type_=sa.Numeric(precision=10, scale=2),
                nullable=True,
                existing_server_default=sa.text('0'))
-    op.create_foreign_key(None, order, user, ['user_id'], ['id'], schema=None)
+    op.create_foreign_key(None, order, user, ['user_id'], ['id'])
     op.alter_column('user', 'a1',
                existing_type=sa.TEXT(),
                server_default='x',
@@ -548,14 +583,14 @@ nullable=True))
                existing_type=sa.TEXT(),
                server_default=None,
                existing_nullable=True)
-    op.drop_constraint(None, order, type_='foreignkey')
+    op.drop_constraint(None, 'order', type_='foreignkey')
     op.alter_column('order', 'amount',
                existing_type=sa.Numeric(precision=10, scale=2),
                type_=sa.NUMERIC(precision=8, scale=2),
                nullable=False,
                existing_server_default=sa.text('0'))
     op.drop_column('order', 'user_id')
-    op.drop_constraint('uq_email', 'address')
+    op.drop_constraint('uq_email', 'address', type_='unique')
     op.drop_column('address', 'street')
     op.create_table('extra',
     sa.Column('x', sa.CHAR(), nullable=True),
@@ -595,7 +630,7 @@ nullable=True))
                type_=sa.Numeric(precision=10, scale=2),
                nullable=True,
                existing_server_default=sa.text('0'))
-        batch_op.create_foreign_key(None, order, user, ['user_id'], ['id'], schema=None)
+        batch_op.create_foreign_key(None, order, user, ['user_id'], ['id'])
 
     with op.batch_alter_table('user', schema=None) as batch_op:
         batch_op.alter_column('a1',
@@ -624,7 +659,7 @@ nullable=True))
                existing_nullable=True)
 
     with op.batch_alter_table('order', schema=None) as batch_op:
-        batch_op.drop_constraint(None, order, type_='foreignkey')
+        batch_op.drop_constraint(None, type_='foreignkey')
         batch_op.alter_column('amount',
                existing_type=sa.Numeric(precision=10, scale=2),
                type_=sa.NUMERIC(precision=8, scale=2),
@@ -633,7 +668,7 @@ nullable=True))
         batch_op.drop_column('user_id')
 
     with op.batch_alter_table('address', schema=None) as batch_op:
-        batch_op.drop_constraint('uq_email')
+        batch_op.drop_constraint('uq_email', type_='unique')
         batch_op.drop_column('street')
 
     op.create_table('extra',
@@ -815,12 +850,12 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
         eq_(repr(diffs[5][0][5]), "NUMERIC(precision=8, scale=2)")
         eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
 
-        eq_(diffs[6][0], 'add_fk')
-        eq_(diffs[6][1].column.name, 'id')
-        eq_(diffs[6][1].parent.table.name, 'order')
-        eq_(diffs[6][1].parent.table.schema, config.test_schema)
-        eq_(diffs[6][2].referred_table, 'user')
-        eq_(diffs[6][2].constrained_columns, ('user_id',))
+        self._assert_fk_diff(
+            diffs[6], "add_fk",
+            "order", ["user_id"],
+            "user", ["id"],
+            source_schema=config.test_schema
+        )
 
         eq_(diffs[7][0][0], "modify_default")
         eq_(diffs[7][0][1], self.schema)
@@ -900,7 +935,7 @@ schema='%(schema)s')
                existing_server_default=sa.text('0'),
                schema='%(schema)s')
     op.create_foreign_key(None, order, user, ['user_id'], ['id'], \
-schema='%(schema)s')
+source_schema='%(schema)s', referent_schema='%(schema)s')
     op.alter_column('user', 'a1',
                existing_type=sa.TEXT(),
                server_default='x',
@@ -928,7 +963,7 @@ autoincrement=False, nullable=True), schema='%(schema)s')
                server_default=None,
                existing_nullable=True,
                schema='%(schema)s')
-    op.drop_constraint(None, order, type_='foreignkey')
+    op.drop_constraint(None, 'order', schema='%(schema)s', type_='foreignkey')
     op.alter_column('order', 'amount',
                existing_type=sa.Numeric(precision=10, scale=2),
                type_=sa.NUMERIC(precision=8, scale=2),
@@ -936,7 +971,7 @@ autoincrement=False, nullable=True), schema='%(schema)s')
                existing_server_default=sa.text('0'),
                schema='%(schema)s')
     op.drop_column('order', 'user_id', schema='%(schema)s')
-    op.drop_constraint('uq_email', 'address', schema='test_schema')
+    op.drop_constraint('uq_email', 'address', schema='test_schema', type_='unique')
     op.drop_column('address', 'street', schema='%(schema)s')
     op.create_table('extra',
     sa.Column('x', sa.CHAR(length=1), autoincrement=False, nullable=True),
@@ -1226,11 +1261,11 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
         eq_(repr(diffs[5][0][5]), "NUMERIC(precision=8, scale=2)")
         eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
 
-        eq_(diffs[6][0], 'add_fk')
-        eq_(diffs[6][1].column.name, 'id')
-        eq_(diffs[6][1].parent.table.name, 'order')
-        eq_(diffs[6][2].referred_table, 'user')
-        eq_(diffs[6][2].constrained_columns, ('user_id',))
+        self._assert_fk_diff(
+            diffs[6], "add_fk",
+            "order", ["user_id"],
+            "user", ["id"]
+        )
 
         eq_(diffs[7][0][0], "modify_default")
         eq_(diffs[7][0][1], None)