]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix to the fix for [ticket:396] plus a unit test
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Dec 2006 01:07:05 +0000 (01:07 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Dec 2006 01:07:05 +0000 (01:07 +0000)
lib/sqlalchemy/databases/mysql.py
test/sql/alltests.py
test/sql/rowcount.py [new file with mode: 0644]

index 19dedd8267ed06a0515347eac43a1a2ae869886c..a452a696e2ce9a7b75e500bf5fb56439a6813509 100644 (file)
@@ -299,9 +299,11 @@ class MySQLDialect(ansisql.ANSIDialect):
     def preparer(self):
         return MySQLIdentifierPreparer(self)
 
-    def do_executemany(self, cursor, statement, parameters, **kwargs):
+    def do_executemany(self, cursor, statement, parameters, context=None, **kwargs):
         try:
-            cursor.executemany(statement, parameters)
+            rowcount = cursor.executemany(statement, parameters)
+            if context is not None:
+                context._rowcount = rowcount
         except mysql.OperationalError, o:
             if o.args[0] == 2006 or o.args[0] == 2014:
                 cursor.invalidate()
@@ -316,7 +318,7 @@ class MySQLDialect(ansisql.ANSIDialect):
             
 
     def do_rollback(self, connection):
-        # some versions of MySQL just dont support rollback() at all....
+        # MySQL without InnoDB doesnt support rollback()
         try:
             connection.rollback()
         except:
index c79d7b67e881172c39acb9cefbca4852485f44a8..98dbfa6a079b02be93c2eaa0d6c23e64431975a6 100644 (file)
@@ -15,6 +15,7 @@ def suite():
         # assorted round-trip tests
         'sql.query',
         'sql.quote',
+        'sql.rowcount',
         
         # defaults, sequences (postgres/oracle)
         'sql.defaults',
diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py
new file mode 100644 (file)
index 0000000..05d0f21
--- /dev/null
@@ -0,0 +1,67 @@
+from sqlalchemy import *
+import testbase
+
+class FoundRowsTest(testbase.AssertMixin):
+    """tests rowcount functionality"""
+    def setUpAll(self):
+        metadata = BoundMetaData(testbase.db)
+
+        global employees_table
+
+        employees_table = Table('employees', metadata,
+            Column('employee_id', Integer, primary_key=True),
+            Column('name', String(50)),
+            Column('department', String(1)),
+        )
+        employees_table.create()
+
+    def setUp(self):
+        global data
+        data = [ ('Angela', 'A'),
+                 ('Andrew', 'A'),
+                 ('Anand', 'A'),
+                 ('Bob', 'B'),
+                 ('Bobette', 'B'),
+                 ('Buffy', 'B'),
+                 ('Charlie', 'C'),
+                 ('Cynthia', 'C'),
+                 ('Chris', 'C') ]
+
+        i = employees_table.insert()
+        i.execute(*[{'name':n, 'department':d} for n, d in data])
+    def tearDown(self):
+        employees_table.delete().execute()
+        
+    def tearDownAll(self):
+        employees_table.drop()
+
+    def testbasic(self):
+        s = employees_table.select()
+        r = s.execute().fetchall()
+
+        assert len(r) == len(data)
+
+    def test_update_rowcount1(self):
+        # WHERE matches 3, 3 rows changed
+        department = employees_table.c.department
+        r = employees_table.update(department=='C').execute(department='Z')
+        assert r.rowcount == 3
+        
+    def test_update_rowcount2(self):
+        # WHERE matches 3, 0 rows changed
+        department = employees_table.c.department
+        r = employees_table.update(department=='C').execute(department='C')
+        assert r.rowcount == 3
+        
+    def test_delete_rowcount(self):
+        # WHERE matches 3, 3 rows deleted
+        department = employees_table.c.department
+        r = employees_table.delete(department=='C').execute()
+        assert r.rowcount == 3
+
+if __name__ == '__main__':
+    testbase.main()
+    
+
+
+