]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Pass along other keyword args in _EventsHold.populate
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Dec 2020 21:03:35 +0000 (16:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Dec 2020 21:06:04 +0000 (16:06 -0500)
Fixed bug involving the ``restore_load_context`` option of ORM events such
as :meth:`_orm.InstanceEvents.load` such that the flag would not be carried
along to subclasses which were mapped after the event handler were first
established.

Fixes: #5737
Change-Id: Ie65fbeac7ab223d24993cff0b34094d4928ff113
(cherry picked from commit 14c08d18885e16611b884bd76ba2811375de1731)

doc/build/changelog/unreleased_13/5737.rst [new file with mode: 0644]
lib/sqlalchemy/orm/events.py
test/orm/test_events.py

diff --git a/doc/build/changelog/unreleased_13/5737.rst b/doc/build/changelog/unreleased_13/5737.rst
new file mode 100644 (file)
index 0000000..7a1c3b5
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 5737
+    :versions: 1.4.0b2
+
+    Fixed bug involving the ``restore_load_context`` option of ORM events such
+    as :meth:`_orm.InstanceEvents.load` such that the flag would not be carried
+    along to subclasses which were mapped after the event handler were first
+    established.
+
+
index cede830644e14df2db615bc7c61f6426ba26e3d9..4cf10302273dea72454929018e95a6818cf64303 100644 (file)
@@ -540,7 +540,13 @@ class _EventsHold(event.RefCollection):
                 collection = target.all_holds[target.class_] = {}
 
             event.registry._stored_in_collection(event_key, target)
-            collection[event_key._key] = (event_key, raw, propagate, retval)
+            collection[event_key._key] = (
+                event_key,
+                raw,
+                propagate,
+                retval,
+                kw,
+            )
 
             if propagate:
                 stack = list(target.class_.__subclasses__())
@@ -567,7 +573,13 @@ class _EventsHold(event.RefCollection):
         for subclass in class_.__mro__:
             if subclass in cls.all_holds:
                 collection = cls.all_holds[subclass]
-                for event_key, raw, propagate, retval in collection.values():
+                for (
+                    event_key,
+                    raw,
+                    propagate,
+                    retval,
+                    kw,
+                ) in collection.values():
                     if propagate or subclass is class_:
                         # since we can't be sure in what order different
                         # classes in a hierarchy are triggered with
@@ -575,7 +587,7 @@ class _EventsHold(event.RefCollection):
                         # assignment, instead of using the generic propagate
                         # flag.
                         event_key.with_dispatch_target(subject).listen(
-                            raw=raw, propagate=False, retval=retval
+                            raw=raw, propagate=False, retval=retval, **kw
                         )
 
 
index e76cba03bad9552d6ea51ac30911f2c0b917c271..00761e4465ea048ae483d257524971c2348936fd 100644 (file)
@@ -716,6 +716,32 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
         event.listen(target, event_name, fn, restore_load_context=True)
         s.query(A).all()
 
+    @testing.combinations(
+        ("load", lambda instance, context: instance.unloaded),
+        (
+            "refresh",
+            lambda instance, context, attrs: instance.unloaded,
+        ),
+    )
+    def test_flag_resolves_existing_for_subclass(self, event_name, fn):
+        Base = declarative_base()
+
+        event.listen(
+            Base, event_name, fn, propagate=True, restore_load_context=True
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            unloaded = deferred(Column(String(50)))
+
+        s = Session(testing.db)
+
+        a1 = s.query(A).all()[0]
+        if event_name == "refresh":
+            s.refresh(a1)
+        s.close()
+
     @_combinations
     def test_flag_resolves(self, target, event_name, fn):
         A = self.classes.A