]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add all versioning logic to _post_update()
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Apr 2016 19:57:20 +0000 (15:57 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Jun 2017 22:58:29 +0000 (18:58 -0400)
An UPDATE emitted as a result of the
:paramref:`.relationship.post_update` feature will now integrate with
the versioning feature to both bump the version id of the row as well
as assert that the existing version number was matched.

Fixes: #3496
Change-Id: I865405dd6069f1c1e3b0d27a4980e9374e059f97

doc/build/changelog/changelog_12.rst
doc/build/changelog/migration_12.rst
doc/build/faq/sessions.rst
lib/sqlalchemy/orm/persistence.py
test/orm/test_versioning.py

index 7da6bc8eae74aa322558dc15266b4c17cbbd7f4b..7e081f64de0a47978789d19595c34fd380b250b5 100644 (file)
 .. changelog::
     :version: 1.2.0b1
 
+    .. change:: 3496
+        :tags: bug, orm
+        :tickets: 3496
+
+        An UPDATE emitted as a result of the
+        :paramref:`.relationship.post_update` feature will now integrate with
+        the versioning feature to both bump the version id of the row as well
+        as assert that the existing version number was matched.
+
+        .. seealso::
+
+            :ref:`change_3496`
+
     .. change:: 3769
         :tags: bug, ext
         :tickets: 3769
index 8dee4489507688f3198ca40a23729c4aee1e4549..6cc955a66f5c1c85a57deb37bb4ecc9292eee791 100644 (file)
@@ -1061,6 +1061,68 @@ within flush occurs in this case.
 
 :ticket:`3472`
 
+.. _change_3496:
+
+post_update integrates with ORM versioning
+------------------------------------------
+
+The post_update feature, documented at :ref:`post_update`, involves that an
+UPDATE statement is emitted in response to changes to a particular
+relationship-bound foreign key, in addition to the INSERT/UPDATE/DELETE that
+would normally be emitted for the target row.  This UPDATE statement
+now participates in the versioning feature, documented at
+:ref:`mapper_version_counter`.
+
+Given a mapping::
+
+    class Node(Base):
+        __tablename__ = 'node'
+        id = Column(Integer, primary_key=True)
+        version_id = Column(Integer, default=0)
+        parent_id = Column(ForeignKey('node.id'))
+        favorite_node_id = Column(ForeignKey('node.id'))
+
+        nodes = relationship("Node", primaryjoin=remote(parent_id) == id)
+        favorite_node = relationship(
+            "Node", primaryjoin=favorite_node_id == remote(id),
+            post_update=True
+        )
+
+        __mapper_args__ = {
+            'version_id_col': version_id
+        }
+
+An UPDATE of a node that associates another node as "favorite" will
+now increment the version counter as well as match the current version::
+
+    node = Node()
+    session.add(node)
+    session.commit()  # node is now version #1
+
+    node = session.query(Node).get(node.id)
+    node.favorite_node = Node()
+    session.commit()  # node is now version #2
+
+Note that this means an object that receives an UPDATE in response to
+other attributes changing, and a second UPDATE due to a post_update
+relationship change, will now receive
+**two version counter updates for one flush**.   However, if the object
+is subject to an INSERT within the current flush, the version counter
+**will not** be incremented an additional time, unless a server-side
+versioning scheme is in place.
+
+The reason post_update emits an UPDATE even for an UPDATE is now discussed at
+:ref:`faq_post_update_update`.
+
+.. seealso::
+
+    :ref:`post_update`
+
+    :ref:`faq_post_update_update`
+
+
+:ticket:`3496`
+
 Key Behavioral Changes - Core
 =============================
 
index 2daba29698b5f70a61720df544076c30951404e2..bbc16ded8e74f81f52da26a0561e4510e9a154df 100644 (file)
@@ -497,4 +497,32 @@ Which is somewhat inconvenient.
 
 This `UniqueObject <http://www.sqlalchemy.org/trac/wiki/UsageRecipes/UniqueObject>`_ recipe was created to address this issue.
 
-
+.. _faq_post_update_update:
+
+Why does post_update emit UPDATE in addition to the first UPDATE?
+-----------------------------------------------------------------
+
+The post_update feature, documented at :ref:`post_update`, involves that an
+UPDATE statement is emitted in response to changes to a particular
+relationship-bound foreign key, in addition to the INSERT/UPDATE/DELETE that
+would normally be emitted for the target row.  While the primary purpose of this
+UPDATE statement is that it pairs up with an INSERT or DELETE of that row, so
+that it can post-set or pre-unset a foreign key reference in order to break a
+cycle with a mutually dependent foreign key, it currently is also bundled as a
+second UPDATE that emits when the target row itself is subject to an UPDATE.
+In this case, the UPDATE emitted by post_update is *usually* unnecessary
+and will often appear wasteful.
+
+However, some research into trying to remove this "UPDATE / UPDATE" behavior
+reveals that major changes to the unit of work process would need to occur  not
+just throughout the post_update implementation, but also in areas that aren't
+related to post_update for this to work, in that the order of operations would
+need to be reversed on the non-post_update side in some cases, which in turn
+can impact other cases, such as correctly handling an UPDATE of a referenced
+primary key value (see :ticket:`1063` for a proof of concept).
+
+The answer is that "post_update" is used to break a cycle between two
+mutually dependent foreign keys, and to have this cycle breaking be limited
+to just INSERT/DELETE of the target table implies that the ordering of UPDATE
+statements elsewhere would need to be liberalized, leading to breakage
+in other edge cases.
index 0de64011a0f3540849ed9340395352d2e99bacb2..924b9e1c94fd0ae1954e8abec921472481765b4c 100644 (file)
@@ -212,15 +212,22 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
             continue
 
         update = (
-            (state, state_dict, sub_mapper, connection)
+            (
+                state, state_dict, sub_mapper, connection,
+                mapper._get_committed_state_attr_by_column(
+                    state, state_dict, mapper.version_id_col
+                ) if mapper.version_id_col is not None else None
+            )
             for
             state, state_dict, sub_mapper, connection in states_to_update
             if table in sub_mapper._pks_by_table
         )
 
-        update = _collect_post_update_commands(base_mapper, uowtransaction,
-                                               table, update,
-                                               post_update_cols)
+        update = _collect_post_update_commands(
+            base_mapper, uowtransaction,
+            table, update,
+            post_update_cols
+        )
 
         _emit_post_update_statements(base_mapper, uowtransaction,
                                      cached_connections,
@@ -576,7 +583,8 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
 
     """
 
-    for state, state_dict, mapper, connection in states_to_update:
+    for state, state_dict, mapper, connection, \
+            update_version_id in states_to_update:
 
         # assert table in mapper._pks_by_table
 
@@ -601,6 +609,16 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
                     params[col.key] = value
                     hasdata = True
         if hasdata:
+            if update_version_id is not None and \
+                    mapper.version_id_col in mapper._cols_by_table[table]:
+
+                col = mapper.version_id_col
+                params[col._label] = update_version_id
+
+                if bool(state.key) and col.key not in params and \
+                        mapper.version_id_generator is not False:
+                    val = mapper.version_id_generator(update_version_id)
+                    params[col.key] = val
             yield state, state_dict, mapper, connection, params
 
 
@@ -870,6 +888,9 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
     """Emit UPDATE statements corresponding to value lists collected
     by _collect_post_update_commands()."""
 
+    needs_version_id = mapper.version_id_col is not None and \
+        mapper.version_id_col in mapper._cols_by_table[table]
+
     def update_stmt():
         clause = sql.and_()
 
@@ -877,7 +898,18 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
             clause.clauses.append(col == sql.bindparam(col._label,
                                                        type_=col.type))
 
-        return table.update(clause)
+        if needs_version_id:
+            clause.clauses.append(
+                mapper.version_id_col == sql.bindparam(
+                    mapper.version_id_col._label,
+                    type_=mapper.version_id_col.type))
+
+        stmt = table.update(clause)
+
+        if mapper.version_id_col is not None:
+            stmt = stmt.return_defaults(mapper.version_id_col)
+
+        return stmt
 
     statement = base_mapper._memo(('post_update', table), update_stmt)
 
@@ -885,23 +917,63 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
     # list of states to guarantee row access order, but
     # also group them into common (connection, cols) sets
     # to support executemany().
-    for key, grouper in groupby(
+    for key, records in groupby(
         update, lambda rec: (
             rec[3],  # connection
             set(rec[4]),  # parameter keys
         )
     ):
-        grouper = list(grouper)
+        rows = 0
+
+        records = list(records)
         connection = key[0]
-        multiparams = [
-            params for state, state_dict, mapper_rec, conn, params in grouper]
-        c = cached_connections[connection].\
-            execute(statement, multiparams)
 
-        for state, state_dict, mapper_rec, connection, params in grouper:
-            _postfetch_post_update(
-                mapper, uowtransaction, state, state_dict,
-                c, c.context.compiled_parameters[0])
+        assert_singlerow = connection.dialect.supports_sane_rowcount
+        assert_multirow = assert_singlerow and \
+            connection.dialect.supports_sane_multi_rowcount
+        allow_multirow = not needs_version_id or assert_multirow
+
+        if not allow_multirow:
+            check_rowcount = assert_singlerow
+            for state, state_dict, mapper_rec, \
+                    connection, params in records:
+                c = cached_connections[connection].\
+                    execute(statement, params)
+                _postfetch_post_update(
+                    mapper_rec, uowtransaction, table, state, state_dict,
+                    c, c.context.compiled_parameters[0])
+                rows += c.rowcount
+        else:
+            multiparams = [
+                params for
+                state, state_dict, mapper_rec, conn, params in records]
+
+            check_rowcount = assert_multirow or (
+                assert_singlerow and
+                len(multiparams) == 1
+            )
+
+            c = cached_connections[connection].\
+                execute(statement, multiparams)
+
+            rows += c.rowcount
+            for state, state_dict, mapper_rec, \
+                    connection, params in records:
+                _postfetch_post_update(
+                    mapper_rec, uowtransaction, table, state, state_dict,
+                    c, c.context.compiled_parameters[0])
+
+        if check_rowcount:
+            if rows != len(records):
+                raise orm_exc.StaleDataError(
+                    "UPDATE statement on table '%s' expected to "
+                    "update %d row(s); %d were matched." %
+                    (table.description, len(records), rows))
+
+        elif needs_version_id:
+            util.warn("Dialect %s does not support updated rowcount "
+                      "- versioning cannot be verified." %
+                      c.dialect.dialect_description)
 
 
 def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
@@ -1045,11 +1117,15 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
                     "Instance does not contain a non-NULL version value")
 
 
-def _postfetch_post_update(mapper, uowtransaction,
+def _postfetch_post_update(mapper, uowtransaction, table,
                            state, dict_, result, params):
     prefetch_cols = result.context.compiled.prefetch
     postfetch_cols = result.context.compiled.postfetch
 
+    if mapper.version_id_col is not None and \
+            mapper.version_id_col in mapper._cols_by_table[table]:
+        prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+
     refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
     if refresh_flush:
         load_evt_attrs = []
index 10d31321939f77b97ab23a6d245aa21adafe0019..4d9d6883acb1e056d887972cbfcc00b86b7ef8af 100644 (file)
@@ -549,6 +549,142 @@ class VersioningTest(fixtures.MappedTest):
         )
 
 
+class VersionOnPostUpdateTest(fixtures.MappedTest):
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            'node', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('version_id', Integer),
+            Column('parent_id', ForeignKey('node.id'))
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Node(cls.Basic):
+            pass
+
+    def _fixture(self, o2m, post_update, insert=True):
+        Node = self.classes.Node
+        node = self.tables.node
+
+        mapper(Node, node, properties={
+            'related': relationship(
+                Node,
+                remote_side=node.c.id if not o2m else node.c.parent_id,
+                post_update=post_update
+            )
+        }, version_id_col=node.c.version_id)
+
+        s = Session()
+        n1 = Node(id=1)
+        n2 = Node(id=2)
+
+        if insert:
+            s.add_all([n1, n2])
+            s.flush()
+        return s, n1, n2
+
+    def test_o2m_plain(self):
+        s, n1, n2 = self._fixture(o2m=True, post_update=False)
+
+        n1.related.append(n2)
+        s.flush()
+
+        eq_(n1.version_id, 1)
+        eq_(n2.version_id, 2)
+
+    def test_m2o_plain(self):
+        s, n1, n2 = self._fixture(o2m=False, post_update=False)
+
+        n1.related = n2
+        s.flush()
+
+        eq_(n1.version_id, 2)
+        eq_(n2.version_id, 1)
+
+    def test_o2m_post_update(self):
+        s, n1, n2 = self._fixture(o2m=True, post_update=True)
+
+        n1.related.append(n2)
+        s.flush()
+
+        eq_(n1.version_id, 1)
+        eq_(n2.version_id, 2)
+
+    def test_m2o_post_update(self):
+        s, n1, n2 = self._fixture(o2m=False, post_update=True)
+
+        n1.related = n2
+        s.flush()
+
+        eq_(n1.version_id, 2)
+        eq_(n2.version_id, 1)
+
+    def test_o2m_post_update_not_assoc_w_insert(self):
+        s, n1, n2 = self._fixture(o2m=True, post_update=True, insert=False)
+
+        n1.related.append(n2)
+        s.add_all([n1, n2])
+        s.flush()
+
+        eq_(n1.version_id, 1)
+        eq_(n2.version_id, 1)
+
+    def test_m2o_post_update_not_assoc_w_insert(self):
+        s, n1, n2 = self._fixture(o2m=False, post_update=True, insert=False)
+
+        n1.related = n2
+        s.add_all([n1, n2])
+        s.flush()
+
+        eq_(n1.version_id, 1)
+        eq_(n2.version_id, 1)
+
+    def test_o2m_post_update_version_assert(self):
+        Node = self.classes.Node
+        s, n1, n2 = self._fixture(o2m=True, post_update=True)
+
+        n1.related.append(n2)
+
+        # outwit the database transaction isolation and SQLA's
+        # expiration at the same time by using different Session on
+        # same transaction
+        s2 = Session(bind=s.connection(Node))
+        s2.query(Node).filter(Node.id == n2.id).update({"version_id": 3})
+        s2.commit()
+
+        assert_raises_message(
+            orm_exc.StaleDataError,
+            "UPDATE statement on table 'node' expected to "
+            r"update 1 row\(s\); 0 were matched.",
+            s.flush
+        )
+
+    def test_m2o_post_update_version_assert(self):
+        Node = self.classes.Node
+
+        s, n1, n2 = self._fixture(o2m=False, post_update=True)
+
+        n1.related = n2
+
+        # outwit the database transaction isolation and SQLA's
+        # expiration at the same time by using different Session on
+        # same transaction
+        s2 = Session(bind=s.connection(Node))
+        s2.query(Node).filter(Node.id == n1.id).update({"version_id": 3})
+        s2.commit()
+
+        assert_raises_message(
+            orm_exc.StaleDataError,
+            "UPDATE statement on table 'node' expected to "
+            r"update 1 row\(s\); 0 were matched.",
+            s.flush
+        )
+
+
 class NoBumpOnRelationshipTest(fixtures.MappedTest):
     __backend__ = True