From: Michael Trier Date: Mon, 1 Mar 2010 03:01:55 +0000 (+0000) Subject: Added support for FOUND_ROWS to mysqlconnector. X-Git-Tag: rel_0_6beta2~85 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=d0eba53117ca8ba6a4e1b96d671c89947c761b05;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Added support for FOUND_ROWS to mysqlconnector. --- diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 9e7caae565..6b3888cb07 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -8,7 +8,7 @@ import re from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer, - BIT, NUMERIC) + BIT, NUMERIC, _NumericType) from sqlalchemy.engine import base as engine_base, default from sqlalchemy.sql import operators as sql_operators @@ -28,6 +28,14 @@ class MySQL_mysqlconnectorCompiler(MySQLCompiler): def post_process_text(self, text): return text.replace('%', '%%') +class _DecimalType(_NumericType): + def result_processor(self, dialect, coltype): + if self.asdecimal: + return None + return processors.to_float + +class _myconnpyNumeric(_DecimalType, NUMERIC): + pass class MySQL_mysqlconnectorIdentifierPreparer(MySQLIdentifierPreparer): @@ -35,12 +43,6 @@ class MySQL_mysqlconnectorIdentifierPreparer(MySQLIdentifierPreparer): value = value.replace(self.escape_quote, self.escape_to_quote) return value.replace("%", "%%") -class _myconnpyNumeric(NUMERIC): - def result_processor(self, dialect, coltype): - if self.asdecimal: - return None - return processors.to_float - class _myconnpyBIT(BIT): def result_processor(self, dialect, coltype): """MySQL-connector already converts mysql bits, so.""" @@ -82,6 +84,16 @@ class MySQL_mysqlconnector(MySQLDialect): opts['buffered'] = True opts['raise_on_warnings'] = True + # FOUND_ROWS must be set in ClientFlag to enable + # supports_sane_rowcount. + if self.dbapi is not None: + try: + from mysql.connector.constants import ClientFlag + client_flags = opts.get('client_flags', ClientFlag.get_default()) + client_flags |= ClientFlag.FOUND_ROWS + opts['client_flags'] = client_flags + except: + pass return [[], opts] def _get_server_version_info(self, connection): diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index 70adb1a8bc..fed8b426fd 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -594,7 +594,7 @@ class CascadeToFKPKTest(_base.MappedTest, testing.AssertsCompiledSQL): self._test_onetomany(True) # PG etc. need passive=True to allow PK->PK cascade - @testing.fails_on_everything_except('sqlite', 'mysql+mysqlconnector') + @testing.fails_on_everything_except('sqlite') def test_onetomany_nonpassive(self): self._test_onetomany(False) @@ -737,7 +737,7 @@ class JoinedInheritanceTest(_base.MappedTest): self._test_pk(True) # PG etc. need passive=True to allow PK->PK cascade - @testing.fails_on_everything_except('sqlite', 'mysql+mysqlconnector') + @testing.fails_on_everything_except('sqlite') def test_pk_nonpassive(self): self._test_pk(False) @@ -747,7 +747,7 @@ class JoinedInheritanceTest(_base.MappedTest): self._test_fk(True) # PG etc. need passive=True to allow PK->PK cascade - @testing.fails_on_everything_except('sqlite', 'mysql+mysqlconnector') + @testing.fails_on_everything_except('sqlite') def test_fk_nonpassive(self): self._test_fk(False)