]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Pass along other keyword args in _EventsHold.populate
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Dec 2020 21:03:35 +0000 (16:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Dec 2020 21:04:06 +0000 (16:04 -0500)
Fixed bug involving the ``restore_load_context`` option of ORM events such
as :meth:`_orm.InstanceEvents.load` such that the flag would not be carried
along to subclasses which were mapped after the event handler were first
established.

Fixes: #5737
Change-Id: Ie65fbeac7ab223d24993cff0b34094d4928ff113

doc/build/changelog/unreleased_13/5737.rst [new file with mode: 0644]
lib/sqlalchemy/orm/events.py
test/orm/test_events.py

diff --git a/doc/build/changelog/unreleased_13/5737.rst b/doc/build/changelog/unreleased_13/5737.rst
new file mode 100644 (file)
index 0000000..7a1c3b5
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 5737
+    :versions: 1.4.0b2
+
+    Fixed bug involving the ``restore_load_context`` option of ORM events such
+    as :meth:`_orm.InstanceEvents.load` such that the flag would not be carried
+    along to subclasses which were mapped after the event handler were first
+    established.
+
+
index a0e100a816867d2d1cf35e5f0e860c2ae7c9d1c5..9f9ebd4612ea03e85a9ecda2e876fe076435f6d8 100644 (file)
@@ -550,7 +550,13 @@ class _EventsHold(event.RefCollection):
                 collection = target.all_holds[target.class_] = {}
 
             event.registry._stored_in_collection(event_key, target)
-            collection[event_key._key] = (event_key, raw, propagate, retval)
+            collection[event_key._key] = (
+                event_key,
+                raw,
+                propagate,
+                retval,
+                kw,
+            )
 
             if propagate:
                 stack = list(target.class_.__subclasses__())
@@ -577,7 +583,13 @@ class _EventsHold(event.RefCollection):
         for subclass in class_.__mro__:
             if subclass in cls.all_holds:
                 collection = cls.all_holds[subclass]
-                for event_key, raw, propagate, retval in collection.values():
+                for (
+                    event_key,
+                    raw,
+                    propagate,
+                    retval,
+                    kw,
+                ) in collection.values():
                     if propagate or subclass is class_:
                         # since we can't be sure in what order different
                         # classes in a hierarchy are triggered with
@@ -585,7 +597,7 @@ class _EventsHold(event.RefCollection):
                         # assignment, instead of using the generic propagate
                         # flag.
                         event_key.with_dispatch_target(subject).listen(
-                            raw=raw, propagate=False, retval=retval
+                            raw=raw, propagate=False, retval=retval, **kw
                         )
 
 
index 2851622414e410931e511f97c855b5c77c5caa9b..f8600894fdf6e5d2e98130978fb81e054b4ff10c 100644 (file)
@@ -900,6 +900,32 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
         s.query(A).all()
         s.close()
 
+    @testing.combinations(
+        ("load", lambda instance, context: instance.unloaded),
+        (
+            "refresh",
+            lambda instance, context, attrs: instance.unloaded,
+        ),
+    )
+    def test_flag_resolves_existing_for_subclass(self, event_name, fn):
+        Base = declarative_base()
+
+        event.listen(
+            Base, event_name, fn, propagate=True, restore_load_context=True
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            unloaded = deferred(Column(String(50)))
+
+        s = Session(testing.db)
+
+        a1 = s.query(A).all()[0]
+        if event_name == "refresh":
+            s.refresh(a1)
+        s.close()
+
     @_combinations
     def test_flag_resolves(self, target, event_name, fn):
         A = self.classes.A