]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Patch up MySQL reflection issues with old server versions, alpha drivers,
authorJason Kirtland <jek@discorporate.us>
Thu, 12 Jul 2007 01:11:38 +0000 (01:11 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 12 Jul 2007 01:11:38 +0000 (01:11 +0000)
  quoting, and connection encoding.

lib/sqlalchemy/databases/mysql.py

index 769e69f962e72bffaaa370e0f15658fd3f9172c5..2fd6e48d0d1090ea09fb16fed48589a23d00a946 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import sys, StringIO, string, types, re, datetime, inspect
+import sys, StringIO, string, types, re, datetime, inspect, warnings
 
 from sqlalchemy import sql,engine,schema,ansisql
 from sqlalchemy.engine import default
@@ -1044,25 +1044,11 @@ class MySQLDialect(ansisql.ANSIDialect):
         return self._default_schema_name
 
     def has_table(self, connection, table_name, schema=None):
-        # TODO: this does not work for table names that contain multibyte characters.
-
-        # http://dev.mysql.com/doc/refman/5.0/en/error-messages-server.html
-
-        # Error: 1146 SQLSTATE: 42S02 (ER_NO_SUCH_TABLE)
-        # Message: Table '%s.%s' doesn't exist
-
-        # Error: 1046 SQLSTATE: 3D000 (ER_NO_DB_ERROR)
-        # Message: No database selected
-
-        try:
-            name = schema and ("%s.%s" % (schema, table_name)) or table_name
-            connection.execute("DESCRIBE `%s`" % name)
-            return True
-        except exceptions.SQLError, e:
-            if e.orig.args[0] in (1146, 1046): 
-                return False
-            else:
-                raise
+        if schema is not None:
+            st = 'SHOW TABLE STATUS FROM `%s` LIKE %%s' % schema
+        else:
+            st = 'SHOW TABLE STATUS LIKE %s'
+        return connection.execute(st, table_name).rowcount != 0
 
     def get_version_info(self, connectable):
         if hasattr(connectable, 'connect'):
@@ -1078,34 +1064,33 @@ class MySQLDialect(ansisql.ANSIDialect):
         return tuple(version)
 
     def reflecttable(self, connection, table):
-        # reference:  http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
-        cs = connection.execute("show variables like 'lower_case_table_names'").fetchone()[1]
-        if isinstance(cs, array):
-            cs = cs.tostring()
-        case_sensitive = int(cs) == 0
+        """Load column definitions from the server."""
+
+        decode_from = self._detect_charset(connection)
 
-        decode_from = connection.execute("show variables like 'character_set_results'").fetchone()[1]
+        # reference:
+        # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
+        row = _compat_fetch(connection.execute(
+            "SHOW VARIABLES LIKE 'lower_case_table_names'"),
+                            one=True, charset=decode_from)
+        if not row:
+            case_sensitive = True
+        else:
+            case_sensitive = row[1] in ('0', 'OFF' 'off')
 
         if not case_sensitive:
             table.name = table.name.lower()
             table.metadata.tables[table.name]= table
+
         try:
-            c = connection.execute("describe " + table.fullname, {})
+            rp = connection.execute("describe " + self._escape_table_name(table),
+                                   {})
         except:
-            raise exceptions.NoSuchTableError(table.name)
-        found_table = False
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            #print "row! " + repr(row)
-            if not found_table:
-                found_table = True
-
-            # these can come back as unicode if use_unicode=1 in the mysql connection
-            (name, type, nullable, primary_key, default) = (row[0], str(row[1]), row[2] == 'YES', row[3] == 'PRI', row[4])
-            if not isinstance(name, unicode):
-                name = name.decode(decode_from)
+            raise exceptions.NoSuchTableError(table.fullname)
+
+        for row in _compat_fetch(rp, charset=decode_from):
+            (name, type, nullable, primary_key, default) = \
+                   (row[0], str(row[1]), row[2] == 'YES', row[3] == 'PRI', row[4])
 
             match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
             col_type = match.group(1)
@@ -1113,7 +1098,6 @@ class MySQLDialect(ansisql.ANSIDialect):
             extra_1 = match.group(3)
             extra_2 = match.group(4)
 
-            #print "coltype: " + repr(col_type) + " args: " + repr(args) + "extras:" + repr(extra_1) + ' ' + repr(extra_2)
             coltype = ischema_names.get(col_type, MSString)
 
             kw = {}
@@ -1136,30 +1120,25 @@ class MySQLDialect(ansisql.ANSIDialect):
                 if col_type == 'timestamp' and default == 'CURRENT_TIMESTAMP':
                     arg = sql.text(default)
                 else:
-                    arg = default
+                    # leave defaults in the connection charset
+                    arg = default.encode(decode_from)
                 colargs.append(schema.PassiveDefault(arg))
             table.append_column(schema.Column(name, coltype, *colargs,
                                             **dict(primary_key=primary_key,
                                                    nullable=nullable,
                                                    )))
 
-        tabletype = self.moretableinfo(connection, table=table)
+        tabletype = self.moretableinfo(connection, table, decode_from)
         table.kwargs['mysql_engine'] = tabletype
 
-        if not found_table:
-            raise exceptions.NoSuchTableError(table.name)
-
-    def moretableinfo(self, connection, table):
-        """runs SHOW CREATE TABLE to get foreign key/options information about the table.
-        
-        """
-        c = connection.execute("SHOW CREATE TABLE " + table.fullname, {})
-        desc_fetched = c.fetchone()[1]
+    def moretableinfo(self, connection, table, charset=None):
+        """SHOW CREATE TABLE to get foreign key/table options."""
 
-        if not isinstance(desc_fetched, basestring):
-            # may get array.array object here, depending on version (such as mysql 4.1.14 vs. 4.1.11)
-            desc_fetched = desc_fetched.tostring()
-        desc = desc_fetched.strip()
+        rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {})
+        row = _compat_fetch(rp, one=True, charset=charset)
+        if not row:
+            raise exceptions.NoSuchTableError(table.fullname)
+        desc = row[1].strip()
 
         tabletype = ''
         lastparen = re.search(r'\)[^\)]*\Z', desc)
@@ -1179,10 +1158,70 @@ class MySQLDialect(ansisql.ANSIDialect):
 
         return tabletype
 
+    def _escape_table_name(self, table):
+        if table.schema is not None:
+            return '`%s`.`%s`' % (table.schema. table.name)
+        else:
+            return '`%s`' % table.name
+
+    def _detect_charset(self, connection):
+        """Sniff out the character set in use for connection results."""
+
+        # Note: MySQL-python 1.2.1c7 seems to ignore changes made
+        # on a connection via set_character_set()
+        
+        rs = connection.execute("show variables like 'character_set%%'")
+        opts = dict([(row[0], row[1]) for row in _compat_fetch(rs)])
+
+        if 'character_set_results' in opts:
+            return opts['character_set_results']
+        try:
+            return connection.connection.character_set_name()
+        except AttributeError:
+            # < 1.2.1 final MySQL-python drivers have no charset support
+            if 'character_set' in opts:
+                return opts['character_set']
+            else:
+                warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python.  MySQL-python >= 1.2.2 is recommended.  Assuming latin1."))
+                return 'latin1'
+
+def _compat_fetch(rp, one=False, charset=None):
+    """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
+
+    if one:
+        return _MySQLPythonRowProxy(rp.fetchone(), charset)
+    else:
+        return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
+        
+
+class _MySQLPythonRowProxy(object):
+    """Return consistent column values for all versions of MySQL-python (esp. alphas) and unicode settings."""
+
+    def __init__(self, rowproxy, charset):
+        self.rowproxy = rowproxy
+        self.charset = charset
+    def __getitem__(self, index):
+        item = self.rowproxy[index]
+        if isinstance(item, array):
+            item = item.tostring()
+        if self.charset and isinstance(item, str):
+            return item.decode(self.charset)
+        else:
+            return item
+    def __getattr__(self, attr):
+        item = getattr(self.rowproxy, attr)
+        if isinstance(item, array):
+            item = item.tostring()
+        if self.charset and isinstance(item, str):
+            return item.decode(self.charset)
+        else:
+            return item
+
+
 class MySQLCompiler(ansisql.ANSICompiler):
     def visit_cast(self, cast):
-        """hey ho MySQL supports almost no types at all for CAST"""
-        if (isinstance(cast.type, sqltypes.Date) or isinstance(cast.type, sqltypes.Time) or isinstance(cast.type, sqltypes.DateTime)):
+
+        if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
             return super(MySQLCompiler, self).visit_cast(cast)
         else:
             # so just skip the CAST altogether for now.