]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
reflection fully implemented for mysql
authorRandall Smith <randall@tnr.cc>
Thu, 12 Mar 2009 04:05:58 +0000 (04:05 +0000)
committerRandall Smith <randall@tnr.cc>
Thu, 12 Mar 2009 04:05:58 +0000 (04:05 +0000)
lib/sqlalchemy/dialects/mysql/base.py
test/reflection.py

index b3e52f9df6ddb080d18a3dd8cac5edecf7e5d6f2..627ed84a1d23eba5318f83bf7efb529cdd3f990e 100644 (file)
@@ -1835,6 +1835,38 @@ class MySQLDialect(default.DefaultDialect):
             self.preparer = MySQLIdentifierPreparer
         self.identifier_preparer = self.preparer(self)
 
+    @reflection.cache
+    def get_schema_names(self, connection, **kw):
+        rp = connection.execute("SHOW schemas")
+        return [r[0] for r in rp]
+
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        if schema is None:
+            schema = self.get_default_schema_name(connection)
+        if self.server_version_info < (5, 0, 2):
+            return self.table_names(connection, schema)
+        charset = self._connection_charset
+        rp = connection.execute("SHOW FULL TABLES FROM %s" %
+                self.identifier_preparer.quote_identifier(schema))
+        return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
+                                                    if row[1] == 'BASE TABLE']
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        charset = self._connection_charset
+        if self.server_version_info < (5, 0, 2):
+            raise NotImplementedError
+        if schema is None:
+            schema = self.get_default_schema_name(connection)
+        if self.server_version_info < (5, 0, 2):
+            return self.table_names(connection, schema)
+        charset = self._connection_charset
+        rp = connection.execute("SHOW FULL TABLES FROM %s" %
+                self.identifier_preparer.quote_identifier(schema))
+        return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
+                                                    if row[1] == 'VIEW']
+
     @reflection.cache
     def get_columns(self, connection, table_name, schema=None, **kw):
 
@@ -1881,7 +1913,7 @@ class MySQLDialect(default.DefaultDialect):
                     con_kw[opt] = spec[opt]
 
             fkey_d = {
-                'name' : None,
+                'name' : spec['name'],
                 'constrained_columns' : loc_names,
                 'referred_schema' : ref_schema,
                 'referred_table' : ref_name,
@@ -1892,7 +1924,7 @@ class MySQLDialect(default.DefaultDialect):
         return fkeys
 
     @reflection.cache
-    def get_indexes(self, connection, table_name, schema, **kw):
+    def get_indexes(self, connection, table_name, schema=None, **kw):
 
         parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
         
@@ -1918,11 +1950,21 @@ class MySQLDialect(default.DefaultDialect):
             indexes.append(index_d)
         return indexes
 
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+
+        charset = self._connection_charset
+        full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
+            schema, view_name))
+        sql = self._show_create_table(connection, None, charset,
+                                      full_name=full_name)
+        return sql
+
     def _parsed_state_or_create(self, connection, table_name, schema=None, **kw):
         if 'parsed_state' in kw:
             return kw['parsed_state']
         else:
-            return self._setup_parser(connection, table.name, schema)
+            return self._setup_parser(connection, table_name, schema)
         
     def _setup_parser(self, connection, table_name, schema=None):
 
index ee0d1ec6bfbd16a72a4d4320f6242d2d94b15e2b..45a17e1e22176cc33c63b697c587feb0469fb7cd 100644 (file)
@@ -128,9 +128,6 @@ class ReflectionTest(TestBase):
     def test_get_table_names_with_schema(self):
         self._test_get_table_names(getSchema())
 
-    def test_get_table_names_order_by_fk(self):
-        self._test_get_table_names(order_by='fk')
-
     def test_get_view_names(self):
         self._test_get_table_names(table_type='view')
 
@@ -263,6 +260,8 @@ class ReflectionTest(TestBase):
         meta.create_all()
         createIndexes(meta.bind, schema)
         try:
+            # The database may decide to create indexes for foreign keys, etc.
+            # so there may be more indexes than expected.
             insp = Inspector(meta.bind)
             indexes = insp.get_indexes('users', schema=schema)
             indexes.sort()
@@ -276,7 +275,13 @@ class ReflectionTest(TestBase):
                     {'unique': False,
                      'column_names': ['test1', 'test2'],
                      'name': 'users_t_idx'}]
-            self.assertEqual(indexes, expected_indexes)
+            index_names = [d['name'] for d in indexes]
+            for e_index in expected_indexes:
+                self.assertTrue(e_index['name'] in index_names)
+                index = indexes[index_names.index(e_index['name'])]
+                for key in e_index:
+                    self.assertEqual(e_index[key], index[key])
+
         finally:
             addresses.drop()
             users.drop()