]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Further fixes for the mysql-connector dialect. [ticket:1668]
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Feb 2010 20:24:15 +0000 (20:24 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Feb 2010 20:24:15 +0000 (20:24 +0000)
CHANGES
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqlconnector.py
test/dialect/test_mysql.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index bef45944a2a5d82ae3b900e5d2e032c212824a15..81ba79de106f0d9444ecc3224426f1a92337b237 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -130,6 +130,8 @@ CHANGES
   - Fixed reflection of TINYINT(1) "boolean" columns defined with
     integer flags like UNSIGNED.
 
+  - Further fixes for the mysql-connector dialect.  [ticket:1668]
+  
 - mssql
   - Re-established initial support for pymssql.
  
index 82a4af941fdb021530125dff96893bda4803562c..686e3da6271752138d13e72b221d1587e1e63f56 100644 (file)
@@ -554,7 +554,13 @@ class BIT(sqltypes.TypeEngine):
         self.length = length
 
     def result_processor(self, dialect, coltype):
-        """Convert a MySQL's 64 bit, variable length binary string to a long."""
+        """Convert a MySQL's 64 bit, variable length binary string to a long.
+        
+        TODO: this is MySQL-db, pyodbc specific.  OurSQL and mysqlconnector
+        already do this, so this logic should be moved to those dialects.
+        
+        """
+        
         def process(value):
             if value is not None:
                 v = 0L
index 3ac207109f90fcc72f3a666d109569235e3e9727..165c7b73e3f2b130335aa24a04c3165211cbdbf1 100644 (file)
@@ -6,12 +6,14 @@
 
 import re
 
-from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext,\
-                                            MySQLCompiler, MySQLIdentifierPreparer
+from sqlalchemy.dialects.mysql.base import (MySQLDialect,
+    MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
+    BIT, NUMERIC)
                                             
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
+from sqlalchemy import processors
 
 class MySQL_mysqlconnectorExecutionContext(MySQLExecutionContext):
     
@@ -33,10 +35,23 @@ class MySQL_mysqlconnectorIdentifierPreparer(MySQLIdentifierPreparer):
         value = value.replace(self.escape_quote, self.escape_to_quote)
         return value.replace("%", "%%")
 
+class _myconnpyNumeric(NUMERIC):
+    def result_processor(self, dialect, coltype):
+        if self.asdecimal:
+            return None
+        return processors.to_float
+
+class _myconnpyBIT(BIT):
+    def result_processor(self, dialect, coltype):
+        """MySQL-connector already converts mysql bits, so."""
+
+        return None
+
 class MySQL_mysqlconnector(MySQLDialect):
     driver = 'mysqlconnector'
     supports_unicode_statements = False
-    supports_sane_rowcount = True
+    supports_unicode_binds = True
+    supports_sane_rowcount = False
     supports_sane_multi_rowcount = True
 
     default_paramstyle = 'format'
@@ -45,6 +60,14 @@ class MySQL_mysqlconnector(MySQLDialect):
     
     preparer = MySQL_mysqlconnectorIdentifierPreparer
     
+    colspecs = util.update_copy(
+        MySQLDialect.colspecs,
+        {
+            sqltypes.Numeric: _myconnpyNumeric,
+            BIT: _myconnpyBIT,
+        }
+    )
+    
     @classmethod
     def dbapi(cls):
         from mysql import connector
@@ -53,24 +76,44 @@ class MySQL_mysqlconnector(MySQLDialect):
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username='user')
         opts.update(url.query)
+        
+        util.coerce_kw_type(opts, 'buffered', bool)
+        util.coerce_kw_type(opts, 'raise_on_warnings', bool)
+        opts['buffered'] = True
+        opts['raise_on_warnings'] = True
+        
         return [[], opts]
 
     def _get_server_version_info(self, connection):
         dbapi_con = connection.connection
+        
+        from mysql.connector.constants import ClientFlag
+        dbapi_con.set_client_flag(ClientFlag.FOUND_ROWS)
+        
         version = dbapi_con.get_server_version()
         return tuple(version)
 
     def _detect_charset(self, connection):
-        """Sniff out the character set in use for connection results."""
-        
         return connection.connection.get_characterset_info()
 
     def _extract_error_code(self, exception):
-        m = re.compile(r"\(.*\)\s+(\d+)").search(str(exception))
-        c = m.group(1)
-        if c:
-            return int(c)
-        else:
+        try:
+            return exception.orig.errno
+        except AttributeError:
             return None
+    
+    def is_disconnect(self, e):
+        errnos = (2006, 2013, 2014, 2045, 2055, 2048)
+        exceptions = (self.dbapi.OperationalError,self.dbapi.InterfaceError)
+        if isinstance(e, exceptions):
+            return e.errno in errnos
+        else:
+            return False
+        
+    def _compat_fetchall(self, rp, charset=None):
+        return rp.fetchall()
 
+    def _compat_fetchone(self, rp, charset=None):
+        return rp.fetchone()
+        
 dialect = MySQL_mysqlconnector
index 85156ac3bfa0f3bb47866e28cda56f6a53cf60df..13df2243d8040675e7da0321f2a25e7af42e1bad 100644 (file)
@@ -1201,11 +1201,22 @@ class MatchTest(TestBase, AssertsCompiledSQL):
     def teardown_class(cls):
         metadata.drop_all()
 
+    @testing.fails_on('mysql+mysqlconnector', 'uses pyformat')
     def test_expression(self):
         format = testing.db.dialect.paramstyle == 'format' and '%s' or '?'
         self.assert_compile(
             matchtable.c.title.match('somstr'),
             "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)" % format)
+    
+    @testing.fails_on('mysql+mysqldb', 'uses format')
+    @testing.fails_on('mysql+oursql', 'uses format')
+    @testing.fails_on('mysql+pyodbc', 'uses format')
+    @testing.fails_on('mysql+zxjdbc', 'uses format')
+    def test_expression(self):
+        format = '%(title_1)s'
+        self.assert_compile(
+            matchtable.c.title.match('somstr'),
+            "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)" % format)
 
     def test_simple_match(self):
         results = (matchtable.select().
index fd957c6e517b4b6887bed58c9afb5957f4f57857..c4d92b46d0c3177c35685a16a678597e05dcdc1e 100644 (file)
@@ -286,6 +286,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
                 ('postgresql','zxjdbc'),  
                 ('mysql','oursql'),
                 ('mysql','zxjdbc'),
+                ('mysql','mysqlconnector'),
                 ('sqlite','pysqlite'),
                 ('oracle','zxjdbc'),
             )), \
@@ -295,7 +296,9 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
                                          testing.db.dialect.returns_unicode_strings)
         
     def test_round_trip(self):
-        unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »"
+        unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, "\
+                    u"quand une drôle de petit voix m’a réveillé. Elle "\
+                    u"disait: « S’il vous plaît… dessine-moi un mouton! »"
         
         unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata)