From: Mike Bayer Date: Tue, 12 Apr 2016 19:57:20 +0000 (-0400) Subject: Add all versioning logic to _post_update() X-Git-Tag: rel_1_2_0b1~20^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=64b0760faa45a26c727a76b9fda97f2b4ea15417;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add all versioning logic to _post_update() An UPDATE emitted as a result of the :paramref:`.relationship.post_update` feature will now integrate with the versioning feature to both bump the version id of the row as well as assert that the existing version number was matched. Fixes: #3496 Change-Id: I865405dd6069f1c1e3b0d27a4980e9374e059f97 --- diff --git a/doc/build/changelog/changelog_12.rst b/doc/build/changelog/changelog_12.rst index 7da6bc8eae..7e081f64de 100644 --- a/doc/build/changelog/changelog_12.rst +++ b/doc/build/changelog/changelog_12.rst @@ -13,6 +13,19 @@ .. changelog:: :version: 1.2.0b1 + .. change:: 3496 + :tags: bug, orm + :tickets: 3496 + + An UPDATE emitted as a result of the + :paramref:`.relationship.post_update` feature will now integrate with + the versioning feature to both bump the version id of the row as well + as assert that the existing version number was matched. + + .. seealso:: + + :ref:`change_3496` + .. change:: 3769 :tags: bug, ext :tickets: 3769 diff --git a/doc/build/changelog/migration_12.rst b/doc/build/changelog/migration_12.rst index 8dee448950..6cc955a66f 100644 --- a/doc/build/changelog/migration_12.rst +++ b/doc/build/changelog/migration_12.rst @@ -1061,6 +1061,68 @@ within flush occurs in this case. :ticket:`3472` +.. _change_3496: + +post_update integrates with ORM versioning +------------------------------------------ + +The post_update feature, documented at :ref:`post_update`, involves that an +UPDATE statement is emitted in response to changes to a particular +relationship-bound foreign key, in addition to the INSERT/UPDATE/DELETE that +would normally be emitted for the target row. This UPDATE statement +now participates in the versioning feature, documented at +:ref:`mapper_version_counter`. + +Given a mapping:: + + class Node(Base): + __tablename__ = 'node' + id = Column(Integer, primary_key=True) + version_id = Column(Integer, default=0) + parent_id = Column(ForeignKey('node.id')) + favorite_node_id = Column(ForeignKey('node.id')) + + nodes = relationship("Node", primaryjoin=remote(parent_id) == id) + favorite_node = relationship( + "Node", primaryjoin=favorite_node_id == remote(id), + post_update=True + ) + + __mapper_args__ = { + 'version_id_col': version_id + } + +An UPDATE of a node that associates another node as "favorite" will +now increment the version counter as well as match the current version:: + + node = Node() + session.add(node) + session.commit() # node is now version #1 + + node = session.query(Node).get(node.id) + node.favorite_node = Node() + session.commit() # node is now version #2 + +Note that this means an object that receives an UPDATE in response to +other attributes changing, and a second UPDATE due to a post_update +relationship change, will now receive +**two version counter updates for one flush**. However, if the object +is subject to an INSERT within the current flush, the version counter +**will not** be incremented an additional time, unless a server-side +versioning scheme is in place. + +The reason post_update emits an UPDATE even for an UPDATE is now discussed at +:ref:`faq_post_update_update`. + +.. seealso:: + + :ref:`post_update` + + :ref:`faq_post_update_update` + + +:ticket:`3496` + Key Behavioral Changes - Core ============================= diff --git a/doc/build/faq/sessions.rst b/doc/build/faq/sessions.rst index 2daba29698..bbc16ded8e 100644 --- a/doc/build/faq/sessions.rst +++ b/doc/build/faq/sessions.rst @@ -497,4 +497,32 @@ Which is somewhat inconvenient. This `UniqueObject `_ recipe was created to address this issue. - +.. _faq_post_update_update: + +Why does post_update emit UPDATE in addition to the first UPDATE? +----------------------------------------------------------------- + +The post_update feature, documented at :ref:`post_update`, involves that an +UPDATE statement is emitted in response to changes to a particular +relationship-bound foreign key, in addition to the INSERT/UPDATE/DELETE that +would normally be emitted for the target row. While the primary purpose of this +UPDATE statement is that it pairs up with an INSERT or DELETE of that row, so +that it can post-set or pre-unset a foreign key reference in order to break a +cycle with a mutually dependent foreign key, it currently is also bundled as a +second UPDATE that emits when the target row itself is subject to an UPDATE. +In this case, the UPDATE emitted by post_update is *usually* unnecessary +and will often appear wasteful. + +However, some research into trying to remove this "UPDATE / UPDATE" behavior +reveals that major changes to the unit of work process would need to occur not +just throughout the post_update implementation, but also in areas that aren't +related to post_update for this to work, in that the order of operations would +need to be reversed on the non-post_update side in some cases, which in turn +can impact other cases, such as correctly handling an UPDATE of a referenced +primary key value (see :ticket:`1063` for a proof of concept). + +The answer is that "post_update" is used to break a cycle between two +mutually dependent foreign keys, and to have this cycle breaking be limited +to just INSERT/DELETE of the target table implies that the ordering of UPDATE +statements elsewhere would need to be liberalized, leading to breakage +in other edge cases. diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 0de64011a0..924b9e1c94 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -212,15 +212,22 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): continue update = ( - (state, state_dict, sub_mapper, connection) + ( + state, state_dict, sub_mapper, connection, + mapper._get_committed_state_attr_by_column( + state, state_dict, mapper.version_id_col + ) if mapper.version_id_col is not None else None + ) for state, state_dict, sub_mapper, connection in states_to_update if table in sub_mapper._pks_by_table ) - update = _collect_post_update_commands(base_mapper, uowtransaction, - table, update, - post_update_cols) + update = _collect_post_update_commands( + base_mapper, uowtransaction, + table, update, + post_update_cols + ) _emit_post_update_statements(base_mapper, uowtransaction, cached_connections, @@ -576,7 +583,8 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, """ - for state, state_dict, mapper, connection in states_to_update: + for state, state_dict, mapper, connection, \ + update_version_id in states_to_update: # assert table in mapper._pks_by_table @@ -601,6 +609,16 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, params[col.key] = value hasdata = True if hasdata: + if update_version_id is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: + + col = mapper.version_id_col + params[col._label] = update_version_id + + if bool(state.key) and col.key not in params and \ + mapper.version_id_generator is not False: + val = mapper.version_id_generator(update_version_id) + params[col.key] = val yield state, state_dict, mapper, connection, params @@ -870,6 +888,9 @@ def _emit_post_update_statements(base_mapper, uowtransaction, """Emit UPDATE statements corresponding to value lists collected by _collect_post_update_commands().""" + needs_version_id = mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table] + def update_stmt(): clause = sql.and_() @@ -877,7 +898,18 @@ def _emit_post_update_statements(base_mapper, uowtransaction, clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) - return table.update(clause) + if needs_version_id: + clause.clauses.append( + mapper.version_id_col == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type)) + + stmt = table.update(clause) + + if mapper.version_id_col is not None: + stmt = stmt.return_defaults(mapper.version_id_col) + + return stmt statement = base_mapper._memo(('post_update', table), update_stmt) @@ -885,23 +917,63 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # list of states to guarantee row access order, but # also group them into common (connection, cols) sets # to support executemany(). - for key, grouper in groupby( + for key, records in groupby( update, lambda rec: ( rec[3], # connection set(rec[4]), # parameter keys ) ): - grouper = list(grouper) + rows = 0 + + records = list(records) connection = key[0] - multiparams = [ - params for state, state_dict, mapper_rec, conn, params in grouper] - c = cached_connections[connection].\ - execute(statement, multiparams) - for state, state_dict, mapper_rec, connection, params in grouper: - _postfetch_post_update( - mapper, uowtransaction, state, state_dict, - c, c.context.compiled_parameters[0]) + assert_singlerow = connection.dialect.supports_sane_rowcount + assert_multirow = assert_singlerow and \ + connection.dialect.supports_sane_multi_rowcount + allow_multirow = not needs_version_id or assert_multirow + + if not allow_multirow: + check_rowcount = assert_singlerow + for state, state_dict, mapper_rec, \ + connection, params in records: + c = cached_connections[connection].\ + execute(statement, params) + _postfetch_post_update( + mapper_rec, uowtransaction, table, state, state_dict, + c, c.context.compiled_parameters[0]) + rows += c.rowcount + else: + multiparams = [ + params for + state, state_dict, mapper_rec, conn, params in records] + + check_rowcount = assert_multirow or ( + assert_singlerow and + len(multiparams) == 1 + ) + + c = cached_connections[connection].\ + execute(statement, multiparams) + + rows += c.rowcount + for state, state_dict, mapper_rec, \ + connection, params in records: + _postfetch_post_update( + mapper_rec, uowtransaction, table, state, state_dict, + c, c.context.compiled_parameters[0]) + + if check_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(records), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description) def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, @@ -1045,11 +1117,15 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): "Instance does not contain a non-NULL version value") -def _postfetch_post_update(mapper, uowtransaction, +def _postfetch_post_update(mapper, uowtransaction, table, state, dict_, result, params): prefetch_cols = result.context.compiled.prefetch postfetch_cols = result.context.compiled.postfetch + if mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) if refresh_flush: load_evt_attrs = [] diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 10d3132193..4d9d6883ac 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -549,6 +549,142 @@ class VersioningTest(fixtures.MappedTest): ) +class VersionOnPostUpdateTest(fixtures.MappedTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + 'node', metadata, + Column('id', Integer, primary_key=True), + Column('version_id', Integer), + Column('parent_id', ForeignKey('node.id')) + ) + + @classmethod + def setup_classes(cls): + class Node(cls.Basic): + pass + + def _fixture(self, o2m, post_update, insert=True): + Node = self.classes.Node + node = self.tables.node + + mapper(Node, node, properties={ + 'related': relationship( + Node, + remote_side=node.c.id if not o2m else node.c.parent_id, + post_update=post_update + ) + }, version_id_col=node.c.version_id) + + s = Session() + n1 = Node(id=1) + n2 = Node(id=2) + + if insert: + s.add_all([n1, n2]) + s.flush() + return s, n1, n2 + + def test_o2m_plain(self): + s, n1, n2 = self._fixture(o2m=True, post_update=False) + + n1.related.append(n2) + s.flush() + + eq_(n1.version_id, 1) + eq_(n2.version_id, 2) + + def test_m2o_plain(self): + s, n1, n2 = self._fixture(o2m=False, post_update=False) + + n1.related = n2 + s.flush() + + eq_(n1.version_id, 2) + eq_(n2.version_id, 1) + + def test_o2m_post_update(self): + s, n1, n2 = self._fixture(o2m=True, post_update=True) + + n1.related.append(n2) + s.flush() + + eq_(n1.version_id, 1) + eq_(n2.version_id, 2) + + def test_m2o_post_update(self): + s, n1, n2 = self._fixture(o2m=False, post_update=True) + + n1.related = n2 + s.flush() + + eq_(n1.version_id, 2) + eq_(n2.version_id, 1) + + def test_o2m_post_update_not_assoc_w_insert(self): + s, n1, n2 = self._fixture(o2m=True, post_update=True, insert=False) + + n1.related.append(n2) + s.add_all([n1, n2]) + s.flush() + + eq_(n1.version_id, 1) + eq_(n2.version_id, 1) + + def test_m2o_post_update_not_assoc_w_insert(self): + s, n1, n2 = self._fixture(o2m=False, post_update=True, insert=False) + + n1.related = n2 + s.add_all([n1, n2]) + s.flush() + + eq_(n1.version_id, 1) + eq_(n2.version_id, 1) + + def test_o2m_post_update_version_assert(self): + Node = self.classes.Node + s, n1, n2 = self._fixture(o2m=True, post_update=True) + + n1.related.append(n2) + + # outwit the database transaction isolation and SQLA's + # expiration at the same time by using different Session on + # same transaction + s2 = Session(bind=s.connection(Node)) + s2.query(Node).filter(Node.id == n2.id).update({"version_id": 3}) + s2.commit() + + assert_raises_message( + orm_exc.StaleDataError, + "UPDATE statement on table 'node' expected to " + r"update 1 row\(s\); 0 were matched.", + s.flush + ) + + def test_m2o_post_update_version_assert(self): + Node = self.classes.Node + + s, n1, n2 = self._fixture(o2m=False, post_update=True) + + n1.related = n2 + + # outwit the database transaction isolation and SQLA's + # expiration at the same time by using different Session on + # same transaction + s2 = Session(bind=s.connection(Node)) + s2.query(Node).filter(Node.id == n1.id).update({"version_id": 3}) + s2.commit() + + assert_raises_message( + orm_exc.StaleDataError, + "UPDATE statement on table 'node' expected to " + r"update 1 row\(s\); 0 were matched.", + s.flush + ) + + class NoBumpOnRelationshipTest(fixtures.MappedTest): __backend__ = True