]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added support for FOUND_ROWS to mysqlconnector.
authorMichael Trier <mtrier@gmail.com>
Mon, 1 Mar 2010 03:01:55 +0000 (03:01 +0000)
committerMichael Trier <mtrier@gmail.com>
Mon, 1 Mar 2010 03:01:55 +0000 (03:01 +0000)
lib/sqlalchemy/dialects/mysql/mysqlconnector.py
test/orm/test_naturalpks.py

index 9e7caae5659c51d33f59ba43f54686cf33b763f4..6b3888cb07647df94b530dd27369bc1c147219f3 100644 (file)
@@ -8,7 +8,7 @@ import re
 
 from sqlalchemy.dialects.mysql.base import (MySQLDialect,
     MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
-    BIT, NUMERIC)
+    BIT, NUMERIC, _NumericType)
 
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
@@ -28,6 +28,14 @@ class MySQL_mysqlconnectorCompiler(MySQLCompiler):
     def post_process_text(self, text):
         return text.replace('%', '%%')
 
+class _DecimalType(_NumericType):
+    def result_processor(self, dialect, coltype):
+        if self.asdecimal:
+            return None
+        return processors.to_float
+
+class _myconnpyNumeric(_DecimalType, NUMERIC):
+    pass
 
 class MySQL_mysqlconnectorIdentifierPreparer(MySQLIdentifierPreparer):
 
@@ -35,12 +43,6 @@ class MySQL_mysqlconnectorIdentifierPreparer(MySQLIdentifierPreparer):
         value = value.replace(self.escape_quote, self.escape_to_quote)
         return value.replace("%", "%%")
 
-class _myconnpyNumeric(NUMERIC):
-    def result_processor(self, dialect, coltype):
-        if self.asdecimal:
-            return None
-        return processors.to_float
-
 class _myconnpyBIT(BIT):
     def result_processor(self, dialect, coltype):
         """MySQL-connector already converts mysql bits, so."""
@@ -82,6 +84,16 @@ class MySQL_mysqlconnector(MySQLDialect):
         opts['buffered'] = True
         opts['raise_on_warnings'] = True
 
+        # FOUND_ROWS must be set in ClientFlag to enable
+        # supports_sane_rowcount.
+        if self.dbapi is not None:
+            try:
+                from mysql.connector.constants import ClientFlag
+                client_flags = opts.get('client_flags', ClientFlag.get_default())
+                client_flags |= ClientFlag.FOUND_ROWS
+                opts['client_flags'] = client_flags
+            except:
+                pass
         return [[], opts]
 
     def _get_server_version_info(self, connection):
index 70adb1a8bc0f822050c57f87aa815b4e48ee16a5..fed8b426fd13590baca57da39fb1e79b30710cad 100644 (file)
@@ -594,7 +594,7 @@ class CascadeToFKPKTest(_base.MappedTest, testing.AssertsCompiledSQL):
         self._test_onetomany(True)
 
     # PG etc. need passive=True to allow PK->PK cascade
-    @testing.fails_on_everything_except('sqlite', 'mysql+mysqlconnector')
+    @testing.fails_on_everything_except('sqlite')
     def test_onetomany_nonpassive(self):
         self._test_onetomany(False)
         
@@ -737,7 +737,7 @@ class JoinedInheritanceTest(_base.MappedTest):
         self._test_pk(True)
 
     # PG etc. need passive=True to allow PK->PK cascade
-    @testing.fails_on_everything_except('sqlite', 'mysql+mysqlconnector')
+    @testing.fails_on_everything_except('sqlite')
     def test_pk_nonpassive(self):
         self._test_pk(False)
         
@@ -747,7 +747,7 @@ class JoinedInheritanceTest(_base.MappedTest):
         self._test_fk(True)
         
     # PG etc. need passive=True to allow PK->PK cascade
-    @testing.fails_on_everything_except('sqlite', 'mysql+mysqlconnector')
+    @testing.fails_on_everything_except('sqlite')
     def test_fk_nonpassive(self):
         self._test_fk(False)