]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Big MySQL dialect update, mostly efficiency and style.
authorJason Kirtland <jek@discorporate.us>
Sun, 29 Jul 2007 16:13:23 +0000 (16:13 +0000)
committerJason Kirtland <jek@discorporate.us>
Sun, 29 Jul 2007 16:13:23 +0000 (16:13 +0000)
Added TINYINT [ticket:691]- whoa, how did that one go missing for so long?
Added a charset-fixing pool listener. The driver-level option doesn't help everyone with this one.
New reflector code not quite done and omiited from this commit.

lib/sqlalchemy/databases/mysql.py
test/dialect/mysql.py

index 53ef1a95b2313d4c76de444674c4f8683b39e230..2c54c251269ffaf612711bf0122f9758c12a3105 100644 (file)
@@ -5,25 +5,21 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 import re, datetime, inspect, warnings, weakref, operator
+from array import array as _array
+from decimal import Decimal
 
 from sqlalchemy import sql, schema, ansisql
 from sqlalchemy.engine import default
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 import sqlalchemy.util as util
-from array import array as _array
-from decimal import Decimal
 
-try:
-    from threading import Lock
-except ImportError:
-    from dummy_threading import Lock
 
 RESERVED_WORDS = util.Set(
     ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
      'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
      'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check',
-     'collate', 'column', 'condition', 'constraint', 'continue', 'convert',
+     'collate', 'column', 'con dition', 'constraint', 'continue', 'convert',
      'create', 'cross', 'current_date', 'current_time', 'current_timestamp',
      'current_user', 'cursor', 'database', 'databases', 'day_hour',
      'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal',
@@ -60,7 +56,7 @@ RESERVED_WORDS = util.Set(
      'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
      'read_only', 'read_write', # 5.1
      ])
-_per_connection_mutex = Lock()
+
 
 class _NumericType(object):
     "Base for MySQL numeric types."
@@ -78,6 +74,7 @@ class _NumericType(object):
             spec += ' ZEROFILL'
         return spec
 
+
 class _StringType(object):
     "Base for MySQL string types."
 
@@ -133,10 +130,11 @@ class _StringType(object):
         return "%s(%s)" % (self.__class__.__name__,
                            ','.join(['%s=%s' % (k, params[k]) for k in params]))
 
+
 class MSNumeric(sqltypes.Numeric, _NumericType):
     """MySQL NUMERIC type"""
     
-    def __init__(self, precision = 10, length = 2, asdecimal=True, **kw):
+    def __init__(self, precision=10, length=2, asdecimal=True, **kw):
         """Construct a NUMERIC.
 
         precision
@@ -173,6 +171,7 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
         else:
             return value
 
+
 class MSDecimal(MSNumeric):
     """MySQL DECIMAL type"""
 
@@ -205,6 +204,7 @@ class MSDecimal(MSNumeric):
         else:
             return self._extend("DECIMAL(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
 
+
 class MSDouble(MSNumeric):
     """MySQL DOUBLE type"""
 
@@ -240,6 +240,7 @@ class MSDouble(MSNumeric):
         else:
             return self._extend('DOUBLE')
 
+
 class MSFloat(sqltypes.Float, _NumericType):
     """MySQL FLOAT type"""
 
@@ -307,6 +308,7 @@ class MSInteger(sqltypes.Integer, _NumericType):
         else:
             return self._extend("INTEGER")
 
+
 class MSBigInteger(MSInteger):
     """MySQL BIGINTEGER type"""
 
@@ -333,6 +335,34 @@ class MSBigInteger(MSInteger):
         else:
             return self._extend("BIGINT")
 
+
+class MSTinyInteger(MSInteger):
+    """MySQL TINYINT type"""
+
+    def __init__(self, length=None, **kw):
+        """Construct a SMALLINTEGER.
+
+        length
+          Optional, maximum display width for this number.
+
+        unsigned
+          Optional.
+
+        zerofill
+          Optional. If true, values will be stored as strings left-padded with
+          zeros. Note that this does not effect the values returned by the
+          underlying database API, which continue to be numeric.
+        """
+
+        super(MSTinyInteger, self).__init__(length, **kw)
+
+    def get_col_spec(self):
+        if self.length is not None:
+            return self._extend("TINYINT(%s)" % self.length)
+        else:
+            return self._extend("TINYINT")
+
+
 class MSSmallInteger(sqltypes.Smallinteger, _NumericType):
     """MySQL SMALLINTEGER type"""
 
@@ -361,18 +391,21 @@ class MSSmallInteger(sqltypes.Smallinteger, _NumericType):
         else:
             return self._extend("SMALLINT")
 
+
 class MSDateTime(sqltypes.DateTime):
     """MySQL DATETIME type"""
 
     def get_col_spec(self):
         return "DATETIME"
 
+
 class MSDate(sqltypes.Date):
     """MySQL DATE type"""
 
     def get_col_spec(self):
         return "DATE"
 
+
 class MSTime(sqltypes.Time):
     """MySQL TIME type"""
 
@@ -386,6 +419,7 @@ class MSTime(sqltypes.Time):
         else:
             return None
 
+
 class MSTimeStamp(sqltypes.TIMESTAMP):
     """MySQL TIMESTAMP type
 
@@ -399,6 +433,7 @@ class MSTimeStamp(sqltypes.TIMESTAMP):
     def get_col_spec(self):
         return "TIMESTAMP"
 
+
 class MSYear(sqltypes.String):
     """MySQL YEAR type, for single byte storage of years 1901-2155"""
 
@@ -408,6 +443,7 @@ class MSYear(sqltypes.String):
         else:
             return "YEAR(%d)" % self.length
 
+
 class MSText(_StringType, sqltypes.TEXT):
     """MySQL TEXT type, for text up to 2^16 characters""" 
     
@@ -495,6 +531,7 @@ class MSTinyText(MSText):
     def get_col_spec(self):
         return self._extend("TINYTEXT")
 
+
 class MSMediumText(MSText):
     """MySQL MEDIUMTEXT type, for text up to 2^24 characters""" 
 
@@ -533,6 +570,7 @@ class MSMediumText(MSText):
     def get_col_spec(self):
         return self._extend("MEDIUMTEXT")
 
+
 class MSLongText(MSText):
     """MySQL LONGTEXT type, for text up to 2^32 characters""" 
 
@@ -571,6 +609,7 @@ class MSLongText(MSText):
     def get_col_spec(self):
         return self._extend("LONGTEXT")
 
+
 class MSString(_StringType, sqltypes.String):
     """MySQL VARCHAR type, for variable-length character data."""
 
@@ -617,6 +656,7 @@ class MSString(_StringType, sqltypes.String):
         else:
             return self._extend("TEXT")
 
+
 class MSChar(_StringType, sqltypes.CHAR):
     """MySQL CHAR type, for fixed-length character data."""
     
@@ -642,6 +682,7 @@ class MSChar(_StringType, sqltypes.CHAR):
     def get_col_spec(self):
         return self._extend("CHAR(%(length)s)" % {'length' : self.length})
 
+
 class MSNVarChar(_StringType, sqltypes.String):
     """MySQL NVARCHAR type, for variable-length character data in the
     server's configured national character set.
@@ -673,6 +714,7 @@ class MSNVarChar(_StringType, sqltypes.String):
         # of "NVARCHAR".
         return self._extend("VARCHAR(%(length)s)" % {'length': self.length})
     
+
 class MSNChar(_StringType, sqltypes.CHAR):
     """MySQL NCHAR type, for fixed-length character data in the
     server's configured national character set.
@@ -701,6 +743,7 @@ class MSNChar(_StringType, sqltypes.CHAR):
         # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
         return self._extend("CHAR(%(length)s)" % {'length': self.length})
 
+
 class _BinaryType(sqltypes.Binary):
     """MySQL binary types"""
 
@@ -716,6 +759,7 @@ class _BinaryType(sqltypes.Binary):
         else:
             return buffer(value)
 
+
 class MSVarBinary(_BinaryType):
     """MySQL VARBINARY type, for variable length binary data"""
 
@@ -733,6 +777,7 @@ class MSVarBinary(_BinaryType):
         else:
             return "BLOB"
 
+
 class MSBinary(_BinaryType):
     """MySQL BINARY type, for fixed length binary data"""
 
@@ -760,10 +805,10 @@ class MSBinary(_BinaryType):
         else:
             return buffer(value)
 
+
 class MSBlob(_BinaryType):
     """MySQL BLOB type, for binary data up to 2^16 bytes""" 
 
-
     def __init__(self, length=None, **kw):
         """Construct a BLOB.  Arguments are:
 
@@ -790,24 +835,28 @@ class MSBlob(_BinaryType):
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
+
 class MSTinyBlob(MSBlob):
     """MySQL TINYBLOB type, for binary data up to 2^8 bytes""" 
 
     def get_col_spec(self):
         return "TINYBLOB"
 
+
 class MSMediumBlob(MSBlob): 
     """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes"""
 
     def get_col_spec(self):
         return "MEDIUMBLOB"
 
+
 class MSLongBlob(MSBlob):
     """MySQL LONGBLOB type, for binary data up to 2^32 bytes"""
 
     def get_col_spec(self):
         return "LONGBLOB"
 
+
 class MSEnum(MSString):
     """MySQL ENUM type."""
     
@@ -877,6 +926,7 @@ class MSEnum(MSString):
     def get_col_spec(self):
         return self._extend("ENUM(%s)" % ",".join(self.__ddl_values))
 
+
 class MSBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BOOL"
@@ -899,17 +949,17 @@ class MSBoolean(sqltypes.Boolean):
 # TODO: SET, BIT
 
 colspecs = {
-    sqltypes.Integer : MSInteger,
-    sqltypes.Smallinteger : MSSmallInteger,
-    sqltypes.Numeric : MSNumeric,
-    sqltypes.Float : MSFloat,
-    sqltypes.DateTime : MSDateTime,
-    sqltypes.Date : MSDate,
-    sqltypes.Time : MSTime,
-    sqltypes.String : MSString,
-    sqltypes.Binary : MSBlob,
-    sqltypes.Boolean : MSBoolean,
-    sqltypes.TEXT : MSText,
+    sqltypes.Integer: MSInteger,
+    sqltypes.Smallinteger: MSSmallInteger,
+    sqltypes.Numeric: MSNumeric,
+    sqltypes.Float: MSFloat,
+    sqltypes.DateTime: MSDateTime,
+    sqltypes.Date: MSDate,
+    sqltypes.Time: MSTime,
+    sqltypes.String: MSString,
+    sqltypes.Binary: MSBlob,
+    sqltypes.Boolean: MSBoolean,
+    sqltypes.TEXT: MSText,
     sqltypes.CHAR: MSChar,
     sqltypes.NCHAR: MSNChar,
     sqltypes.TIMESTAMP: MSTimeStamp,
@@ -919,37 +969,37 @@ colspecs = {
 
 
 ischema_names = {
-    'bigint' : MSBigInteger,
-    'binary' : MSBinary,
-    'blob' : MSBlob,
+    'bigint': MSBigInteger,
+    'binary': MSBinary,
+    'blob': MSBlob,
     'boolean':MSBoolean,
-    'char' : MSChar,
-    'date' : MSDate,
-    'datetime' : MSDateTime,
-    'decimal' : MSDecimal,
-    'double' : MSDouble,
+    'char': MSChar,
+    'date': MSDate,
+    'datetime': MSDateTime,
+    'decimal': MSDecimal,
+    'double': MSDouble,
     'enum': MSEnum,
     'fixed': MSDecimal,
-    'float' : MSFloat,
-    'int' : MSInteger,
-    'integer' : MSInteger,
+    'float': MSFloat,
+    'int': MSInteger,
+    'integer': MSInteger,
     'longblob': MSLongBlob,
     'longtext': MSLongText,
     'mediumblob': MSMediumBlob,
-    'mediumint' : MSInteger,
+    'mediumint': MSInteger,
     'mediumtext': MSMediumText,
     'nchar': MSNChar,
     'nvarchar': MSNVarChar,
-    'numeric' : MSNumeric,
-    'smallint' : MSSmallInteger,
-    'text' : MSText,
-    'time' : MSTime,
-    'timestamp' : MSTimeStamp,
+    'numeric': MSNumeric,
+    'smallint': MSSmallInteger,
+    'text': MSText,
+    'time': MSTime,
+    'timestamp': MSTimeStamp,
     'tinyblob': MSTinyBlob,
-    'tinyint' : MSSmallInteger,
-    'tinytext' : MSTinyText,
-    'varbinary' : MSVarBinary,
-    'varchar' : MSString,
+    'tinyint': MSTinyInteger,
+    'tinytext': MSTinyText,
+    'varbinary': MSVarBinary,
+    'varchar': MSString,
 }
 
 def descriptor():
@@ -962,19 +1012,26 @@ def descriptor():
         ('host',"Hostname", None),
     ]}
 
+
 class MySQLExecutionContext(default.DefaultExecutionContext):
+    _my_is_select = re.compile(r'\s*(?:SELECT|SHOW|DESCRIBE|XA RECOVER)',
+                               re.I | re.UNICODE)
+
     def post_exec(self):
         if self.compiled.isinsert:
-            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
-                self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
+            if (not len(self._last_inserted_ids) or
+                self._last_inserted_ids[0] is None):
+                self._last_inserted_ids = ([self.cursor.lastrowid] +
+                                           self._last_inserted_ids[1:])
             
     def is_select(self):
-        return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None
+        return self._my_is_select.match(self.statement) is not None
+
 
 class MySQLDialect(ansisql.ANSIDialect):
     def __init__(self, **kwargs):
-        ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
-        self.per_connection = weakref.WeakKeyDictionary()
+        ansisql.ANSIDialect.__init__(self, default_paramstyle='format',
+                                     **kwargs)
 
     def dbapi(cls):
         import MySQLdb as mysql
@@ -989,12 +1046,15 @@ class MySQLDialect(ansisql.ANSIDialect):
         util.coerce_kw_type(opts, 'connect_timeout', int)
         util.coerce_kw_type(opts, 'client_flag', int)
         util.coerce_kw_type(opts, 'local_infile', int)
-        # note: these two could break SA Unicode type
+        # Note: using either of the below will cause all strings to be returned
+        # as Unicode, both in raw SQL operations and with column types like
+        # String and MSString.
         util.coerce_kw_type(opts, 'use_unicode', bool)   
         util.coerce_kw_type(opts, 'charset', str)
-        # TODO: cursorclass and conv:  support via query string or punt?
+
+        # Rich values 'cursorclass' and 'conv' are not supported via
+        # query string.
         
-        # ssl
         ssl = {}
         for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
             if key in opts:
@@ -1004,7 +1064,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         if len(ssl):
             opts['ssl'] = ssl
         
-        # FOUND_ROWS must be set in CLIENT_FLAGS for to enable
+        # FOUND_ROWS must be set in CLIENT_FLAGS to enable
         # supports_sane_rowcount.
         client_flag = opts.get('client_flag', 0)
         if self.dbapi is not None:
@@ -1041,7 +1101,8 @@ class MySQLDialect(ansisql.ANSIDialect):
     def preparer(self):
         return MySQLIdentifierPreparer(self)
 
-    def do_executemany(self, cursor, statement, parameters, context=None, **kwargs):
+    def do_executemany(self, cursor, statement, parameters,
+                       context=None, **kwargs):
         rowcount = cursor.executemany(statement, parameters)
         if context is not None:
             context._rowcount = rowcount
@@ -1060,28 +1121,31 @@ class MySQLDialect(ansisql.ANSIDialect):
             pass
 
     def do_begin_twophase(self, connection, xid):
-        connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)]))
+        connection.execute("XA BEGIN %s", xid)
 
     def do_prepare_twophase(self, connection, xid):
-        connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
-        connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)]))
+        connection.execute("XA END %s", xid)
+        connection.execute("XA PREPARE %s", xid)
 
-    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+    def do_rollback_twophase(self, connection, xid, is_prepared=True,
+                             recover=False):
         if not is_prepared:
-            connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
-        connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)]))
+            connection.execute("XA END %s", xid)
+        connection.execute("XA ROLLBACK %s", xid)
 
-    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+    def do_commit_twophase(self, connection, xid, is_prepared=True,
+                           recover=False):
         if not is_prepared:
             self.do_prepare_twophase(connection, xid)
-        connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)]))
+        connection.execute("XA COMMIT %s", xid)
     
     def do_recover_twophase(self, connection):
-        resultset = connection.execute(sql.text("XA RECOVER"))
+        resultset = connection.execute("XA RECOVER")
         return [row['data'][0:row['gtrid_length']] for row in resultset]
 
     def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
+        return isinstance(e, self.dbapi.OperationalError) and \
+               e.args[0] in (2006, 2013, 2014, 2045, 2055)
 
     def get_default_schema_name(self, connection):
         try:
@@ -1102,7 +1166,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         # on macosx (and maybe win?) with multibyte table names.
         #
         # TODO: if this is not a problem on win, make the strategy swappable
-        # based on platform.  DESCRIBE is much slower.
+        # based on platform.  DESCRIBE is slower.
         if schema is not None:
             st = "DESCRIBE `%s`.`%s`" % (schema, table_name)
         else:
@@ -1118,12 +1182,14 @@ class MySQLDialect(ansisql.ANSIDialect):
             raise
 
     def get_version_info(self, connectable):
+        """A tuple of the database server version."""
+        
         if hasattr(connectable, 'connect'):
-            con = connectable.connect().connection
+            dbapi_con = connectable.connect().connection
         else:
-            con = connectable
+            dbapi_con = connectable
         version = []
-        for n in con.get_server_info().split('.'):
+        for n in dbapi_con.get_server_info().split('.'):
             try:
                 version.append(int(n))
             except ValueError:
@@ -1140,8 +1206,9 @@ class MySQLDialect(ansisql.ANSIDialect):
             table.name = table.name.lower()
             table.metadata.tables[table.name]= table
 
+        table_name = '.'.join(self.identifier_preparer.format_table_seq(table))
         try:
-            rp = connection.execute("DESCRIBE " + self._escape_table_name(table))
+            rp = connection.execute("DESCRIBE " + table_name)
         except exceptions.SQLError, e:
             if e.orig.args[0] == 1146:
                 raise exceptions.NoSuchTableError(table.fullname)
@@ -1166,7 +1233,9 @@ class MySQLDialect(ansisql.ANSIDialect):
             try:
                 coltype = ischema_names[col_type]
             except KeyError:
-                warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name)))
+                warnings.warn(RuntimeWarning(
+                        "Did not recognize type '%s' of column '%s'" %
+                        (col_type, name)))
                 coltype = sqltypes.NULLTYPE
 
             kw = {}
@@ -1194,24 +1263,28 @@ class MySQLDialect(ansisql.ANSIDialect):
                                                    nullable=nullable,
                                                    )))
 
-        tabletype = self.moretableinfo(connection, table, decode_from)
-        table.kwargs['mysql_engine'] = tabletype
+        table_options = self.moretableinfo(connection, table, decode_from)
+        table.kwargs.update(table_options)
 
     def moretableinfo(self, connection, table, charset=None):
         """SHOW CREATE TABLE to get foreign key/table options."""
 
-        rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {})
+        table_name = '.'.join(self.identifier_preparer.format_table_seq(table))
+        rp = connection.execute("SHOW CREATE TABLE " + table_name)
         row = _compat_fetchone(rp, charset=charset)
         if not row:
             raise exceptions.NoSuchTableError(table.fullname)
         desc = row[1].strip()
+        row.close()
+
+        table_options = {}
 
-        tabletype = ''
         lastparen = re.search(r'\)[^\)]*\Z', desc)
         if lastparen:
-            match = re.search(r'\b(?:TYPE|ENGINE)=(?P<ttype>.+)\b', desc[lastparen.start():], re.I)
+            match = re.search(r'\b(?P<spec>TYPE|ENGINE)=(?P<ttype>.+)\b', desc[lastparen.start():], re.I)
             if match:
-                tabletype = match.group('ttype')
+                table_options["mysql_%s" % match.group('spec')] = \
+                    match.group('ttype')
 
         # \x27 == ' (single quote)  (avoid xemacs syntax highlighting issue)
         fkpat = r'''CONSTRAINT [`"\x27](?P<name>.+?)[`"\x27] FOREIGN KEY \((?P<columns>.+?)\) REFERENCES [`"\x27](?P<reftable>.+?)[`"\x27] \((?P<refcols>.+?)\)'''
@@ -1222,21 +1295,32 @@ class MySQLDialect(ansisql.ANSIDialect):
             constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name'))
             table.append_constraint(constraint)
 
-        return tabletype
-
-    def _escape_table_name(self, table):
-        if table.schema is not None:
-            return '`%s`.`%s`' % (table.schema, table.name)
-        else:
-            return '`%s`' % table.name
+        return table_options
 
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
 
+        # Allow user override, won't sniff if force_charset is set.
+        if 'force_charset' in connection.properties:
+            return connection.properties['force_charset']
+
         # Note: MySQL-python 1.2.1c7 seems to ignore changes made
         # on a connection via set_character_set()
-        
-        rs = connection.execute("show variables like 'character_set%%'")
+        if self.get_version_info(connection) < (4, 1, 0):
+            try:
+                return connection.connection.character_set_name()
+            except AttributeError:
+                # < 1.2.1 final MySQL-python drivers have no charset support.
+                # a query is needed.
+                pass
+
+        # Prefer 'character_set_results' for the current connection over the
+        # value in the driver.  SET NAMES or individual variable SETs will
+        # change the charset without updating the driver's view of the world.
+        # 
+        # If it's decided that issuing that sort of SQL leaves you SOL, then
+        # this can prefer the driver value.
+        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
         opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)])
 
         if 'character_set_results' in opts:
@@ -1244,11 +1328,14 @@ class MySQLDialect(ansisql.ANSIDialect):
         try:
             return connection.connection.character_set_name()
         except AttributeError:
-            # < 1.2.1 final MySQL-python drivers have no charset support
+            # Still no charset on < 1.2.1 final...
             if 'character_set' in opts:
                 return opts['character_set']
             else:
-                warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python.  MySQL-python >= 1.2.2 is recommended.  Assuming latin1."))
+                warnings.warn(RuntimeWarning(
+                    "Could not detect the connection character set with this "
+                    "combination of MySQL server and MySQL-python. "
+                    "MySQL-python >= 1.2.2 is recommended.  Assuming latin1."))
                 return 'latin1'
 
     def _detect_case_sensitive(self, connection, charset=None):
@@ -1257,25 +1344,41 @@ class MySQLDialect(ansisql.ANSIDialect):
         Cached per-connection. This value can not change without a server
         restart.
         """
+
         # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
 
-        _per_connection_mutex.acquire()
         try:
-            raw_connection = connection.connection.connection
-            cache = self.per_connection.get(raw_connection, {})
-            if 'lower_case_table_names' not in cache:
-                row = _compat_fetchone(connection.execute(
-                        "SHOW VARIABLES LIKE 'lower_case_table_names'"),
-                        charset=charset)
-                if not row:
-                    cs = True
-                else:
-                    cs = row[1] in ('0', 'OFF' 'off')
-                cache['lower_case_table_names'] = cs
-                self.per_connection[raw_connection] = cache
-            return cache.get('lower_case_table_names')
-        finally:
-            _per_connection_mutex.release()
+            return connection.properties['lower_case_table_names']
+        except KeyError:
+            row = _compat_fetchone(connection.execute(
+                    "SHOW VARIABLES LIKE 'lower_case_table_names'"),
+                                   charset=charset)
+            if not row:
+                cs = True
+            else:
+                cs = row[1] in ('0', 'OFF' 'off')
+                row.close()
+            connection.properties['lower_case_table_names'] = cs
+            return cs
+
+    def _detect_collations(self, connection, charset=None):
+        """Pull the active COLLATIONS list from the server.
+
+        Cached per-connection.
+        """
+        
+        try:
+            return connection.properties['collations']
+        except KeyError:
+            collations = {}
+            if self.get_version_info(connection) < (4, 1, 0):
+                pass
+            else:
+                rs = connection.execute('SHOW COLLATION')
+                for row in _compat_fetchall(rs, charset):
+                    collations[row[0]] = row[1]
+            connection.properties['collations'] = collations
+            return collations
 
 def _compat_fetchall(rp, charset=None):
     """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
@@ -1291,6 +1394,10 @@ def _compat_fetchone(rp, charset=None):
 class _MySQLPythonRowProxy(object):
     """Return consistent column values for all versions of MySQL-python (esp. alphas) and unicode settings."""
 
+    # Some MySQL-python versions can return some columns as
+    # sets.Set(['value']) (seriously) but thankfully that doesn't
+    # seem to come up in DDL queries.
+
     def __init__(self, rowproxy, charset):
         self.rowproxy = rowproxy
         self.charset = charset
@@ -1316,13 +1423,15 @@ class MySQLCompiler(ansisql.ANSICompiler):
     operators = ansisql.ANSICompiler.operators.copy()
     operators.update(
         {
-            sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y),
-            operator.mod : '%%'
+            sql.ColumnOperators.concat_op: \
+              lambda x, y: "concat(%s, %s)" % (x, y),
+            operator.mod: '%%'
         }
     )
 
     def visit_cast(self, cast, **kwargs):
-        if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
+        if isinstance(cast.type, (sqltypes.Date, sqltypes.Time,
+                                  sqltypes.DateTime)):
             return super(MySQLCompiler, self).visit_cast(cast, **kwargs)
         else:
             # so just skip the CAST altogether for now.
@@ -1348,26 +1457,45 @@ class MySQLCompiler(ansisql.ANSICompiler):
         
 
 class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def get_column_specification(self, column, override_pk=False, first_pk=False):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
+    def get_column_specification(self, column, override_pk=False,
+                                 first_pk=False):
+        """Builds column DDL."""
+        
+        colspec = [self.preparer.format_column(column),
+                   column.type.dialect_impl(self.dialect).get_col_spec()]
+
         default = self.get_column_default_string(column)
         if default is not None:
-            colspec += " DEFAULT " + default
+            colspec.append('DEFAULT ' + default)
 
         if not column.nullable:
-            colspec += " NOT NULL"
+            colspec.append('NOT NULL')
+
+        # FIXME: #649, also #612 with regard to SHOW CREATE
         if column.primary_key:
-            if len(column.foreign_keys)==0 and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer):
-                colspec += " AUTO_INCREMENT"
-        return colspec
+            if (len(column.foreign_keys)==0
+                and first_pk
+                and column.autoincrement
+                and isinstance(column.type, sqltypes.Integer)):
+                colspec.append('AUTO_INCREMENT')
+
+        return ' '.join(colspec)
 
     def post_create_table(self, table):
-        args = ""
+        """Build table-level CREATE options like ENGINE and COLLATE."""
+
+        table_opts = []
         for k in table.kwargs:
             if k.startswith('mysql_'):
-                opt = k[6:]
-                args += " %s=%s" % (opt.upper(), table.kwargs[k])
-        return args
+                opt = k[6:].upper()
+                joiner = '='
+                if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET',
+                           'CHARACTER SET', 'COLLATE'):
+                    joiner = ' '
+                
+                table_opts.append(joiner.join((opt, table.kwargs[k])))
+        return ' '.join(table_opts)
+
 
 class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
     def visit_index(self, index):
@@ -1382,9 +1510,11 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
                      self.preparer.format_constraint(constraint)))
         self.execute()
 
+
 class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
     def __init__(self, dialect):
-        super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote='`')
+        super(MySQLIdentifierPreparer, self).__init__(dialect,
+                                                      initial_quote='`')
 
     def _reserved_words(self):
         return RESERVED_WORDS
@@ -1393,7 +1523,69 @@ class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
         return value.replace('`', '``')
 
     def _fold_identifier_case(self, value):
-        #TODO: determine MySQL's case folding rules
+        # TODO: determine MySQL's case folding rules
+        #
+        # For compatability with sql.text() issued statements, maybe it's best
+        # to just leave things as-is.  When lower_case_table_names > 0 it
+        # looks a good idea to lc everything, but more importantly the casing
+        # of all identifiers in an expression must be consistent.  So for now,
+        # just leave everything as-is.
         return value
 
+    def format_table_seq(self, table, use_schema=True):
+        """Format table name and schema as a tuple."""
+
+        if use_schema and getattr(table, 'schema', None):
+            return (self.quote_identifier(table.schema),
+                    self.format_table(table, use_schema=False))
+        else:
+            return (self.format_table(table, use_schema=False), )
+
+
+class MySQLCharsetOnConnect(object):
+    """Use an alternate connection character set automatically."""
+
+    def __init__(self, charset, collation=None):
+        """Creates a pool listener that decorates new database connections.
+
+        Sets the connection character set on MySQL connections.  Strings
+        sent to and from the server will use this encoding, and if a collation
+        is provided it will be used as the default.
+
+        There is also a MySQL-python 'charset' keyword for connections,
+        however that keyword has the side-effect of turning all strings into
+        Unicode.
+
+        This class is a ``Pool`` listener.  To use, pass an insstance to the
+        ``listeners`` argument to create_engine or Pool constructor, or
+        manually add it to a pool with ``add_listener()``.
+
+        charset:
+          The character set to use
+
+        collation:
+          Optional, use a non-default collation for the given charset
+        """
+
+        self.charset = charset
+        self.collation = collation
+        
+    def connect(self, dbapi_con, con_record):
+        cr = dbapi_con.cursor()
+        try:
+            if self.collation is None:
+                if hasattr(dbapi_con, 'set_character_set'):
+                    dbapi_con.set_character_set(self.charset)
+                else:
+                    cr.execute("SET NAMES %s" % self.charset)
+            else:
+                if hasattr(dbapi_con, 'set_character_set'):
+                    dbapi_con.set_character_set(self.charset)
+                cr.execute("SET NAMES %s COLLATE %s" % (self.charset,
+                                                        self.collation))
+        # let SQL errors (1064 if SET NAMES is not supported) raise
+        finally:
+            cr.close()
+
+        
 dialect = MySQLDialect
index dbba78893d94c8dd8c33e84a65816b89241032ff..484022bd166f9adc438ff472459093072d5175ac 100644 (file)
@@ -90,6 +90,17 @@ class TypesTest(AssertMixin):
             (mysql.MSBigInteger, [4], {'zerofill':True, 'unsigned':True},
              'BIGINT(4) UNSIGNED ZEROFILL'),
 
+            (mysql.MSTinyInteger, [], {},
+             'TINYINT'),
+            (mysql.MSTinyInteger, [1], {},
+             'TINYINT(1)'),
+            (mysql.MSTinyInteger, [1], {'unsigned':True},
+             'TINYINT(1) UNSIGNED'),
+            (mysql.MSTinyInteger, [1], {'zerofill':True},
+             'TINYINT(1) ZEROFILL'),
+            (mysql.MSTinyInteger, [1], {'zerofill':True, 'unsigned':True},
+             'TINYINT(1) UNSIGNED ZEROFILL'),
+
             (mysql.MSSmallInteger, [], {},
              'SMALLINT'),
             (mysql.MSSmallInteger, [4], {},
@@ -320,5 +331,38 @@ class TypesTest(AssertMixin):
 
         m.drop_all()
 
+class CharsetHelperTest(PersistTest):
+    @testing.supported('mysql')
+    def test_basic(self):
+        if testbase.db.dialect.get_version_info(testbase.db) < (4, 1):
+            return
+
+        helper = mysql.MySQLCharsetOnConnect('utf8')
+
+        e = create_engine(testbase.db.url, listeners=[helper])
+
+        rs = e.execute("SHOW VARIABLES LIKE 'character_set%%'")
+        vars = dict([(row[0], row[1]) for row in mysql._compat_fetchall(rs)])
+        self.assert_(vars['character_set_client'] == 'utf8')
+        self.assert_(vars['character_set_connection'] == 'utf8')
+
+        helper.charset = 'latin1'
+        e.pool.dispose()
+        rs = e.execute("SHOW VARIABLES LIKE 'character_set%%'")
+        vars = dict([(row[0], row[1]) for row in mysql._compat_fetchall(rs)])
+        self.assert_(vars['character_set_client'] == 'latin1')
+        self.assert_(vars['character_set_connection'] == 'latin1')
+
+        helper.charset = 'utf8'
+        helper.collation = 'utf8_bin'
+        e.pool.dispose()
+        rs = e.execute("SHOW VARIABLES LIKE 'character_set%%'")
+        vars = dict([(row[0], row[1]) for row in mysql._compat_fetchall(rs)])
+        self.assert_(vars['character_set_client'] == 'utf8')
+        self.assert_(vars['character_set_connection'] == 'utf8')
+        rs = e.execute("SHOW VARIABLES LIKE 'collation%%'")
+        vars = dict([(row[0], row[1]) for row in mysql._compat_fetchall(rs)])
+        self.assert_(vars['collation_connection'] == 'utf8_bin')
+
 if __name__ == "__main__":
     testbase.main()