]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Restored reflection for mysql VIEWs [ticket:748]
authorJason Kirtland <jek@discorporate.us>
Tue, 28 Aug 2007 23:44:00 +0000 (23:44 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 28 Aug 2007 23:44:00 +0000 (23:44 +0000)
- Fixed anonymous pk reflection for mysql 5.1
- Tested table and view reflection against the 'sakila' database from
  MySQL AB on 3.23 - 6.0. (with some schema adjustments, obviously)
  Maybe this will go into the SA test suite someday.
- Tweaked mysql server version tuplification, now also splitting on hyphens
- Light janitorial

CHANGES
lib/sqlalchemy/databases/mysql.py
test/dialect/mysql.py

diff --git a/CHANGES b/CHANGES
index cba42858953a0b861cc220d7d7b7fed0fa2a40a7..20945f98f682482ed3aa0e9975c32ff36bb3b8df 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -2,6 +2,11 @@
 CHANGES
 =======
 
+0.4.0beta5
+----------
+
+- MySQL views can be reflected again [ticket:748]
+
 0.4.0beta4
 ----------
 
index 298823329c9a1d6c96e23497d1d4248174c57180..af01cf84e3102a3dde8cc382e7bd5f61012583f5 100644 (file)
@@ -1543,7 +1543,8 @@ class MySQLDialect(default.DefaultDialect):
         """Convert a MySQL-python server_info string into a tuple."""
 
         version = []
-        for n in dbapi_con.get_server_info().split('.'):
+        r = re.compile('[.\-]')
+        for n in r.split(dbapi_con.get_server_info()):
             try:
                 version.append(int(n))
             except ValueError:
@@ -1567,14 +1568,6 @@ class MySQLDialect(default.DefaultDialect):
         """Load column definitions from the server."""
 
         charset = self._detect_charset(connection)
-        casing = self._detect_casing(connection, charset)
-        # is this really needed?
-        if casing == 1 and table.name != table.name.lower():
-            table.name = table.name.lower()
-            lc_alias = schema._get_table_key(table.name, table.schema)
-            table.metadata.tables[lc_alias] = table
-
-        sql = self._show_create_table(connection, table, charset)
 
         try:
             reflector = self.reflector
@@ -1582,7 +1575,31 @@ class MySQLDialect(default.DefaultDialect):
             self.reflector = reflector = \
                 MySQLSchemaReflector(self.identifier_preparer)
 
-        reflector.reflect(connection, table, sql, charset, only=include_columns)
+        sql = self._show_create_table(connection, table, charset) 
+        if sql.startswith('CREATE ALGORITHM'):
+            # Adapt views to something table-like.
+            columns = self._describe_table(connection, table, charset)
+            sql = reflector._describe_to_create(table, columns)
+
+        self._adjust_casing(connection, table, charset)
+
+        return reflector.reflect(connection, table, sql, charset,
+                                 only=include_columns)
+
+    def _adjust_casing(self, connection, table, charset=None):
+        """Adjust Table name to the server case sensitivity, if needed."""
+        
+        if charset is None:
+            charset = self._detect_charset(connection)
+            
+        casing = self._detect_casing(connection, charset)
+
+        # For winxx database hosts.  TODO: is this really needed?
+        if casing == 1 and table.name != table.name.lower():
+            table.name = table.name.lower()
+            lc_alias = schema._get_table_key(table.name, table.schema)
+            table.metadata.tables[lc_alias] = table
+        
 
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
@@ -1701,9 +1718,36 @@ class MySQLDialect(default.DefaultDialect):
 
         return sql
 
+    def _describe_table(self, connection, table, charset=None,
+                             full_name=None):
+        """Run DESCRIBE for a ``Table`` and return processed rows."""
+
+        if full_name is None:
+            full_name = self.identifier_preparer.format_table(table)
+        st = "DESCRIBE %s" % full_name
+
+        rp, rows = None, None
+        try:
+            try:
+                rp = connection.execute(st)
+            except exceptions.SQLError, e:
+                if e.orig.args[0] == 1146:
+                    raise exceptions.NoSuchTableError(full_name)
+                else:
+                    raise
+            rows = _compat_fetchall(rp, charset=charset)
+        finally:
+            if rp:
+                rp.close()
+        return rows
 
 class _MySQLPythonRowProxy(object):
-    """Return consistent column values for all versions of MySQL-python (esp. alphas) and Unicode settings."""
+    """Return consistent column values for all versions of MySQL-python.
+
+    Smooth over data type issues (esp. with alpha driver versions) and
+    normalize strings as Unicode regardless of user-configured driver
+    encoding settings.
+    """
 
     # Some MySQL-python versions can return some columns as
     # sets.Set(['value']) (seriously) but thankfully that doesn't
@@ -1732,13 +1776,10 @@ class _MySQLPythonRowProxy(object):
 
 class MySQLCompiler(compiler.DefaultCompiler):
     operators = compiler.DefaultCompiler.operators.copy()
-    operators.update(
-        {
-            sql_operators.concat_op: \
-              lambda x, y: "concat(%s, %s)" % (x, y),
-            sql_operators.mod: '%%'
-        }
-    )
+    operators.update({
+        sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
+        sql_operators.mod: '%%'
+    })
 
     def visit_cast(self, cast, **kwargs):
         if isinstance(cast.type, (sqltypes.Date, sqltypes.Time,
@@ -1890,7 +1931,7 @@ class MySQLSchemaReflector(object):
             elif not line:
                 pass
             else:
-                type_, spec = self.constraints(line)
+                type_, spec = self.parse_constraints(line)
                 if type_ is None:
                     warnings.warn(
                         RuntimeWarning("Unknown schema content: %s" %
@@ -1917,10 +1958,10 @@ class MySQLSchemaReflector(object):
 
         # Don't override by default.
         if table.name is None:
-            table.name = self.name(line)
+            table.name = self.parse_name(line)
 
     def _add_column(self, table, line, charset, only=None):
-        spec = self.column(line)
+        spec = self.parse_column(line)
         if not spec:
             warnings.warn(RuntimeWarning(
                 "Unknown column definition %s" % line))
@@ -2097,7 +2138,7 @@ class MySQLSchemaReflector(object):
           The final line of SHOW CREATE TABLE output.
         """
 
-        options = self.table_options(line)
+        options = self.parse_table_options(line)
         for nope in ('auto_increment', 'data_directory', 'index_directory'):
             options.pop(nope, None)
 
@@ -2184,8 +2225,8 @@ class MySQLSchemaReflector(object):
         # KEY_BLOCK_SIZE size | WITH PARSER name
         self._re_key = _re_compile(
             r'  '
-            r'(?:(?P<type>\S+) )?KEY +'
-            r'(?:%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?'
+            r'(?:(?P<type>\S+) )?KEY'
+            r'(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?'
             r'(?: +USING +(?P<using>\S+))?'
             r' +\((?P<columns>.+?)\)'
             r'(?: +KEY_BLOCK_SIZE +(?P<keyblock>\S+))?'
@@ -2260,7 +2301,7 @@ class MySQLSchemaReflector(object):
         self._pr_options.append(_pr_compile(regex))
 
 
-    def name(self, line):
+    def parse_name(self, line):
         """Extract the table name.
 
         line
@@ -2273,7 +2314,7 @@ class MySQLSchemaReflector(object):
             return None
         return cleanup(m.group('name'))
 
-    def column(self, line):
+    def parse_column(self, line):
         """Extract column details.
 
         Falls back to a 'minimal support' variant if full parse fails.
@@ -2294,7 +2335,7 @@ class MySQLSchemaReflector(object):
             return spec
         return None
 
-    def constraints(self, line):
+    def parse_constraints(self, line):
         """Parse a KEY or CONSTRAINT line.
 
         line
@@ -2330,7 +2371,7 @@ class MySQLSchemaReflector(object):
         # No match.
         return (None, line)
         
-    def table_options(self, line):
+    def parse_table_options(self, line):
         """Build a dictionary of all reflected table-level options.
 
         line
@@ -2356,6 +2397,51 @@ class MySQLSchemaReflector(object):
 
         return options
 
+    def _describe_to_create(self, table, columns):
+        """Re-format DESCRIBE output as a SHOW CREATE TABLE string.
+
+        DESCRIBE is a much simpler reflection and is sufficient for
+        reflecting views for runtime use.  This method formats DDL
+        for columns only- keys are omitted.
+
+        `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
+        SHOW FULL COLUMNS FROM rows must be rearranged for use with
+        this function.
+        """
+
+        buffer = []
+        for row in columns:
+            (name, col_type, nullable, default, extra) = \
+                   [row[i] for i in (0, 1, 2, 4, 5)]
+
+            line = [' ']
+            line.append(self.preparer.quote_identifier(name))
+            line.append(col_type)
+            if not nullable:
+                line.append('NOT NULL')
+            if default:
+                if 'auto_increment' in default:
+                    pass
+                elif (col_type.startswith('timestamp') and
+                      default.startswith('C')):
+                    line.append('DEFAULT')
+                    line.append(default)
+                elif default == 'NULL':
+                    line.append('DEFAULT')
+                    line.append(default)
+                else:
+                    line.append('DEFAULT')
+                    line.append("'%s'" % default.replace("'", "''"))
+            if extra:
+                line.append(extra)
+
+            buffer.append(' '.join(line))
+
+        return ''.join([('CREATE TABLE %s (\n' %
+                         self.preparer.quote_identifier(table.name)),
+                        ',\n'.join(buffer),
+                        '\n) '])
+
     def _parse_keyexprs(self, identifiers):
         """Unpack '"col"(2),"col" ASC'-ish strings into components."""
 
index 7e0a5547458a379d5fa5aa12939c917d46d8ce35..1cf6b437c0098bbfaf1108032134a26ab6b8b67d 100644 (file)
@@ -587,19 +587,37 @@ class TypesTest(AssertMixin):
 
         columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)]
 
-        m = MetaData(testbase.db)
+        db = testbase.db
+        m = MetaData(db)
         t_table = Table('mysql_types', m, *columns)
-        m.drop_all()
-        m.create_all()
+        try:
+            m.create_all()
         
-        m2 = MetaData(testbase.db)
-        rt = Table('mysql_types', m2, autoload=True)
-
-        expected = [len(c) > 1 and c[1] or c[0] for c in specs]
-        for i, reflected in enumerate(rt.c):
-            assert isinstance(reflected.type, type(expected[i]))
+            m2 = MetaData(db)
+            rt = Table('mysql_types', m2, autoload=True)
+            try:
+                db.execute('CREATE OR REPLACE VIEW mysql_types_v '
+                           'AS SELECT * from mysql_types')
+                rv = Table('mysql_types_v', m2, autoload=True)
+        
+                expected = [len(c) > 1 and c[1] or c[0] for c in specs]
+
+                # Early 5.0 releases seem to report more "general" for columns
+                # in a view, e.g. char -> varchar, tinyblob -> mediumblob
+                #
+                # Not sure exactly which point version has the fix.
+                if db.dialect.server_version_info(db.connect()) < (5, 0, 11):
+                    tables = rt,
+                else:
+                    tables = rt, rv
 
-        m.drop_all()
+                for table in tables:
+                    for i, reflected in enumerate(table.c):
+                        assert isinstance(reflected.type, type(expected[i]))
+            finally:
+                db.execute('DROP VIEW mysql_types_v')
+        finally:
+            m.drop_all()
 
     @testing.supported('mysql')
     def test_autoincrement(self):