params[col.key] = value
hasdata = True
if hasdata:
- yield params, connection
+ yield state, state_dict, mapper, connection, params
def _collect_delete_commands(base_mapper, uowtransaction, table,
# to support executemany().
for key, grouper in groupby(
update, lambda rec: (
- rec[1], # connection
- set(rec[0]) # parameter keys
+ rec[3], # connection
+ set(rec[4]), # parameter keys
)
):
+ grouper = list(grouper)
connection = key[0]
- multiparams = [params for params, conn in grouper]
- cached_connections[connection].\
+ multiparams = [
+ params for state, state_dict, mapper_rec, conn, params in grouper]
+ c = cached_connections[connection].\
execute(statement, multiparams)
+ for state, state_dict, mapper_rec, connection, params in grouper:
+ _postfetch_post_update(
+ mapper, uowtransaction, state, state_dict,
+ c, c.context.compiled_parameters[0])
+
def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
mapper, table, delete):
"Instance does not contain a non-NULL version value")
+def _postfetch_post_update(mapper, uowtransaction,
+ state, dict_, result, params):
+ prefetch_cols = result.context.compiled.prefetch
+ postfetch_cols = result.context.compiled.postfetch
+
+ refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
+ if refresh_flush:
+ load_evt_attrs = []
+
+ for c in prefetch_cols:
+ if c.key in params and c in mapper._columntoproperty:
+ dict_[mapper._columntoproperty[c].key] = params[c.key]
+ if refresh_flush:
+ load_evt_attrs.append(mapper._columntoproperty[c].key)
+
+ if refresh_flush and load_evt_attrs:
+ mapper.class_manager.dispatch.refresh_flush(
+ state, uowtransaction, load_evt_attrs)
+
+ if postfetch_cols:
+ state._expire_attributes(state.dict,
+ [mapper._columntoproperty[c].key
+ for c in postfetch_cols if c in
+ mapper._columntoproperty]
+ )
+
+
def _postfetch(mapper, uowtransaction, table,
state, dict_, result, params, value_params):
"""Expire attributes in need of newly persisted database state,
"""
from sqlalchemy import testing
+from sqlalchemy import event
+from sqlalchemy.testing import mock
from sqlalchemy import Integer, String, ForeignKey
from sqlalchemy.testing.schema import Table, Column
from sqlalchemy.orm import mapper, relationship, backref, \
- create_session, sessionmaker
+ create_session, sessionmaker, Session
from sqlalchemy.testing import eq_, is_
from sqlalchemy.testing.assertsql import RegexSQL, CompiledSQL, AllOf
from sqlalchemy.testing import fixtures
+from itertools import count
class SelfReferentialTest(fixtures.MappedTest):
'c1_id': None, 'c3_id': None}
)
)
+
+
+class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest):
+
+ counter = count()
+
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class A(Base):
+ __tablename__ = 'a'
+ id = Column(Integer, primary_key=True)
+ favorite_b_id = Column(ForeignKey('b.id', name="favorite_b_fk"))
+ bs = relationship("B", primaryjoin="A.id == B.a_id")
+ favorite_b = relationship(
+ "B", primaryjoin="A.favorite_b_id == B.id", post_update=True)
+ updated = Column(Integer, onupdate=lambda: next(cls.counter))
+
+ class B(Base):
+ __tablename__ = 'b'
+ id = Column(Integer, primary_key=True)
+ a_id = Column(ForeignKey('a.id', name="a_fk"))
+
+ def setup(self):
+ super(PostUpdateOnUpdateTest, self).setup()
+ PostUpdateOnUpdateTest.counter = count()
+
+ def test_update_defaults(self):
+ A, B = self.classes("A", "B")
+
+ s = Session()
+ a1 = A()
+ b1 = B()
+
+ a1.bs.append(b1)
+ a1.favorite_b = b1
+ s.add(a1)
+ s.flush()
+
+ eq_(a1.updated, 0)
+
+ def test_update_defaults_refresh_flush_event(self):
+ A, B = self.classes("A", "B")
+
+ canary = mock.Mock()
+ event.listen(A, "refresh_flush", canary)
+
+ s = Session()
+ a1 = A()
+ b1 = B()
+
+ a1.bs.append(b1)
+ a1.favorite_b = b1
+ s.add(a1)
+ s.flush()
+
+ eq_(a1.updated, 0)
+ eq_(
+ canary.mock_calls,
+ [
+ mock.call(a1, mock.ANY, ['updated'])
+ ]
+ )