]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- patch that makes MySQL rowcount work correctly! [ticket:396]
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Dec 2006 21:06:38 +0000 (21:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Dec 2006 21:06:38 +0000 (21:06 +0000)
CHANGES
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/orm/mapper.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index fd03d1a93da65993d70cc90056c3f205260b6f37..6c95164383538a0ebed0459551206858370a66b2 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -5,6 +5,7 @@
 - fixed QueuePool bug whereby its better able to reconnect to a database
 that was not reachable (thanks to Sébastien Lelong), also fixed dispose()
 method
+- patch that makes MySQL rowcount work correctly! [ticket:396]
  
 0.3.2
 - major connection pool bug fixed.  fixes MySQL out of sync
index c795ae7d4c05424f943cf4919e1eba00a0adcf82..19dedd8267ed06a0515347eac43a1a2ae869886c 100644 (file)
@@ -13,6 +13,7 @@ import sqlalchemy.exceptions as exceptions
 
 try:
     import MySQLdb as mysql
+    import MySQLdb.constants.CLIENT as CLIENT_FLAGS
 except:
     mysql = None
 
@@ -270,6 +271,11 @@ class MySQLDialect(ansisql.ANSIDialect):
         coercetype('use_unicode', bool)   # this could break SA Unicode type
         coercetype('charset', str)        # this could break SA Unicode type
         # TODO: what about options like "ssl", "cursorclass" and "conv" ?
+
+        client_flag = opts.get('client_flag', 0)
+        client_flag |= CLIENT_FLAGS.FOUND_ROWS
+        opts['client_flag'] = client_flag
+
         return [[], opts]
 
     def create_execution_context(self):
@@ -279,7 +285,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         return sqltypes.adapt_type(typeobj, colspecs)
 
     def supports_sane_rowcount(self):
-        return False
+        return True
 
     def compiler(self, statement, bindparams, **kwargs):
         return MySQLCompiler(self, statement, bindparams, **kwargs)
index d7449c7cab1dc4c6f186dc08bd49f721255fc941..78fd3a1cc2f3e388944e9ad63b8d2fc4303672fd 100644 (file)
@@ -960,7 +960,7 @@ class Mapper(object):
                     mapper._postfetch(connection, table, obj, c, c.last_updated_params())
 
                     updated_objects.add(obj)
-                    rows += c.cursor.rowcount
+                    rows += c.rowcount
 
                 if c.supports_sane_rowcount() and rows != len(update):
                     raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
index 700597e08d3913036abb356e8ba0b56a31f246cc..ed3b9fca1db016319cf3899797d95682767b0e22 100644 (file)
@@ -100,7 +100,7 @@ class VersioningTest(UnitOfWorkTest):
         version_table.delete().execute()
         UnitOfWorkTest.tearDown(self)
     
-    @testbase.unsupported('mysql', 'mssql')
+    @testbase.unsupported('mssql')
     def testbasic(self):
         s = create_session()
         class Foo(object):pass