From: Mike Bayer Date: Sat, 4 Aug 2012 11:18:01 +0000 (-0400) Subject: - [bug] Repaired create_foreign_key() for X-Git-Tag: rel_0_3_6~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7ee8f96b6f531ebdfb2d588e13415efc9d59a4ff;p=thirdparty%2Fsqlalchemy%2Falembic.git - [bug] Repaired create_foreign_key() for self-referential foreign keys, which weren't working at all. --- diff --git a/CHANGES b/CHANGES index ef50e63d..4fe2e7dd 100644 --- a/CHANGES +++ b/CHANGES @@ -11,6 +11,10 @@ config option is being used but SQL access isn't desired. +- [bug] Repaired create_foreign_key() for + self-referential foreign keys, which weren't working + at all. + - [bug] 'alembic' command reports an informative error message when the configuration is missing the 'script_directory' key. #63 diff --git a/alembic/operations.py b/alembic/operations.py index d044ccdc..d1f7835a 100644 --- a/alembic/operations.py +++ b/alembic/operations.py @@ -54,11 +54,16 @@ class Operations(object): local_cols, remote_cols, onupdate=None, ondelete=None): m = schema.MetaData() - t1 = schema.Table(source, m, - *[schema.Column(n, NULLTYPE) for n in local_cols]) - t2 = schema.Table(referent, m, + if source == referent: + t1_cols = local_cols + remote_cols + else: + t1_cols = local_cols + schema.Table(referent, m, *[schema.Column(n, NULLTYPE) for n in remote_cols]) + t1 = schema.Table(source, m, + *[schema.Column(n, NULLTYPE) for n in t1_cols]) + f = schema.ForeignKeyConstraint(local_cols, ["%s.%s" % (referent, n) for n in remote_cols], diff --git a/tests/test_op.py b/tests/test_op.py index b7abaef9..b5dbc26d 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -196,6 +196,14 @@ def test_add_foreign_key_ondelete(): "REFERENCES t2 (bat, hoho) ON DELETE CASCADE" ) +def test_add_foreign_key_self_referential(): + context = op_fixture() + op.create_foreign_key("fk_test", "t1", "t1", ["foo"], ["bar"]) + context.assert_( + "ALTER TABLE t1 ADD CONSTRAINT fk_test " + "FOREIGN KEY(foo) REFERENCES t1 (bar)" + ) + def test_add_check_constraint(): context = op_fixture() op.create_check_constraint(