]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Don't load expired objects from evaluator
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Oct 2020 19:01:03 +0000 (15:01 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Oct 2020 19:22:52 +0000 (15:22 -0400)
part 2 of e3dc20ff27fa75e571fb2631e64737ea8f25f7c5, the
pre-evaluate step was also emitting SELECT for expired objects.

Fixes: #5664
Change-Id: I9f5de2a5d480eafeb290ec0c45ce2a82ec93475b

lib/sqlalchemy/orm/persistence.py
test/orm/test_evaluator.py
test/orm/test_update_delete.py

index 1794cc2ce2bfcf30b5d6af31dec8f50cecee0a17..cfb6d926587ee1a27104d96dd22d242a93c70d12 100644 (file)
@@ -1978,6 +1978,7 @@ class BulkUDCompileState(CompileState):
             state.obj()
             for state in session.identity_map.all_states()
             if state.mapper.isa(mapper)
+            and not state.expired
             and eval_condition(state.obj())
             and (
                 update_options._refresh_identity_token is None
@@ -2209,9 +2210,8 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
             # only evaluate unmodified attributes
             to_evaluate = state.unmodified.intersection(evaluated_keys)
             for key in to_evaluate:
-                if key not in dict_:
-                    continue
-                dict_[key] = update_options._value_evaluators[key](obj)
+                if key in dict_:
+                    dict_[key] = update_options._value_evaluators[key](obj)
 
             state.manager.dispatch.refresh(state, None, to_evaluate)
 
index 955e5134fc00af03ecf3a43be2130113ced90cbc..ec843d1c5cc4e230fde3868512b371ffd84199dc 100644 (file)
@@ -10,9 +10,11 @@ from sqlalchemy import or_
 from sqlalchemy import String
 from sqlalchemy import tuple_
 from sqlalchemy.orm import evaluator
+from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
@@ -285,10 +287,10 @@ class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
             name = Column(String(50), primary_key=True)
             parent = relationship(Parent)
 
-    def test_delete(self):
+    def test_delete_not_expired(self):
         Parent, Child = self.classes("Parent", "Child")
 
-        session = Session()
+        session = Session(expire_on_commit=False)
 
         p = Parent(id=1)
         session.add(p)
@@ -301,3 +303,24 @@ class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
         session.query(Child).filter(Child.parent == p).delete("evaluate")
 
         is_(inspect(c).deleted, True)
+
+    def test_delete_expired(self):
+        Parent, Child = self.classes("Parent", "Child")
+
+        session = Session()
+
+        p = Parent(id=1)
+        session.add(p)
+        session.commit()
+
+        c = Child(name="foo", parent=p)
+        session.add(c)
+        session.commit()
+
+        session.query(Child).filter(Child.parent == p).delete("evaluate")
+
+        # because it's expired
+        is_(inspect(c).deleted, False)
+
+        # but it's gone
+        assert_raises(orm_exc.ObjectDeletedError, lambda: c.name)
index 00bc344bce9c43db15d35376b9a2b3f9753cfeb8..0b0c9cea76fd18f29a244be162c1c1c2a01a1be5 100644 (file)
@@ -276,7 +276,15 @@ class UpdateDeleteTest(fixtures.MappedTest):
         )
         eq_(jill.ufoo, "moonbeam")
 
-    def test_evaluate_dont_refresh_expired_objects(self):
+    @testing.combinations(
+        (False, False),
+        (False, True),
+        (True, False),
+        (True, True),
+    )
+    def test_evaluate_dont_refresh_expired_objects(
+        self, expire_jane_age, add_filter_criteria
+    ):
         User = self.classes.User
 
         sess = Session()
@@ -285,28 +293,74 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         sess.expire(john)
         sess.expire(jill)
-        sess.expire(jane, ["name"])
+
+        if expire_jane_age:
+            sess.expire(jane, ["name", "age"])
+        else:
+            sess.expire(jane, ["name"])
 
         with self.sql_execution_asserter() as asserter:
             # using 1.x style for easier backport
-            sess.query(User).update(
-                {"age": User.age + 10}, synchronize_session="evaluate"
-            )
+            if add_filter_criteria:
+                sess.query(User).filter(User.name != None).update(
+                    {"age": User.age + 10}, synchronize_session="evaluate"
+                )
+            else:
+                sess.query(User).update(
+                    {"age": User.age + 10}, synchronize_session="evaluate"
+                )
 
-        asserter.assert_(
-            CompiledSQL(
-                "UPDATE users SET age_int=(users.age_int + :age_int_1)",
-                [{"age_int_1": 10}],
-            ),
-        )
+        if add_filter_criteria:
+            if expire_jane_age:
+                asserter.assert_(
+                    # it has to unexpire jane.name, because jane is not fully
+                    # expired and the critiera needs to look at this particular
+                    # key
+                    CompiledSQL(
+                        "SELECT users.age_int AS users_age_int, "
+                        "users.name AS users_name FROM users "
+                        "WHERE users.id = :param_1",
+                        [{"param_1": 4}],
+                    ),
+                    CompiledSQL(
+                        "UPDATE users "
+                        "SET age_int=(users.age_int + :age_int_1) "
+                        "WHERE users.name IS NOT NULL",
+                        [{"age_int_1": 10}],
+                    ),
+                )
+            else:
+                asserter.assert_(
+                    # it has to unexpire jane.name, because jane is not fully
+                    # expired and the critiera needs to look at this particular
+                    # key
+                    CompiledSQL(
+                        "SELECT users.name AS users_name FROM users "
+                        "WHERE users.id = :param_1",
+                        [{"param_1": 4}],
+                    ),
+                    CompiledSQL(
+                        "UPDATE users SET "
+                        "age_int=(users.age_int + :age_int_1) "
+                        "WHERE users.name IS NOT NULL",
+                        [{"age_int_1": 10}],
+                    ),
+                )
+        else:
+            asserter.assert_(
+                CompiledSQL(
+                    "UPDATE users SET age_int=(users.age_int + :age_int_1)",
+                    [{"age_int_1": 10}],
+                ),
+            )
 
         with self.sql_execution_asserter() as asserter:
             eq_(john.age, 35)  # needs refresh
             eq_(jack.age, 57)  # no SQL needed
             eq_(jill.age, 39)  # needs refresh
-            eq_(jane.age, 47)  # no SQL needed
+            eq_(jane.age, 47)  # needs refresh
 
-        asserter.assert_(
+        to_assert = [
             # refresh john
             CompiledSQL(
                 "SELECT users.age_int AS users_age_int, "
@@ -321,7 +375,19 @@ class UpdateDeleteTest(fixtures.MappedTest):
                 "WHERE users.id = :param_1",
                 [{"param_1": 3}],
             ),
-        )
+        ]
+
+        if expire_jane_age and not add_filter_criteria:
+            to_assert.append(
+                # refresh jane
+                CompiledSQL(
+                    "SELECT users.age_int AS users_age_int, "
+                    "users.name AS users_name FROM users "
+                    "WHERE users.id = :param_1",
+                    [{"param_1": 4}],
+                )
+            )
+        asserter.assert_(*to_assert)
 
     def test_fetch_dont_refresh_expired_objects(self):
         User = self.classes.User
@@ -336,7 +402,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         with self.sql_execution_asserter() as asserter:
             # using 1.x style for easier backport
-            sess.query(User).update(
+            sess.query(User).filter(User.name != None).update(
                 {"age": User.age + 10}, synchronize_session="fetch"
             )
 
@@ -344,6 +410,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
             asserter.assert_(
                 CompiledSQL(
                     "UPDATE users SET age_int=(users.age_int + %(age_int_1)s) "
+                    "WHERE users.name IS NOT NULL "
                     "RETURNING users.id",
                     [{"age_int_1": 10}],
                     dialect="postgresql",
@@ -351,9 +418,13 @@ class UpdateDeleteTest(fixtures.MappedTest):
             )
         else:
             asserter.assert_(
-                CompiledSQL("SELECT users.id FROM users"),
                 CompiledSQL(
-                    "UPDATE users SET age_int=(users.age_int + :age_int_1)",
+                    "SELECT users.id FROM users "
+                    "WHERE users.name IS NOT NULL"
+                ),
+                CompiledSQL(
+                    "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+                    "WHERE users.name IS NOT NULL",
                     [{"age_int_1": 10}],
                 ),
             )