rows = 0
records = list(records)
+ # TODO: would be super-nice to not have to determine this boolean
+ # inside the loop here, in the 99.9999% of the time there's only
+ # one connection in use
+ assert_singlerow = connection.dialect.supports_sane_rowcount
+ assert_multirow = assert_singlerow and \
+ connection.dialect.supports_sane_multi_rowcount
+ allow_multirow = not needs_version_id or assert_multirow
+
if hasvalue:
for state, state_dict, params, mapper, \
connection, value_params in records:
value_params)
rows += c.rowcount
else:
- if needs_version_id and \
- not connection.dialect.supports_sane_multi_rowcount and \
- connection.dialect.supports_sane_rowcount:
+ if not allow_multirow:
for state, state_dict, params, mapper, \
connection, value_params in records:
c = cached_connections[connection].\
rows += c.rowcount
else:
multiparams = [rec[2] for rec in records]
+
c = cached_connections[connection].\
execute(statement, multiparams)
c.context.compiled_parameters[0],
value_params)
- if connection.dialect.supports_sane_rowcount:
+ if assert_multirow or assert_singlerow and \
+ len(multiparams) == 1:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
from sqlalchemy.testing import engines
from sqlalchemy.testing.schema import Table, Column
from test.orm import _fixtures
-from sqlalchemy import exc
-from sqlalchemy.testing import fixtures
+from sqlalchemy import exc, util
+from sqlalchemy.testing import fixtures, config
from sqlalchemy import Integer, String, ForeignKey, func
from sqlalchemy.orm import mapper, relationship, backref, \
create_session, unitofwork, attributes,\
Session, exc as orm_exc
-from sqlalchemy.testing.mock import Mock
+from sqlalchemy.testing.mock import Mock, patch
from sqlalchemy.testing.assertsql import AllOf, CompiledSQL
from sqlalchemy import event
sess.flush
)
+ def test_update_single_missing_broken_multi_rowcount(self):
+ @util.memoized_property
+ def rowcount(self):
+ if len(self.context.compiled_parameters) > 1:
+ return -1
+ else:
+ return self.context.rowcount
+
+ with patch.object(
+ config.db.dialect, "supports_sane_multi_rowcount", False):
+ with patch(
+ "sqlalchemy.engine.result.ResultProxy.rowcount",
+ rowcount):
+ Parent, Child = self._fixture()
+ sess = Session()
+ p1 = Parent(id=1, data=2)
+ sess.add(p1)
+ sess.flush()
+
+ sess.execute(self.tables.parent.delete())
+
+ p1.data = 3
+ assert_raises_message(
+ orm_exc.StaleDataError,
+ "UPDATE statement on table 'parent' expected to "
+ "update 1 row\(s\); 0 were matched.",
+ sess.flush
+ )
+
+ def test_update_multi_missing_broken_multi_rowcount(self):
+ @util.memoized_property
+ def rowcount(self):
+ if len(self.context.compiled_parameters) > 1:
+ return -1
+ else:
+ return self.context.rowcount
+
+ with patch.object(
+ config.db.dialect, "supports_sane_multi_rowcount", False):
+ with patch(
+ "sqlalchemy.engine.result.ResultProxy.rowcount",
+ rowcount):
+ Parent, Child = self._fixture()
+ sess = Session()
+ p1 = Parent(id=1, data=2)
+ p2 = Parent(id=2, data=3)
+ sess.add_all([p1, p2])
+ sess.flush()
+
+ sess.execute(self.tables.parent.delete().where(Parent.id == 1))
+
+ p1.data = 3
+ p2.data = 4
+ sess.flush() # no exception
+
+ # update occurred for remaining row
+ eq_(
+ sess.query(Parent.id, Parent.data).all(),
+ [(2, 4)]
+ )
+
@testing.requires.sane_multi_rowcount
def test_delete_multi_missing_warning(self):
Parent, Child = self._fixture()
T(id=10, data='t10', def_='def3'),
T(id=11, data='t11'),
])
+
self.assert_sql_execution(
testing.db,
sess.flush,