]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added order_by='foreign_key' option to help with dependency checking
authorRandall Smith <randall@tnr.cc>
Thu, 5 Feb 2009 06:03:49 +0000 (06:03 +0000)
committerRandall Smith <randall@tnr.cc>
Thu, 5 Feb 2009 06:03:49 +0000 (06:03 +0000)
lib/sqlalchemy/engine/reflection.py
test/reflection.py

index 99e86499e5633ef771b28f3feda3ebb1cce78a61..746e2e94b1c602270c63b7270f3029b1837f5215 100644 (file)
@@ -53,7 +53,7 @@ class Inspector(object):
                                                         self.info_cache)
         return []
 
-    def get_table_names(self, schemaname=None):
+    def get_table_names(self, schemaname=None, order_by=None):
         """Return all table names in `schemaname`.
         schemaname:
           Optional, retrieve names from a non-default schema.
@@ -63,9 +63,28 @@ class Inspector(object):
 
         """
         if hasattr(self.engine.dialect, 'get_table_names'):
-            return self.engine.dialect.get_table_names(self.conn, schemaname,
+            tnames = self.engine.dialect.get_table_names(self.conn, schemaname,
                                                        self.info_cache)
-        return self.engine.table_names(schemaname)
+        else:
+            tnames = self.engine.table_names(schemaname)
+        if order_by == 'foreign_key':
+            ordered_tnames = tnames[:]
+            # Order based on foreign key dependencies.
+            for tname in tnames:
+                table_pos = tnames.index(tname)
+                fkeys = self.get_foreign_keys(tname, schemaname)
+                for fkey in fkeys:
+                    rtable = fkey['referred_table']
+                    if rtable in ordered_tnames:
+                        ref_pos = ordered_tnames.index(rtable)
+                        # Make sure it's lower in the list than anything it
+                        # references.
+                        if table_pos > ref_pos:
+                            ordered_tnames.pop(table_pos) # rtable moves up 1
+                            # insert just below rtable
+                            ordered_tnames.index(ref_pos, tname)
+            tnames = ordered_tnames
+        return tnames
 
     def get_view_names(self, schemaname=None):
         """Return all view names in `schemaname`.
@@ -122,7 +141,7 @@ class Inspector(object):
         """Return information about primary keys in `tablename`.
 
         Given a string `tablename`, and an optional string `schemaname`, return 
-        primary key information as a list of column names:
+        primary key information as a list of column names.
 
         """
 
index e3d911a4c92ab68f09b6d93a8fe9d3f57b072950..863dc526ba764ae16a6ea7dc9d9fbab90cfc03b0 100644 (file)
@@ -89,7 +89,8 @@ class ReflectionTest(TestBase):
         insp = Inspector(meta.bind)
         self.assert_(getSchema() in insp.get_schema_names())
 
-    def _test_get_table_names(self, schemaname=None, table_type='table'):
+    def _test_get_table_names(self, schemaname=None, table_type='table',
+                              order_by=None):
         meta = MetaData(testing.db)
         (users, addresses) = createTables(meta, schemaname)
         meta.create_all()
@@ -101,9 +102,13 @@ class ReflectionTest(TestBase):
                 table_names.sort()
                 answer = [u'email_addresses_v', u'users_v']
             else:
-                table_names = insp.get_table_names(schemaname)
+                table_names = insp.get_table_names(schemaname,
+                                                   order_by=order_by)
                 table_names.sort()
-                answer = [u'email_addresses', 'users']
+                if order_by == 'foreign_key':
+                    answer = [u'users', 'email_addresses']
+                else:
+                    answer = [u'email_addresses', 'users']
             self.assertEqual(table_names, answer)
         finally:
             dropViews(meta.bind, schemaname)
@@ -116,6 +121,9 @@ class ReflectionTest(TestBase):
     def test_get_table_names_with_schema(self):
         self._test_get_table_names(getSchema())
 
+    def test_get_table_names_order_by_fk(self):
+        self._test_get_table_names(order_by='fk')
+
     def test_get_view_names(self):
         self._test_get_table_names(table_type='view')