]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Propagate key for collection events
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Aug 2022 14:53:11 +0000 (10:53 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 17 Aug 2022 00:05:32 +0000 (20:05 -0400)
Added new parameter :paramref:`_orm.AttributeEvents.include_key`, which
will include the dictionary or list key for operations such as
``__setitem__()`` (e.g. ``obj[key] = value``) and ``__delitem__()`` (e.g.
``del obj[key]``), using a new keyword parameter "key" or "keys", depending
on event, e.g. :paramref:`_orm.AttributeEvents.append.key`,
:paramref:`_orm.AttributeEvents.bulk_replace.keys`. This allows event
handlers to take into account the key that was passed to the operation and
is of particular importance for dictionary operations working with
:class:`_orm.MappedCollection`.

Fixes: #8375
Change-Id: Icc472f7c28848f94e15c94a399cc13a88782e1e4

doc/build/changelog/unreleased_20/8375.rst [new file with mode: 0644]
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/test_attributes.py

diff --git a/doc/build/changelog/unreleased_20/8375.rst b/doc/build/changelog/unreleased_20/8375.rst
new file mode 100644 (file)
index 0000000..0fb0327
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: feature, orm
+    :tickets: 8375
+
+    Added new parameter :paramref:`_orm.AttributeEvents.include_key`, which
+    will include the dictionary or list key for operations such as
+    ``__setitem__()`` (e.g. ``obj[key] = value``) and ``__delitem__()`` (e.g.
+    ``del obj[key]``), using a new keyword parameter "key" or "keys", depending
+    on event, e.g. :paramref:`_orm.AttributeEvents.append.key`,
+    :paramref:`_orm.AttributeEvents.bulk_replace.keys`. This allows event
+    handlers to take into account the key that was passed to the operation and
+    is of particular importance for dictionary operations working with
+    :class:`_orm.MappedCollection`.
+
index cda58d6a5fa669b93ca54e3deacac32bde024d31..3a0f425fc4a41703dafe0ff277ffed86b149632d 100644 (file)
@@ -82,6 +82,7 @@ from .interfaces import InspectionAttrInfo as InspectionAttrInfo
 from .interfaces import MANYTOMANY as MANYTOMANY
 from .interfaces import MANYTOONE as MANYTOONE
 from .interfaces import MapperProperty as MapperProperty
+from .interfaces import NO_KEY as NO_KEY
 from .interfaces import ONETOMANY as ONETOMANY
 from .interfaces import PropComparator as PropComparator
 from .interfaces import UserDefinedOption as UserDefinedOption
index bb7eda5ac2274acfa48696fb74a03061186767ae..db86d0810a7555060055feb18557110147f08aea 100644 (file)
@@ -1717,9 +1717,10 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         dict_: _InstanceDict,
         value: _T,
         initiator: Optional[AttributeEventToken],
+        key: Optional[Any],
     ) -> _T:
         for fn in self.dispatch.append:
-            value = fn(state, value, initiator or self._append_token)
+            value = fn(state, value, initiator or self._append_token, key=key)
 
         state._modified_event(dict_, self, NO_VALUE, True)
 
@@ -1734,9 +1735,10 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         dict_: _InstanceDict,
         value: _T,
         initiator: Optional[AttributeEventToken],
+        key: Optional[Any],
     ) -> _T:
         for fn in self.dispatch.append_wo_mutation:
-            value = fn(state, value, initiator or self._append_token)
+            value = fn(state, value, initiator or self._append_token, key=key)
 
         return value
 
@@ -1745,6 +1747,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         state: InstanceState[Any],
         dict_: _InstanceDict,
         initiator: Optional[AttributeEventToken],
+        key: Optional[Any],
     ) -> None:
         """A special event used for pop() operations.
 
@@ -1762,12 +1765,13 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         dict_: _InstanceDict,
         value: Any,
         initiator: Optional[AttributeEventToken],
+        key: Optional[Any],
     ) -> None:
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), state, False)
 
         for fn in self.dispatch.remove:
-            fn(state, value, initiator or self._remove_token)
+            fn(state, value, initiator or self._remove_token, key=key)
 
         state._modified_event(dict_, self, NO_VALUE, True)
 
@@ -1825,7 +1829,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
             state, dict_, user_data=None, passive=passive
         )
         if collection is PASSIVE_NO_RESULT:
-            value = self.fire_append_event(state, dict_, value, initiator)
+            value = self.fire_append_event(
+                state, dict_, value, initiator, key=NO_KEY
+            )
             assert (
                 self.key not in dict_
             ), "Collection was loaded during event handling."
@@ -1847,7 +1853,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
             state, state.dict, user_data=None, passive=passive
         )
         if collection is PASSIVE_NO_RESULT:
-            self.fire_remove_event(state, dict_, value, initiator)
+            self.fire_remove_event(state, dict_, value, initiator, key=NO_KEY)
             assert (
                 self.key not in dict_
             ), "Collection was loaded during event handling."
@@ -1885,6 +1891,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         _adapt: bool = True,
     ) -> None:
         iterable = orig_iterable = value
+        new_keys = None
 
         # pulling a new collection first so that an adaptation exception does
         # not trigger a lazy load of the old collection.
@@ -1913,14 +1920,18 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
                 if hasattr(iterable, "_sa_iterator"):
                     iterable = iterable._sa_iterator()
                 elif setting_type is dict:
+                    new_keys = list(iterable)
                     iterable = iterable.values()
                 else:
                     iterable = iter(iterable)
+        elif util.duck_type_collection(iterable) is dict:
+            new_keys = list(value)
+
         new_values = list(iterable)
 
         evt = self._bulk_replace_token
 
-        self.dispatch.bulk_replace(state, new_values, evt)
+        self.dispatch.bulk_replace(state, new_values, evt, keys=new_keys)
 
         old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT)
         if old is PASSIVE_NO_RESULT:
@@ -2081,7 +2092,9 @@ def backref_listeners(
             )
         )
 
-    def emit_backref_from_scalar_set_event(state, child, oldchild, initiator):
+    def emit_backref_from_scalar_set_event(
+        state, child, oldchild, initiator, **kw
+    ):
         if oldchild is child:
             return child
         if (
@@ -2146,7 +2159,9 @@ def backref_listeners(
                 )
         return child
 
-    def emit_backref_from_collection_append_event(state, child, initiator):
+    def emit_backref_from_collection_append_event(
+        state, child, initiator, **kw
+    ):
         if child is None:
             return
 
@@ -2180,7 +2195,9 @@ def backref_listeners(
             )
         return child
 
-    def emit_backref_from_collection_remove_event(state, child, initiator):
+    def emit_backref_from_collection_remove_event(
+        state, child, initiator, **kw
+    ):
         if (
             child is not None
             and child is not PASSIVE_NO_RESULT
@@ -2234,6 +2251,7 @@ def backref_listeners(
             emit_backref_from_collection_append_event,
             retval=True,
             raw=True,
+            include_key=True,
         )
     else:
         event.listen(
@@ -2242,6 +2260,7 @@ def backref_listeners(
             emit_backref_from_scalar_set_event,
             retval=True,
             raw=True,
+            include_key=True,
         )
     # TODO: need coverage in test/orm/ of remove event
     event.listen(
@@ -2250,6 +2269,7 @@ def backref_listeners(
         emit_backref_from_collection_remove_event,
         retval=True,
         raw=True,
+        include_key=True,
     )
 
 
index fa653a472da649b7a3a6b8571c14f2e620937028..66b7b8c2e31d227446d55c73ac8fa6a9655d3d71 100644 (file)
@@ -191,9 +191,21 @@ class PassiveFlag(FastIntFlag):
 DEFAULT_MANAGER_ATTR = "_sa_class_manager"
 DEFAULT_STATE_ATTR = "_sa_instance_state"
 
-EXT_CONTINUE = util.symbol("EXT_CONTINUE")
-EXT_STOP = util.symbol("EXT_STOP")
-EXT_SKIP = util.symbol("EXT_SKIP")
+
+class EventConstants(Enum):
+    EXT_CONTINUE = 1
+    EXT_STOP = 2
+    EXT_SKIP = 3
+    NO_KEY = 4
+    """indicates an :class:`.AttributeEvent` event that did not have any
+    key argument.
+
+    .. versionadded:: 2.0
+
+    """
+
+
+EXT_CONTINUE, EXT_STOP, EXT_SKIP, NO_KEY = tuple(EventConstants)
 
 
 class RelationshipDirection(Enum):
index f47d00634e2432fbcf9b570ae62c83177e2b72a9..5dbd2dc30579f67c9b4c454699cd76dcd096acf9 100644 (file)
@@ -125,6 +125,7 @@ from typing import TypeVar
 from typing import Union
 import weakref
 
+from .base import NO_KEY
 from .. import exc as sa_exc
 from .. import util
 from ..util.compat import inspect_getfullargspec
@@ -614,7 +615,7 @@ class CollectionAdapter:
     def __bool__(self):
         return True
 
-    def fire_append_wo_mutation_event(self, item, initiator=None):
+    def fire_append_wo_mutation_event(self, item, initiator=None, key=NO_KEY):
         """Notify that a entity is entering the collection but is already
         present.
 
@@ -635,12 +636,12 @@ class CollectionAdapter:
                 self._reset_empty()
 
             return self.attr.fire_append_wo_mutation_event(
-                self.owner_state, self.owner_state.dict, item, initiator
+                self.owner_state, self.owner_state.dict, item, initiator, key
             )
         else:
             return item
 
-    def fire_append_event(self, item, initiator=None):
+    def fire_append_event(self, item, initiator=None, key=NO_KEY):
         """Notify that a entity has entered the collection.
 
         Initiator is a token owned by the InstrumentedAttribute that
@@ -657,12 +658,12 @@ class CollectionAdapter:
                 self._reset_empty()
 
             return self.attr.fire_append_event(
-                self.owner_state, self.owner_state.dict, item, initiator
+                self.owner_state, self.owner_state.dict, item, initiator, key
             )
         else:
             return item
 
-    def fire_remove_event(self, item, initiator=None):
+    def fire_remove_event(self, item, initiator=None, key=NO_KEY):
         """Notify that a entity has been removed from the collection.
 
         Initiator is the InstrumentedAttribute that initiated the membership
@@ -678,10 +679,10 @@ class CollectionAdapter:
                 self._reset_empty()
 
             self.attr.fire_remove_event(
-                self.owner_state, self.owner_state.dict, item, initiator
+                self.owner_state, self.owner_state.dict, item, initiator, key
             )
 
-    def fire_pre_remove_event(self, initiator=None):
+    def fire_pre_remove_event(self, initiator=None, key=NO_KEY):
         """Notify that an entity is about to be removed from the collection.
 
         Only called if the entity cannot be removed after calling
@@ -691,7 +692,10 @@ class CollectionAdapter:
         if self.invalidated:
             self._warn_invalidated()
         self.attr.fire_pre_remove_event(
-            self.owner_state, self.owner_state.dict, initiator=initiator
+            self.owner_state,
+            self.owner_state.dict,
+            initiator=initiator,
+            key=key,
         )
 
     def __getstate__(self):
@@ -1025,10 +1029,12 @@ def __set_wo_mutation(collection, item, _sa_initiator=None):
     if _sa_initiator is not False:
         executor = collection._sa_adapter
         if executor:
-            executor.fire_append_wo_mutation_event(item, _sa_initiator)
+            executor.fire_append_wo_mutation_event(
+                item, _sa_initiator, key=None
+            )
 
 
-def __set(collection, item, _sa_initiator=None):
+def __set(collection, item, _sa_initiator, key):
     """Run set events.
 
     This event always occurs before the collection is actually mutated.
@@ -1038,11 +1044,11 @@ def __set(collection, item, _sa_initiator=None):
     if _sa_initiator is not False:
         executor = collection._sa_adapter
         if executor:
-            item = executor.fire_append_event(item, _sa_initiator)
+            item = executor.fire_append_event(item, _sa_initiator, key=key)
     return item
 
 
-def __del(collection, item, _sa_initiator=None):
+def __del(collection, item, _sa_initiator, key):
     """Run del events.
 
     This event occurs before the collection is actually mutated, *except*
@@ -1054,7 +1060,7 @@ def __del(collection, item, _sa_initiator=None):
     if _sa_initiator is not False:
         executor = collection._sa_adapter
         if executor:
-            executor.fire_remove_event(item, _sa_initiator)
+            executor.fire_remove_event(item, _sa_initiator, key=key)
 
 
 def __before_pop(collection, _sa_initiator=None):
@@ -1073,7 +1079,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
 
     def append(fn):
         def append(self, item, _sa_initiator=None):
-            item = __set(self, item, _sa_initiator)
+            item = __set(self, item, _sa_initiator, NO_KEY)
             fn(self, item)
 
         _tidy(append)
@@ -1081,7 +1087,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
 
     def remove(fn):
         def remove(self, value, _sa_initiator=None):
-            __del(self, value, _sa_initiator)
+            __del(self, value, _sa_initiator, NO_KEY)
             # testlib.pragma exempt:__eq__
             fn(self, value)
 
@@ -1090,7 +1096,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
 
     def insert(fn):
         def insert(self, index, value):
-            value = __set(self, value)
+            value = __set(self, value, None, index)
             fn(self, index, value)
 
         _tidy(insert)
@@ -1101,8 +1107,8 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
             if not isinstance(index, slice):
                 existing = self[index]
                 if existing is not None:
-                    __del(self, existing)
-                value = __set(self, value)
+                    __del(self, existing, None, index)
+                value = __set(self, value, None, index)
                 fn(self, index, value)
             else:
                 # slice assignment requires __delitem__, insert, __len__
@@ -1144,14 +1150,14 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         def __delitem__(self, index):
             if not isinstance(index, slice):
                 item = self[index]
-                __del(self, item)
+                __del(self, item, None, index)
                 fn(self, index)
             else:
                 # slice deletion requires __getslice__ and a slice-groking
                 # __getitem__ for stepped deletion
                 # note: not breaking this into atomic dels
                 for item in self[index]:
-                    __del(self, item)
+                    __del(self, item, None, index)
                 fn(self, index)
 
         _tidy(__delitem__)
@@ -1180,7 +1186,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         def pop(self, index=-1):
             __before_pop(self)
             item = fn(self, index)
-            __del(self, item)
+            __del(self, item, None, index)
             return item
 
         _tidy(pop)
@@ -1189,7 +1195,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     def clear(fn):
         def clear(self, index=-1):
             for item in self:
-                __del(self, item)
+                __del(self, item, None, index)
             fn(self)
 
         _tidy(clear)
@@ -1217,8 +1223,8 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     def __setitem__(fn):
         def __setitem__(self, key, value, _sa_initiator=None):
             if key in self:
-                __del(self, self[key], _sa_initiator)
-            value = __set(self, value, _sa_initiator)
+                __del(self, self[key], _sa_initiator, key)
+            value = __set(self, value, _sa_initiator, key)
             fn(self, key, value)
 
         _tidy(__setitem__)
@@ -1227,7 +1233,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     def __delitem__(fn):
         def __delitem__(self, key, _sa_initiator=None):
             if key in self:
-                __del(self, self[key], _sa_initiator)
+                __del(self, self[key], _sa_initiator, key)
             fn(self, key)
 
         _tidy(__delitem__)
@@ -1236,7 +1242,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     def clear(fn):
         def clear(self):
             for key in self:
-                __del(self, self[key])
+                __del(self, self[key], None, key)
             fn(self)
 
         _tidy(clear)
@@ -1251,7 +1257,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
             else:
                 item = fn(self, key, default)
             if _to_del:
-                __del(self, item)
+                __del(self, item, None, key)
             return item
 
         _tidy(pop)
@@ -1261,7 +1267,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         def popitem(self):
             __before_pop(self)
             item = fn(self)
-            __del(self, item[1])
+            __del(self, item[1], None, 1)
             return item
 
         _tidy(popitem)
@@ -1341,7 +1347,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     def add(fn):
         def add(self, value, _sa_initiator=None):
             if value not in self:
-                value = __set(self, value, _sa_initiator)
+                value = __set(self, value, _sa_initiator, NO_KEY)
             else:
                 __set_wo_mutation(self, value, _sa_initiator)
             # testlib.pragma exempt:__hash__
@@ -1354,7 +1360,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         def discard(self, value, _sa_initiator=None):
             # testlib.pragma exempt:__hash__
             if value in self:
-                __del(self, value, _sa_initiator)
+                __del(self, value, _sa_initiator, NO_KEY)
                 # testlib.pragma exempt:__hash__
             fn(self, value)
 
@@ -1365,7 +1371,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         def remove(self, value, _sa_initiator=None):
             # testlib.pragma exempt:__hash__
             if value in self:
-                __del(self, value, _sa_initiator)
+                __del(self, value, _sa_initiator, NO_KEY)
             # testlib.pragma exempt:__hash__
             fn(self, value)
 
@@ -1378,7 +1384,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
             item = fn(self)
             # for set in particular, we have no way to access the item
             # that will be popped before pop is called.
-            __del(self, item)
+            __del(self, item, None, NO_KEY)
             return item
 
         _tidy(pop)
index 680e499815a040cd1cd6ad05aca24d9afab24697..c17ea1abedb29d8fa5abc2640515a5e61ad3d336 100644 (file)
@@ -22,6 +22,7 @@ from . import interfaces
 from . import mapperlib
 from .attributes import QueryableAttribute
 from .base import _mapper_or_none
+from .base import NO_KEY
 from .query import Query
 from .scoping import scoped_session
 from .session import Session
@@ -2288,6 +2289,7 @@ class AttributeEvents(event.Events):
         raw=False,
         retval=False,
         propagate=False,
+        include_key=False,
     ):
 
         target, fn = event_key.dispatch_target, event_key._listen_fn
@@ -2295,9 +2297,9 @@ class AttributeEvents(event.Events):
         if active_history:
             target.dispatch._active_history = True
 
-        if not raw or not retval:
+        if not raw or not retval or not include_key:
 
-            def wrap(target, *arg):
+            def wrap(target, *arg, **kw):
                 if not raw:
                     target = target.obj()
                 if not retval:
@@ -2305,10 +2307,16 @@ class AttributeEvents(event.Events):
                         value = arg[0]
                     else:
                         value = None
-                    fn(target, *arg)
+                    if include_key:
+                        fn(target, *arg, **kw)
+                    else:
+                        fn(target, *arg)
                     return value
                 else:
-                    return fn(target, *arg)
+                    if include_key:
+                        return fn(target, *arg, **kw)
+                    else:
+                        return fn(target, *arg)
 
             event_key = event_key.with_wrapper(wrap)
 
@@ -2324,7 +2332,7 @@ class AttributeEvents(event.Events):
                 if active_history:
                     mgr[target.key].dispatch._active_history = True
 
-    def append(self, target, value, initiator):
+    def append(self, target, value, initiator, *, key=NO_KEY):
         """Receive a collection append event.
 
         The append event is invoked for each element as it is appended
@@ -2343,6 +2351,19 @@ class AttributeEvents(event.Events):
           from its original value by backref handlers in order to control
           chained event propagation, as well as be inspected for information
           about the source of the event.
+        :param key: When the event is established using the
+         :paramref:`.AttributeEvents.include_key` parameter set to
+         True, this will be the key used in the operation, such as
+         ``collection[some_key_or_index] = value``.
+         The parameter is not passed
+         to the event at all if the the
+         :paramref:`.AttributeEvents.include_key`
+         was not used to set up the event; this is to allow backwards
+         compatibility with existing event handlers that don't include the
+         ``key`` parameter.
+
+         .. versionadded:: 2.0
+
         :return: if the event was registered with ``retval=True``,
          the given value, or a new effective value, should be returned.
 
@@ -2355,7 +2376,7 @@ class AttributeEvents(event.Events):
 
         """
 
-    def append_wo_mutation(self, target, value, initiator):
+    def append_wo_mutation(self, target, value, initiator, *, key=NO_KEY):
         """Receive a collection append event where the collection was not
         actually mutated.
 
@@ -2378,6 +2399,18 @@ class AttributeEvents(event.Events):
           from its original value by backref handlers in order to control
           chained event propagation, as well as be inspected for information
           about the source of the event.
+        :param key: When the event is established using the
+         :paramref:`.AttributeEvents.include_key` parameter set to
+         True, this will be the key used in the operation, such as
+         ``collection[some_key_or_index] = value``.
+         The parameter is not passed
+         to the event at all if the the
+         :paramref:`.AttributeEvents.include_key`
+         was not used to set up the event; this is to allow backwards
+         compatibility with existing event handlers that don't include the
+         ``key`` parameter.
+
+         .. versionadded:: 2.0
 
         :return: No return value is defined for this event.
 
@@ -2385,7 +2418,7 @@ class AttributeEvents(event.Events):
 
         """
 
-    def bulk_replace(self, target, values, initiator):
+    def bulk_replace(self, target, values, initiator, *, keys=None):
         """Receive a collection 'bulk replace' event.
 
         This event is invoked for a sequence of values as they are incoming
@@ -2428,6 +2461,17 @@ class AttributeEvents(event.Events):
           handler can modify this list in place.
         :param initiator: An instance of :class:`.attributes.Event`
           representing the initiation of the event.
+        :param keys: When the event is established using the
+         :paramref:`.AttributeEvents.include_key` parameter set to
+         True, this will be the sequence of keys used in the operation,
+         typically only for a dictionary update.  The parameter is not passed
+         to the event at all if the the
+         :paramref:`.AttributeEvents.include_key`
+         was not used to set up the event; this is to allow backwards
+         compatibility with existing event handlers that don't include the
+         ``key`` parameter.
+
+         .. versionadded:: 2.0
 
         .. seealso::
 
@@ -2437,7 +2481,7 @@ class AttributeEvents(event.Events):
 
         """
 
-    def remove(self, target, value, initiator):
+    def remove(self, target, value, initiator, *, key=NO_KEY):
         """Receive a collection remove event.
 
         :param target: the object instance receiving the event.
@@ -2453,6 +2497,17 @@ class AttributeEvents(event.Events):
              passed as a :class:`.attributes.Event` object, and may be
              modified by backref handlers within a chain of backref-linked
              events.
+        :param key: When the event is established using the
+         :paramref:`.AttributeEvents.include_key` parameter set to
+         True, this will be the key used in the operation, such as
+         ``del collection[some_key_or_index]``.  The parameter is not passed
+         to the event at all if the the
+         :paramref:`.AttributeEvents.include_key`
+         was not used to set up the event; this is to allow backwards
+         compatibility with existing event handlers that don't include the
+         ``key`` parameter.
+
+         .. versionadded:: 2.0
 
         :return: No return value is defined for this event.
 
index 16062fffa3a255c5a67dbbbcdf57593576d7a9d7..72f5c6a7b1976253216cc716c3fbb7753a066b4d 100644 (file)
@@ -49,6 +49,7 @@ from .base import InspectionAttr as InspectionAttr  # noqa: F401
 from .base import InspectionAttrInfo as InspectionAttrInfo
 from .base import MANYTOMANY as MANYTOMANY  # noqa: F401
 from .base import MANYTOONE as MANYTOONE  # noqa: F401
+from .base import NO_KEY as NO_KEY  # noqa: F401
 from .base import NotExtension as NotExtension  # noqa: F401
 from .base import ONETOMANY as ONETOMANY  # noqa: F401
 from .base import RelationshipDirection as RelationshipDirection  # noqa: F401
index c83ffdb59892a9553c85c645c54784a812a632af..5e66653a381ec92e0aebc5b66a51ebc752c1afed 100644 (file)
@@ -47,7 +47,7 @@ def track_cascade_events(descriptor, prop):
     """
     key = prop.key
 
-    def append(state, item, initiator):
+    def append(state, item, initiator, **kw):
         # process "save_update" cascade rules for when
         # an instance is appended to the list of another instance
 
@@ -70,7 +70,7 @@ def track_cascade_events(descriptor, prop):
                 sess._save_or_update_state(item_state)
         return item
 
-    def remove(state, item, initiator):
+    def remove(state, item, initiator, **kw):
         if item is None:
             return
 
@@ -104,7 +104,7 @@ def track_cascade_events(descriptor, prop):
                     # item
                     item_state._orphaned_outside_of_session = True
 
-    def set_(state, newvalue, oldvalue, initiator):
+    def set_(state, newvalue, oldvalue, initiator, **kw):
         # process "save_update" cascade rules for when an instance
         # is attached to another instance
         if oldvalue is newvalue:
@@ -141,10 +141,18 @@ def track_cascade_events(descriptor, prop):
                     sess.expunge(oldvalue)
         return newvalue
 
-    event.listen(descriptor, "append_wo_mutation", append, raw=True)
-    event.listen(descriptor, "append", append, raw=True, retval=True)
-    event.listen(descriptor, "remove", remove, raw=True, retval=True)
-    event.listen(descriptor, "set", set_, raw=True, retval=True)
+    event.listen(
+        descriptor, "append_wo_mutation", append, raw=True, include_key=True
+    )
+    event.listen(
+        descriptor, "append", append, raw=True, retval=True, include_key=True
+    )
+    event.listen(
+        descriptor, "remove", remove, raw=True, retval=True, include_key=True
+    )
+    event.listen(
+        descriptor, "set", set_, raw=True, retval=True, include_key=True
+    )
 
 
 class UOWTransaction:
index e1274a8051e6dbebe943812752055d0ce96102b1..53b306f5b1d48de6391777dea24451a7d6e5eb54 100644 (file)
@@ -8,6 +8,8 @@ from sqlalchemy import testing
 from sqlalchemy.orm import attributes
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import instrumentation
+from sqlalchemy.orm import NO_KEY
+from sqlalchemy.orm.collections import attribute_mapped_collection
 from sqlalchemy.orm.collections import collection
 from sqlalchemy.orm.state import InstanceState
 from sqlalchemy.testing import assert_raises
@@ -23,7 +25,6 @@ from sqlalchemy.testing.assertions import assert_warns
 from sqlalchemy.testing.util import all_partial_orderings
 from sqlalchemy.testing.util import gc_collect
 
-
 # global for pickling tests
 MyTest = None
 MyTest2 = None
@@ -2576,8 +2577,6 @@ class HistoryTest(fixtures.TestBase):
         class Bar(fixtures.BasicEntity):
             pass
 
-        from sqlalchemy.orm.collections import attribute_mapped_collection
-
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
         _register_attribute(
@@ -3193,6 +3192,239 @@ class LazyloadHistoryTest(fixtures.TestBase):
         )
 
 
+class CollectionKeyTest(fixtures.ORMTest):
+    @testing.fixture
+    def dict_collection(self):
+        class Foo(fixtures.BasicEntity):
+            pass
+
+        class Bar(fixtures.BasicEntity):
+            def __init__(self, name):
+                self.name = name
+
+        instrumentation.register_class(Foo)
+        instrumentation.register_class(Bar)
+        _register_attribute(
+            Foo,
+            "someattr",
+            uselist=True,
+            useobject=True,
+            typecallable=attribute_mapped_collection("name"),
+        )
+        _register_attribute(
+            Bar,
+            "name",
+            uselist=False,
+            useobject=False,
+        )
+
+        return Foo, Bar
+
+    @testing.fixture
+    def list_collection(self):
+        class Foo(fixtures.BasicEntity):
+            pass
+
+        class Bar(fixtures.BasicEntity):
+            pass
+
+        instrumentation.register_class(Foo)
+        instrumentation.register_class(Bar)
+        _register_attribute(
+            Foo,
+            "someattr",
+            uselist=True,
+            useobject=True,
+        )
+
+        return Foo, Bar
+
+    def test_listen_w_list_key(self, list_collection):
+        Foo, Bar = list_collection
+
+        m1 = Mock()
+
+        event.listen(Foo.someattr, "append", m1, include_key=True)
+        event.listen(Foo.someattr, "remove", m1, include_key=True)
+
+        f1 = Foo()
+        b1, b2, b3 = Bar(), Bar(), Bar()
+        f1.someattr.append(b1)
+        f1.someattr.append(b2)
+        f1.someattr[1] = b3
+        del f1.someattr[0]
+        append_token, remove_token = (
+            Foo.someattr.impl._append_token,
+            Foo.someattr.impl._remove_token,
+        )
+
+        eq_(
+            m1.mock_calls,
+            [
+                call(
+                    f1,
+                    b1,
+                    append_token,
+                    key=NO_KEY,
+                ),
+                call(
+                    f1,
+                    b2,
+                    append_token,
+                    key=NO_KEY,
+                ),
+                call(
+                    f1,
+                    b2,
+                    remove_token,
+                    key=1,
+                ),
+                call(
+                    f1,
+                    b3,
+                    append_token,
+                    key=1,
+                ),
+                call(
+                    f1,
+                    b1,
+                    remove_token,
+                    key=0,
+                ),
+            ],
+        )
+
+    def test_listen_w_dict_key(self, dict_collection):
+        Foo, Bar = dict_collection
+
+        m1 = Mock()
+
+        event.listen(Foo.someattr, "append", m1, include_key=True)
+        event.listen(Foo.someattr, "remove", m1, include_key=True)
+
+        f1 = Foo()
+        b1, b2, b3 = Bar("b1"), Bar("b2"), Bar("b3")
+        f1.someattr["k1"] = b1
+        f1.someattr.update({"k2": b2, "k3": b3})
+
+        del f1.someattr["k2"]
+
+        append_token, remove_token = (
+            Foo.someattr.impl._append_token,
+            Foo.someattr.impl._remove_token,
+        )
+
+        eq_(
+            m1.mock_calls,
+            [
+                call(
+                    f1,
+                    b1,
+                    append_token,
+                    key="k1",
+                ),
+                call(
+                    f1,
+                    b2,
+                    append_token,
+                    key="k2",
+                ),
+                call(
+                    f1,
+                    b3,
+                    append_token,
+                    key="k3",
+                ),
+                call(
+                    f1,
+                    b2,
+                    remove_token,
+                    key="k2",
+                ),
+            ],
+        )
+
+    def test_dict_bulk_replace_w_key(self, dict_collection):
+        Foo, Bar = dict_collection
+
+        m1 = Mock()
+
+        event.listen(Foo.someattr, "bulk_replace", m1, include_key=True)
+        event.listen(Foo.someattr, "append", m1, include_key=True)
+        event.listen(Foo.someattr, "remove", m1, include_key=True)
+
+        f1 = Foo()
+        b1, b2, b3, b4 = Bar("b1"), Bar("b2"), Bar("b3"), Bar("b4")
+        f1.someattr = {"b1": b1, "b3": b3}
+        f1.someattr = {"b2": b2, "b3": b3, "b4": b4}
+
+        bulk_replace_token = Foo.someattr.impl._bulk_replace_token
+
+        eq_(
+            m1.mock_calls,
+            [
+                call(f1, [b1, b3], bulk_replace_token, keys=["b1", "b3"]),
+                call(f1, b1, bulk_replace_token, key="b1"),
+                call(f1, b3, bulk_replace_token, key="b3"),
+                call(
+                    f1,
+                    [b2, b3, b4],
+                    bulk_replace_token,
+                    keys=["b2", "b3", "b4"],
+                ),
+                call(f1, b2, bulk_replace_token, key="b2"),
+                call(f1, b4, bulk_replace_token, key="b4"),
+                call(f1, b1, bulk_replace_token, key=NO_KEY),
+            ],
+        )
+
+    def test_listen_wo_dict_key(self, dict_collection):
+        Foo, Bar = dict_collection
+
+        m1 = Mock()
+
+        event.listen(Foo.someattr, "append", m1)
+        event.listen(Foo.someattr, "remove", m1)
+
+        f1 = Foo()
+        b1, b2, b3 = Bar("b1"), Bar("b2"), Bar("b3")
+        f1.someattr["k1"] = b1
+        f1.someattr.update({"k2": b2, "k3": b3})
+
+        del f1.someattr["k2"]
+
+        append_token, remove_token = (
+            Foo.someattr.impl._append_token,
+            Foo.someattr.impl._remove_token,
+        )
+
+        eq_(
+            m1.mock_calls,
+            [
+                call(
+                    f1,
+                    b1,
+                    append_token,
+                ),
+                call(
+                    f1,
+                    b2,
+                    append_token,
+                ),
+                call(
+                    f1,
+                    b3,
+                    append_token,
+                ),
+                call(
+                    f1,
+                    b2,
+                    remove_token,
+                ),
+            ],
+        )
+
+
 class ListenerTest(fixtures.ORMTest):
     def test_receive_changes(self):
         """test that Listeners can mutate the given value."""