From e1a52eb7dfb19edf3baeff6d2878b6b0afb9a04d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 13 Dec 2006 21:06:38 +0000 Subject: [PATCH] - patch that makes MySQL rowcount work correctly! [ticket:396] --- CHANGES | 1 + lib/sqlalchemy/databases/mysql.py | 8 +++++++- lib/sqlalchemy/orm/mapper.py | 2 +- test/orm/unitofwork.py | 2 +- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/CHANGES b/CHANGES index fd03d1a93d..6c95164383 100644 --- a/CHANGES +++ b/CHANGES @@ -5,6 +5,7 @@ - fixed QueuePool bug whereby its better able to reconnect to a database that was not reachable (thanks to Sébastien Lelong), also fixed dispose() method +- patch that makes MySQL rowcount work correctly! [ticket:396] 0.3.2 - major connection pool bug fixed. fixes MySQL out of sync diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index c795ae7d4c..19dedd8267 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -13,6 +13,7 @@ import sqlalchemy.exceptions as exceptions try: import MySQLdb as mysql + import MySQLdb.constants.CLIENT as CLIENT_FLAGS except: mysql = None @@ -270,6 +271,11 @@ class MySQLDialect(ansisql.ANSIDialect): coercetype('use_unicode', bool) # this could break SA Unicode type coercetype('charset', str) # this could break SA Unicode type # TODO: what about options like "ssl", "cursorclass" and "conv" ? + + client_flag = opts.get('client_flag', 0) + client_flag |= CLIENT_FLAGS.FOUND_ROWS + opts['client_flag'] = client_flag + return [[], opts] def create_execution_context(self): @@ -279,7 +285,7 @@ class MySQLDialect(ansisql.ANSIDialect): return sqltypes.adapt_type(typeobj, colspecs) def supports_sane_rowcount(self): - return False + return True def compiler(self, statement, bindparams, **kwargs): return MySQLCompiler(self, statement, bindparams, **kwargs) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index d7449c7cab..78fd3a1cc2 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -960,7 +960,7 @@ class Mapper(object): mapper._postfetch(connection, table, obj, c, c.last_updated_params()) updated_objects.add(obj) - rows += c.cursor.rowcount + rows += c.rowcount if c.supports_sane_rowcount() and rows != len(update): raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 700597e08d..ed3b9fca1d 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -100,7 +100,7 @@ class VersioningTest(UnitOfWorkTest): version_table.delete().execute() UnitOfWorkTest.tearDown(self) - @testbase.unsupported('mysql', 'mssql') + @testbase.unsupported('mssql') def testbasic(self): s = create_session() class Foo(object):pass -- 2.47.2