]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed bug where using the ``column_reflect`` event to change the ``.key``
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Aug 2013 21:25:44 +0000 (17:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Aug 2013 21:26:11 +0000 (17:26 -0400)
of the incoming :class:`.Column` would prevent primary key constraints,
indexes, and foreign key constraints from being correctly reflected.
Also in 0.8.3. [ticket:2811]

Conflicts:
doc/build/changelog/changelog_09.rst

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/engine/reflection.py
test/engine/test_reflection.py

index 91aa2f7e9658a5b3a4da57fa8295f7179279e067..34fa58e0cd888609be19c275992f351f97126555 100644 (file)
@@ -6,6 +6,14 @@
 .. changelog::
     :version: 0.8.3
 
+    .. change::
+        :tags: bug, sql
+        :tickets: 2811
+
+        Fixed bug where using the ``column_reflect`` event to change the ``.key``
+        of the incoming :class:`.Column` would prevent primary key constraints,
+        indexes, and foreign key constraints from being correctly reflected.
+
     .. change::
         :tags: feature
 
index 90f21db0948b5b5ffff40d57d618bea4e81f1b03..c1c546d843932d7cd9a530aaaff8076fd7c4a9e7 100644 (file)
@@ -368,7 +368,8 @@ class Inspector(object):
 
         # table attributes we might need.
         reflection_options = dict(
-            (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
+            (k, table.kwargs.get(k))
+            for k in dialect.reflection_options if k in table.kwargs)
 
         schema = table.schema
         table_name = table.name
@@ -394,8 +395,12 @@ class Inspector(object):
 
         # columns
         found_table = False
+        cols_by_orig_name = {}
+
         for col_d in self.get_columns(table_name, schema, **tblkw):
             found_table = True
+            orig_name = col_d['name']
+
             table.dispatch.column_reflect(self, table, col_d)
 
             name = col_d['name']
@@ -433,7 +438,9 @@ class Inspector(object):
                     sequence.increment = seq['increment']
                 colargs.append(sequence)
 
-            col = sa_schema.Column(name, coltype, *colargs, **col_kw)
+            cols_by_orig_name[orig_name] = col = \
+                        sa_schema.Column(name, coltype, *colargs, **col_kw)
+
             table.append_column(col)
 
         if not found_table:
@@ -443,9 +450,9 @@ class Inspector(object):
         pk_cons = self.get_pk_constraint(table_name, schema, **tblkw)
         if pk_cons:
             pk_cols = [
-                table.c[pk]
+                cols_by_orig_name[pk]
                 for pk in pk_cons['constrained_columns']
-                if pk in table.c and pk not in exclude_columns
+                if pk in cols_by_orig_name and pk not in exclude_columns
             ]
             pk_cols += [
                 pk
@@ -463,7 +470,11 @@ class Inspector(object):
         fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
         for fkey_d in fkeys:
             conname = fkey_d['name']
-            constrained_columns = fkey_d['constrained_columns']
+            constrained_columns = [
+                                    cols_by_orig_name[c].key
+                                    if c in cols_by_orig_name else c
+                                    for c in fkey_d['constrained_columns']
+                                ]
             if exclude_columns and set(constrained_columns).intersection(
                                 exclude_columns):
                 continue
@@ -503,5 +514,5 @@ class Inspector(object):
                     "Omitting %s KEY for (%s), key covers omitted columns." %
                     (flavor, ', '.join(columns)))
                 continue
-            sa_schema.Index(name, *[table.columns[c] for c in columns],
+            sa_schema.Index(name, *[cols_by_orig_name[c] for c in columns],
                          **dict(unique=unique))
index c490efff2cb77d99ba119c3a778759f05bbb1099..5aa1f7a3dabcfe2f3236cbd39cc3ab2dd8a6e9d4 100644 (file)
@@ -1426,6 +1426,12 @@ class ColumnEventsTest(fixtures.TestBase):
             cls.metadata,
             Column('x', sa.Integer, primary_key=True),
         )
+        cls.related = Table(
+            'related',
+            cls.metadata,
+            Column('q', sa.Integer, sa.ForeignKey('to_reflect.x'))
+        )
+        sa.Index("some_index", cls.to_reflect.c.x)
         cls.metadata.create_all(testing.db)
 
     @classmethod
@@ -1435,7 +1441,7 @@ class ColumnEventsTest(fixtures.TestBase):
     def teardown(self):
         events.SchemaEventTarget.dispatch._clear()
 
-    def _do_test(self, col, update, assert_):
+    def _do_test(self, col, update, assert_, tablename="to_reflect"):
         # load the actual Table class, not the test
         # wrapper
         from sqlalchemy.schema import Table
@@ -1445,22 +1451,46 @@ class ColumnEventsTest(fixtures.TestBase):
             if column_info['name'] == col:
                 column_info.update(update)
 
-        t = Table('to_reflect', m, autoload=True, listeners=[
+        t = Table(tablename, m, autoload=True, listeners=[
             ('column_reflect', column_reflect),
         ])
         assert_(t)
 
         m = MetaData(testing.db)
         event.listen(Table, 'column_reflect', column_reflect)
-        t2 = Table('to_reflect', m, autoload=True)
+        t2 = Table(tablename, m, autoload=True)
         assert_(t2)
 
     def test_override_key(self):
+        def assertions(table):
+            eq_(table.c.YXZ.name, "x")
+            eq_(set(table.primary_key), set([table.c.YXZ]))
+            idx = list(table.indexes)[0]
+            eq_(idx.columns, [table.c.YXZ])
+
         self._do_test(
             "x", {"key": "YXZ"},
-            lambda table: eq_(table.c.YXZ.name, "x")
+            assertions
         )
 
+    def test_override_key_fk(self):
+        m = MetaData(testing.db)
+        def column_reflect(insp, table, column_info):
+
+            if column_info['name'] == 'q':
+                column_info['key'] = 'qyz'
+            elif column_info['name'] == 'x':
+                column_info['key'] = 'xyz'
+
+        to_reflect = Table("to_reflect", m, autoload=True, listeners=[
+            ('column_reflect', column_reflect),
+        ])
+        related = Table("related", m, autoload=True, listeners=[
+            ('column_reflect', column_reflect),
+            ])
+
+        assert related.c.qyz.references(to_reflect.c.xyz)
+
     def test_override_type(self):
         def assert_(table):
             assert isinstance(table.c.x.type, sa.String)