From: Mike Bayer Date: Fri, 15 Dec 2006 01:07:05 +0000 (+0000) Subject: fix to the fix for [ticket:396] plus a unit test X-Git-Tag: rel_0_3_3~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=422558bc5d58557a758b56f4d592d08dd6f86309;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fix to the fix for [ticket:396] plus a unit test --- diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 19dedd8267..a452a696e2 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -299,9 +299,11 @@ class MySQLDialect(ansisql.ANSIDialect): def preparer(self): return MySQLIdentifierPreparer(self) - def do_executemany(self, cursor, statement, parameters, **kwargs): + def do_executemany(self, cursor, statement, parameters, context=None, **kwargs): try: - cursor.executemany(statement, parameters) + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount except mysql.OperationalError, o: if o.args[0] == 2006 or o.args[0] == 2014: cursor.invalidate() @@ -316,7 +318,7 @@ class MySQLDialect(ansisql.ANSIDialect): def do_rollback(self, connection): - # some versions of MySQL just dont support rollback() at all.... + # MySQL without InnoDB doesnt support rollback() try: connection.rollback() except: diff --git a/test/sql/alltests.py b/test/sql/alltests.py index c79d7b67e8..98dbfa6a07 100644 --- a/test/sql/alltests.py +++ b/test/sql/alltests.py @@ -15,6 +15,7 @@ def suite(): # assorted round-trip tests 'sql.query', 'sql.quote', + 'sql.rowcount', # defaults, sequences (postgres/oracle) 'sql.defaults', diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py new file mode 100644 index 0000000000..05d0f21105 --- /dev/null +++ b/test/sql/rowcount.py @@ -0,0 +1,67 @@ +from sqlalchemy import * +import testbase + +class FoundRowsTest(testbase.AssertMixin): + """tests rowcount functionality""" + def setUpAll(self): + metadata = BoundMetaData(testbase.db) + + global employees_table + + employees_table = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('department', String(1)), + ) + employees_table.create() + + def setUp(self): + global data + data = [ ('Angela', 'A'), + ('Andrew', 'A'), + ('Anand', 'A'), + ('Bob', 'B'), + ('Bobette', 'B'), + ('Buffy', 'B'), + ('Charlie', 'C'), + ('Cynthia', 'C'), + ('Chris', 'C') ] + + i = employees_table.insert() + i.execute(*[{'name':n, 'department':d} for n, d in data]) + def tearDown(self): + employees_table.delete().execute() + + def tearDownAll(self): + employees_table.drop() + + def testbasic(self): + s = employees_table.select() + r = s.execute().fetchall() + + assert len(r) == len(data) + + def test_update_rowcount1(self): + # WHERE matches 3, 3 rows changed + department = employees_table.c.department + r = employees_table.update(department=='C').execute(department='Z') + assert r.rowcount == 3 + + def test_update_rowcount2(self): + # WHERE matches 3, 0 rows changed + department = employees_table.c.department + r = employees_table.update(department=='C').execute(department='C') + assert r.rowcount == 3 + + def test_delete_rowcount(self): + # WHERE matches 3, 3 rows deleted + department = employees_table.c.department + r = employees_table.delete(department=='C').execute() + assert r.rowcount == 3 + +if __name__ == '__main__': + testbase.main() + + + +