From: Mike Bayer Date: Sun, 18 Jan 2015 02:36:52 +0000 (-0500) Subject: - fix a regression from ref #3178, where dialects that don't actually support X-Git-Tag: rel_1_0_0b1~109 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=f49c367ef712d080e630ba722f96903922d7de7b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fix a regression from ref #3178, where dialects that don't actually support sane multi rowcount (e.g. pyodbc) would fail on multirow update. add a test that mocks this breakage into plain dialects --- diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index f477e1dd7b..dbf1d3eb41 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -617,6 +617,14 @@ def _emit_update_statements(base_mapper, uowtransaction, rows = 0 records = list(records) + # TODO: would be super-nice to not have to determine this boolean + # inside the loop here, in the 99.9999% of the time there's only + # one connection in use + assert_singlerow = connection.dialect.supports_sane_rowcount + assert_multirow = assert_singlerow and \ + connection.dialect.supports_sane_multi_rowcount + allow_multirow = not needs_version_id or assert_multirow + if hasvalue: for state, state_dict, params, mapper, \ connection, value_params in records: @@ -635,9 +643,7 @@ def _emit_update_statements(base_mapper, uowtransaction, value_params) rows += c.rowcount else: - if needs_version_id and \ - not connection.dialect.supports_sane_multi_rowcount and \ - connection.dialect.supports_sane_rowcount: + if not allow_multirow: for state, state_dict, params, mapper, \ connection, value_params in records: c = cached_connections[connection].\ @@ -654,6 +660,7 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount else: multiparams = [rec[2] for rec in records] + c = cached_connections[connection].\ execute(statement, multiparams) @@ -670,7 +677,8 @@ def _emit_update_statements(base_mapper, uowtransaction, c.context.compiled_parameters[0], value_params) - if connection.dialect.supports_sane_rowcount: + if assert_multirow or assert_singlerow and \ + len(multiparams) == 1: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 374a77237c..681b104cff 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -3,13 +3,13 @@ from sqlalchemy import testing from sqlalchemy.testing import engines from sqlalchemy.testing.schema import Table, Column from test.orm import _fixtures -from sqlalchemy import exc -from sqlalchemy.testing import fixtures +from sqlalchemy import exc, util +from sqlalchemy.testing import fixtures, config from sqlalchemy import Integer, String, ForeignKey, func from sqlalchemy.orm import mapper, relationship, backref, \ create_session, unitofwork, attributes,\ Session, exc as orm_exc -from sqlalchemy.testing.mock import Mock +from sqlalchemy.testing.mock import Mock, patch from sqlalchemy.testing.assertsql import AllOf, CompiledSQL from sqlalchemy import event @@ -1473,6 +1473,67 @@ class BasicStaleChecksTest(fixtures.MappedTest): sess.flush ) + def test_update_single_missing_broken_multi_rowcount(self): + @util.memoized_property + def rowcount(self): + if len(self.context.compiled_parameters) > 1: + return -1 + else: + return self.context.rowcount + + with patch.object( + config.db.dialect, "supports_sane_multi_rowcount", False): + with patch( + "sqlalchemy.engine.result.ResultProxy.rowcount", + rowcount): + Parent, Child = self._fixture() + sess = Session() + p1 = Parent(id=1, data=2) + sess.add(p1) + sess.flush() + + sess.execute(self.tables.parent.delete()) + + p1.data = 3 + assert_raises_message( + orm_exc.StaleDataError, + "UPDATE statement on table 'parent' expected to " + "update 1 row\(s\); 0 were matched.", + sess.flush + ) + + def test_update_multi_missing_broken_multi_rowcount(self): + @util.memoized_property + def rowcount(self): + if len(self.context.compiled_parameters) > 1: + return -1 + else: + return self.context.rowcount + + with patch.object( + config.db.dialect, "supports_sane_multi_rowcount", False): + with patch( + "sqlalchemy.engine.result.ResultProxy.rowcount", + rowcount): + Parent, Child = self._fixture() + sess = Session() + p1 = Parent(id=1, data=2) + p2 = Parent(id=2, data=3) + sess.add_all([p1, p2]) + sess.flush() + + sess.execute(self.tables.parent.delete().where(Parent.id == 1)) + + p1.data = 3 + p2.data = 4 + sess.flush() # no exception + + # update occurred for remaining row + eq_( + sess.query(Parent.id, Parent.data).all(), + [(2, 4)] + ) + @testing.requires.sane_multi_rowcount def test_delete_multi_missing_warning(self): Parent, Child = self._fixture() @@ -1544,6 +1605,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): T(id=10, data='t10', def_='def3'), T(id=11, data='t11'), ]) + self.assert_sql_execution( testing.db, sess.flush,