From: Mike Bayer Date: Wed, 10 Aug 2022 14:53:11 +0000 (-0400) Subject: Propagate key for collection events X-Git-Tag: rel_2_0_0b1~112^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6cef8526226ab6033dfef1f793be87bff2160c04;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Propagate key for collection events Added new parameter :paramref:`_orm.AttributeEvents.include_key`, which will include the dictionary or list key for operations such as ``__setitem__()`` (e.g. ``obj[key] = value``) and ``__delitem__()`` (e.g. ``del obj[key]``), using a new keyword parameter "key" or "keys", depending on event, e.g. :paramref:`_orm.AttributeEvents.append.key`, :paramref:`_orm.AttributeEvents.bulk_replace.keys`. This allows event handlers to take into account the key that was passed to the operation and is of particular importance for dictionary operations working with :class:`_orm.MappedCollection`. Fixes: #8375 Change-Id: Icc472f7c28848f94e15c94a399cc13a88782e1e4 --- diff --git a/doc/build/changelog/unreleased_20/8375.rst b/doc/build/changelog/unreleased_20/8375.rst new file mode 100644 index 0000000000..0fb03275b3 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8375.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: feature, orm + :tickets: 8375 + + Added new parameter :paramref:`_orm.AttributeEvents.include_key`, which + will include the dictionary or list key for operations such as + ``__setitem__()`` (e.g. ``obj[key] = value``) and ``__delitem__()`` (e.g. + ``del obj[key]``), using a new keyword parameter "key" or "keys", depending + on event, e.g. :paramref:`_orm.AttributeEvents.append.key`, + :paramref:`_orm.AttributeEvents.bulk_replace.keys`. This allows event + handlers to take into account the key that was passed to the operation and + is of particular importance for dictionary operations working with + :class:`_orm.MappedCollection`. + diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index cda58d6a5f..3a0f425fc4 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -82,6 +82,7 @@ from .interfaces import InspectionAttrInfo as InspectionAttrInfo from .interfaces import MANYTOMANY as MANYTOMANY from .interfaces import MANYTOONE as MANYTOONE from .interfaces import MapperProperty as MapperProperty +from .interfaces import NO_KEY as NO_KEY from .interfaces import ONETOMANY as ONETOMANY from .interfaces import PropComparator as PropComparator from .interfaces import UserDefinedOption as UserDefinedOption diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index bb7eda5ac2..db86d0810a 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1717,9 +1717,10 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): dict_: _InstanceDict, value: _T, initiator: Optional[AttributeEventToken], + key: Optional[Any], ) -> _T: for fn in self.dispatch.append: - value = fn(state, value, initiator or self._append_token) + value = fn(state, value, initiator or self._append_token, key=key) state._modified_event(dict_, self, NO_VALUE, True) @@ -1734,9 +1735,10 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): dict_: _InstanceDict, value: _T, initiator: Optional[AttributeEventToken], + key: Optional[Any], ) -> _T: for fn in self.dispatch.append_wo_mutation: - value = fn(state, value, initiator or self._append_token) + value = fn(state, value, initiator or self._append_token, key=key) return value @@ -1745,6 +1747,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): state: InstanceState[Any], dict_: _InstanceDict, initiator: Optional[AttributeEventToken], + key: Optional[Any], ) -> None: """A special event used for pop() operations. @@ -1762,12 +1765,13 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): dict_: _InstanceDict, value: Any, initiator: Optional[AttributeEventToken], + key: Optional[Any], ) -> None: if self.trackparent and value is not None: self.sethasparent(instance_state(value), state, False) for fn in self.dispatch.remove: - fn(state, value, initiator or self._remove_token) + fn(state, value, initiator or self._remove_token, key=key) state._modified_event(dict_, self, NO_VALUE, True) @@ -1825,7 +1829,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): state, dict_, user_data=None, passive=passive ) if collection is PASSIVE_NO_RESULT: - value = self.fire_append_event(state, dict_, value, initiator) + value = self.fire_append_event( + state, dict_, value, initiator, key=NO_KEY + ) assert ( self.key not in dict_ ), "Collection was loaded during event handling." @@ -1847,7 +1853,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): state, state.dict, user_data=None, passive=passive ) if collection is PASSIVE_NO_RESULT: - self.fire_remove_event(state, dict_, value, initiator) + self.fire_remove_event(state, dict_, value, initiator, key=NO_KEY) assert ( self.key not in dict_ ), "Collection was loaded during event handling." @@ -1885,6 +1891,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): _adapt: bool = True, ) -> None: iterable = orig_iterable = value + new_keys = None # pulling a new collection first so that an adaptation exception does # not trigger a lazy load of the old collection. @@ -1913,14 +1920,18 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): if hasattr(iterable, "_sa_iterator"): iterable = iterable._sa_iterator() elif setting_type is dict: + new_keys = list(iterable) iterable = iterable.values() else: iterable = iter(iterable) + elif util.duck_type_collection(iterable) is dict: + new_keys = list(value) + new_values = list(iterable) evt = self._bulk_replace_token - self.dispatch.bulk_replace(state, new_values, evt) + self.dispatch.bulk_replace(state, new_values, evt, keys=new_keys) old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT) if old is PASSIVE_NO_RESULT: @@ -2081,7 +2092,9 @@ def backref_listeners( ) ) - def emit_backref_from_scalar_set_event(state, child, oldchild, initiator): + def emit_backref_from_scalar_set_event( + state, child, oldchild, initiator, **kw + ): if oldchild is child: return child if ( @@ -2146,7 +2159,9 @@ def backref_listeners( ) return child - def emit_backref_from_collection_append_event(state, child, initiator): + def emit_backref_from_collection_append_event( + state, child, initiator, **kw + ): if child is None: return @@ -2180,7 +2195,9 @@ def backref_listeners( ) return child - def emit_backref_from_collection_remove_event(state, child, initiator): + def emit_backref_from_collection_remove_event( + state, child, initiator, **kw + ): if ( child is not None and child is not PASSIVE_NO_RESULT @@ -2234,6 +2251,7 @@ def backref_listeners( emit_backref_from_collection_append_event, retval=True, raw=True, + include_key=True, ) else: event.listen( @@ -2242,6 +2260,7 @@ def backref_listeners( emit_backref_from_scalar_set_event, retval=True, raw=True, + include_key=True, ) # TODO: need coverage in test/orm/ of remove event event.listen( @@ -2250,6 +2269,7 @@ def backref_listeners( emit_backref_from_collection_remove_event, retval=True, raw=True, + include_key=True, ) diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index fa653a472d..66b7b8c2e3 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -191,9 +191,21 @@ class PassiveFlag(FastIntFlag): DEFAULT_MANAGER_ATTR = "_sa_class_manager" DEFAULT_STATE_ATTR = "_sa_instance_state" -EXT_CONTINUE = util.symbol("EXT_CONTINUE") -EXT_STOP = util.symbol("EXT_STOP") -EXT_SKIP = util.symbol("EXT_SKIP") + +class EventConstants(Enum): + EXT_CONTINUE = 1 + EXT_STOP = 2 + EXT_SKIP = 3 + NO_KEY = 4 + """indicates an :class:`.AttributeEvent` event that did not have any + key argument. + + .. versionadded:: 2.0 + + """ + + +EXT_CONTINUE, EXT_STOP, EXT_SKIP, NO_KEY = tuple(EventConstants) class RelationshipDirection(Enum): diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index f47d00634e..5dbd2dc305 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -125,6 +125,7 @@ from typing import TypeVar from typing import Union import weakref +from .base import NO_KEY from .. import exc as sa_exc from .. import util from ..util.compat import inspect_getfullargspec @@ -614,7 +615,7 @@ class CollectionAdapter: def __bool__(self): return True - def fire_append_wo_mutation_event(self, item, initiator=None): + def fire_append_wo_mutation_event(self, item, initiator=None, key=NO_KEY): """Notify that a entity is entering the collection but is already present. @@ -635,12 +636,12 @@ class CollectionAdapter: self._reset_empty() return self.attr.fire_append_wo_mutation_event( - self.owner_state, self.owner_state.dict, item, initiator + self.owner_state, self.owner_state.dict, item, initiator, key ) else: return item - def fire_append_event(self, item, initiator=None): + def fire_append_event(self, item, initiator=None, key=NO_KEY): """Notify that a entity has entered the collection. Initiator is a token owned by the InstrumentedAttribute that @@ -657,12 +658,12 @@ class CollectionAdapter: self._reset_empty() return self.attr.fire_append_event( - self.owner_state, self.owner_state.dict, item, initiator + self.owner_state, self.owner_state.dict, item, initiator, key ) else: return item - def fire_remove_event(self, item, initiator=None): + def fire_remove_event(self, item, initiator=None, key=NO_KEY): """Notify that a entity has been removed from the collection. Initiator is the InstrumentedAttribute that initiated the membership @@ -678,10 +679,10 @@ class CollectionAdapter: self._reset_empty() self.attr.fire_remove_event( - self.owner_state, self.owner_state.dict, item, initiator + self.owner_state, self.owner_state.dict, item, initiator, key ) - def fire_pre_remove_event(self, initiator=None): + def fire_pre_remove_event(self, initiator=None, key=NO_KEY): """Notify that an entity is about to be removed from the collection. Only called if the entity cannot be removed after calling @@ -691,7 +692,10 @@ class CollectionAdapter: if self.invalidated: self._warn_invalidated() self.attr.fire_pre_remove_event( - self.owner_state, self.owner_state.dict, initiator=initiator + self.owner_state, + self.owner_state.dict, + initiator=initiator, + key=key, ) def __getstate__(self): @@ -1025,10 +1029,12 @@ def __set_wo_mutation(collection, item, _sa_initiator=None): if _sa_initiator is not False: executor = collection._sa_adapter if executor: - executor.fire_append_wo_mutation_event(item, _sa_initiator) + executor.fire_append_wo_mutation_event( + item, _sa_initiator, key=None + ) -def __set(collection, item, _sa_initiator=None): +def __set(collection, item, _sa_initiator, key): """Run set events. This event always occurs before the collection is actually mutated. @@ -1038,11 +1044,11 @@ def __set(collection, item, _sa_initiator=None): if _sa_initiator is not False: executor = collection._sa_adapter if executor: - item = executor.fire_append_event(item, _sa_initiator) + item = executor.fire_append_event(item, _sa_initiator, key=key) return item -def __del(collection, item, _sa_initiator=None): +def __del(collection, item, _sa_initiator, key): """Run del events. This event occurs before the collection is actually mutated, *except* @@ -1054,7 +1060,7 @@ def __del(collection, item, _sa_initiator=None): if _sa_initiator is not False: executor = collection._sa_adapter if executor: - executor.fire_remove_event(item, _sa_initiator) + executor.fire_remove_event(item, _sa_initiator, key=key) def __before_pop(collection, _sa_initiator=None): @@ -1073,7 +1079,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: def append(fn): def append(self, item, _sa_initiator=None): - item = __set(self, item, _sa_initiator) + item = __set(self, item, _sa_initiator, NO_KEY) fn(self, item) _tidy(append) @@ -1081,7 +1087,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: def remove(fn): def remove(self, value, _sa_initiator=None): - __del(self, value, _sa_initiator) + __del(self, value, _sa_initiator, NO_KEY) # testlib.pragma exempt:__eq__ fn(self, value) @@ -1090,7 +1096,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: def insert(fn): def insert(self, index, value): - value = __set(self, value) + value = __set(self, value, None, index) fn(self, index, value) _tidy(insert) @@ -1101,8 +1107,8 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: if not isinstance(index, slice): existing = self[index] if existing is not None: - __del(self, existing) - value = __set(self, value) + __del(self, existing, None, index) + value = __set(self, value, None, index) fn(self, index, value) else: # slice assignment requires __delitem__, insert, __len__ @@ -1144,14 +1150,14 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: def __delitem__(self, index): if not isinstance(index, slice): item = self[index] - __del(self, item) + __del(self, item, None, index) fn(self, index) else: # slice deletion requires __getslice__ and a slice-groking # __getitem__ for stepped deletion # note: not breaking this into atomic dels for item in self[index]: - __del(self, item) + __del(self, item, None, index) fn(self, index) _tidy(__delitem__) @@ -1180,7 +1186,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: def pop(self, index=-1): __before_pop(self) item = fn(self, index) - __del(self, item) + __del(self, item, None, index) return item _tidy(pop) @@ -1189,7 +1195,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: def clear(fn): def clear(self, index=-1): for item in self: - __del(self, item) + __del(self, item, None, index) fn(self) _tidy(clear) @@ -1217,8 +1223,8 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: def __setitem__(fn): def __setitem__(self, key, value, _sa_initiator=None): if key in self: - __del(self, self[key], _sa_initiator) - value = __set(self, value, _sa_initiator) + __del(self, self[key], _sa_initiator, key) + value = __set(self, value, _sa_initiator, key) fn(self, key, value) _tidy(__setitem__) @@ -1227,7 +1233,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: def __delitem__(fn): def __delitem__(self, key, _sa_initiator=None): if key in self: - __del(self, self[key], _sa_initiator) + __del(self, self[key], _sa_initiator, key) fn(self, key) _tidy(__delitem__) @@ -1236,7 +1242,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: def clear(fn): def clear(self): for key in self: - __del(self, self[key]) + __del(self, self[key], None, key) fn(self) _tidy(clear) @@ -1251,7 +1257,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: else: item = fn(self, key, default) if _to_del: - __del(self, item) + __del(self, item, None, key) return item _tidy(pop) @@ -1261,7 +1267,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: def popitem(self): __before_pop(self) item = fn(self) - __del(self, item[1]) + __del(self, item[1], None, 1) return item _tidy(popitem) @@ -1341,7 +1347,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: def add(fn): def add(self, value, _sa_initiator=None): if value not in self: - value = __set(self, value, _sa_initiator) + value = __set(self, value, _sa_initiator, NO_KEY) else: __set_wo_mutation(self, value, _sa_initiator) # testlib.pragma exempt:__hash__ @@ -1354,7 +1360,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: def discard(self, value, _sa_initiator=None): # testlib.pragma exempt:__hash__ if value in self: - __del(self, value, _sa_initiator) + __del(self, value, _sa_initiator, NO_KEY) # testlib.pragma exempt:__hash__ fn(self, value) @@ -1365,7 +1371,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: def remove(self, value, _sa_initiator=None): # testlib.pragma exempt:__hash__ if value in self: - __del(self, value, _sa_initiator) + __del(self, value, _sa_initiator, NO_KEY) # testlib.pragma exempt:__hash__ fn(self, value) @@ -1378,7 +1384,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: item = fn(self) # for set in particular, we have no way to access the item # that will be popped before pop is called. - __del(self, item) + __del(self, item, None, NO_KEY) return item _tidy(pop) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 680e499815..c17ea1abed 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -22,6 +22,7 @@ from . import interfaces from . import mapperlib from .attributes import QueryableAttribute from .base import _mapper_or_none +from .base import NO_KEY from .query import Query from .scoping import scoped_session from .session import Session @@ -2288,6 +2289,7 @@ class AttributeEvents(event.Events): raw=False, retval=False, propagate=False, + include_key=False, ): target, fn = event_key.dispatch_target, event_key._listen_fn @@ -2295,9 +2297,9 @@ class AttributeEvents(event.Events): if active_history: target.dispatch._active_history = True - if not raw or not retval: + if not raw or not retval or not include_key: - def wrap(target, *arg): + def wrap(target, *arg, **kw): if not raw: target = target.obj() if not retval: @@ -2305,10 +2307,16 @@ class AttributeEvents(event.Events): value = arg[0] else: value = None - fn(target, *arg) + if include_key: + fn(target, *arg, **kw) + else: + fn(target, *arg) return value else: - return fn(target, *arg) + if include_key: + return fn(target, *arg, **kw) + else: + return fn(target, *arg) event_key = event_key.with_wrapper(wrap) @@ -2324,7 +2332,7 @@ class AttributeEvents(event.Events): if active_history: mgr[target.key].dispatch._active_history = True - def append(self, target, value, initiator): + def append(self, target, value, initiator, *, key=NO_KEY): """Receive a collection append event. The append event is invoked for each element as it is appended @@ -2343,6 +2351,19 @@ class AttributeEvents(event.Events): from its original value by backref handlers in order to control chained event propagation, as well as be inspected for information about the source of the event. + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``collection[some_key_or_index] = value``. + The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + :return: if the event was registered with ``retval=True``, the given value, or a new effective value, should be returned. @@ -2355,7 +2376,7 @@ class AttributeEvents(event.Events): """ - def append_wo_mutation(self, target, value, initiator): + def append_wo_mutation(self, target, value, initiator, *, key=NO_KEY): """Receive a collection append event where the collection was not actually mutated. @@ -2378,6 +2399,18 @@ class AttributeEvents(event.Events): from its original value by backref handlers in order to control chained event propagation, as well as be inspected for information about the source of the event. + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``collection[some_key_or_index] = value``. + The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 :return: No return value is defined for this event. @@ -2385,7 +2418,7 @@ class AttributeEvents(event.Events): """ - def bulk_replace(self, target, values, initiator): + def bulk_replace(self, target, values, initiator, *, keys=None): """Receive a collection 'bulk replace' event. This event is invoked for a sequence of values as they are incoming @@ -2428,6 +2461,17 @@ class AttributeEvents(event.Events): handler can modify this list in place. :param initiator: An instance of :class:`.attributes.Event` representing the initiation of the event. + :param keys: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the sequence of keys used in the operation, + typically only for a dictionary update. The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 .. seealso:: @@ -2437,7 +2481,7 @@ class AttributeEvents(event.Events): """ - def remove(self, target, value, initiator): + def remove(self, target, value, initiator, *, key=NO_KEY): """Receive a collection remove event. :param target: the object instance receiving the event. @@ -2453,6 +2497,17 @@ class AttributeEvents(event.Events): passed as a :class:`.attributes.Event` object, and may be modified by backref handlers within a chain of backref-linked events. + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``del collection[some_key_or_index]``. The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 :return: No return value is defined for this event. diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 16062fffa3..72f5c6a7b1 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -49,6 +49,7 @@ from .base import InspectionAttr as InspectionAttr # noqa: F401 from .base import InspectionAttrInfo as InspectionAttrInfo from .base import MANYTOMANY as MANYTOMANY # noqa: F401 from .base import MANYTOONE as MANYTOONE # noqa: F401 +from .base import NO_KEY as NO_KEY # noqa: F401 from .base import NotExtension as NotExtension # noqa: F401 from .base import ONETOMANY as ONETOMANY # noqa: F401 from .base import RelationshipDirection as RelationshipDirection # noqa: F401 diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index c83ffdb598..5e66653a38 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -47,7 +47,7 @@ def track_cascade_events(descriptor, prop): """ key = prop.key - def append(state, item, initiator): + def append(state, item, initiator, **kw): # process "save_update" cascade rules for when # an instance is appended to the list of another instance @@ -70,7 +70,7 @@ def track_cascade_events(descriptor, prop): sess._save_or_update_state(item_state) return item - def remove(state, item, initiator): + def remove(state, item, initiator, **kw): if item is None: return @@ -104,7 +104,7 @@ def track_cascade_events(descriptor, prop): # item item_state._orphaned_outside_of_session = True - def set_(state, newvalue, oldvalue, initiator): + def set_(state, newvalue, oldvalue, initiator, **kw): # process "save_update" cascade rules for when an instance # is attached to another instance if oldvalue is newvalue: @@ -141,10 +141,18 @@ def track_cascade_events(descriptor, prop): sess.expunge(oldvalue) return newvalue - event.listen(descriptor, "append_wo_mutation", append, raw=True) - event.listen(descriptor, "append", append, raw=True, retval=True) - event.listen(descriptor, "remove", remove, raw=True, retval=True) - event.listen(descriptor, "set", set_, raw=True, retval=True) + event.listen( + descriptor, "append_wo_mutation", append, raw=True, include_key=True + ) + event.listen( + descriptor, "append", append, raw=True, retval=True, include_key=True + ) + event.listen( + descriptor, "remove", remove, raw=True, retval=True, include_key=True + ) + event.listen( + descriptor, "set", set_, raw=True, retval=True, include_key=True + ) class UOWTransaction: diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index e1274a8051..53b306f5b1 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -8,6 +8,8 @@ from sqlalchemy import testing from sqlalchemy.orm import attributes from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import instrumentation +from sqlalchemy.orm import NO_KEY +from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.orm.collections import collection from sqlalchemy.orm.state import InstanceState from sqlalchemy.testing import assert_raises @@ -23,7 +25,6 @@ from sqlalchemy.testing.assertions import assert_warns from sqlalchemy.testing.util import all_partial_orderings from sqlalchemy.testing.util import gc_collect - # global for pickling tests MyTest = None MyTest2 = None @@ -2576,8 +2577,6 @@ class HistoryTest(fixtures.TestBase): class Bar(fixtures.BasicEntity): pass - from sqlalchemy.orm.collections import attribute_mapped_collection - instrumentation.register_class(Foo) instrumentation.register_class(Bar) _register_attribute( @@ -3193,6 +3192,239 @@ class LazyloadHistoryTest(fixtures.TestBase): ) +class CollectionKeyTest(fixtures.ORMTest): + @testing.fixture + def dict_collection(self): + class Foo(fixtures.BasicEntity): + pass + + class Bar(fixtures.BasicEntity): + def __init__(self, name): + self.name = name + + instrumentation.register_class(Foo) + instrumentation.register_class(Bar) + _register_attribute( + Foo, + "someattr", + uselist=True, + useobject=True, + typecallable=attribute_mapped_collection("name"), + ) + _register_attribute( + Bar, + "name", + uselist=False, + useobject=False, + ) + + return Foo, Bar + + @testing.fixture + def list_collection(self): + class Foo(fixtures.BasicEntity): + pass + + class Bar(fixtures.BasicEntity): + pass + + instrumentation.register_class(Foo) + instrumentation.register_class(Bar) + _register_attribute( + Foo, + "someattr", + uselist=True, + useobject=True, + ) + + return Foo, Bar + + def test_listen_w_list_key(self, list_collection): + Foo, Bar = list_collection + + m1 = Mock() + + event.listen(Foo.someattr, "append", m1, include_key=True) + event.listen(Foo.someattr, "remove", m1, include_key=True) + + f1 = Foo() + b1, b2, b3 = Bar(), Bar(), Bar() + f1.someattr.append(b1) + f1.someattr.append(b2) + f1.someattr[1] = b3 + del f1.someattr[0] + append_token, remove_token = ( + Foo.someattr.impl._append_token, + Foo.someattr.impl._remove_token, + ) + + eq_( + m1.mock_calls, + [ + call( + f1, + b1, + append_token, + key=NO_KEY, + ), + call( + f1, + b2, + append_token, + key=NO_KEY, + ), + call( + f1, + b2, + remove_token, + key=1, + ), + call( + f1, + b3, + append_token, + key=1, + ), + call( + f1, + b1, + remove_token, + key=0, + ), + ], + ) + + def test_listen_w_dict_key(self, dict_collection): + Foo, Bar = dict_collection + + m1 = Mock() + + event.listen(Foo.someattr, "append", m1, include_key=True) + event.listen(Foo.someattr, "remove", m1, include_key=True) + + f1 = Foo() + b1, b2, b3 = Bar("b1"), Bar("b2"), Bar("b3") + f1.someattr["k1"] = b1 + f1.someattr.update({"k2": b2, "k3": b3}) + + del f1.someattr["k2"] + + append_token, remove_token = ( + Foo.someattr.impl._append_token, + Foo.someattr.impl._remove_token, + ) + + eq_( + m1.mock_calls, + [ + call( + f1, + b1, + append_token, + key="k1", + ), + call( + f1, + b2, + append_token, + key="k2", + ), + call( + f1, + b3, + append_token, + key="k3", + ), + call( + f1, + b2, + remove_token, + key="k2", + ), + ], + ) + + def test_dict_bulk_replace_w_key(self, dict_collection): + Foo, Bar = dict_collection + + m1 = Mock() + + event.listen(Foo.someattr, "bulk_replace", m1, include_key=True) + event.listen(Foo.someattr, "append", m1, include_key=True) + event.listen(Foo.someattr, "remove", m1, include_key=True) + + f1 = Foo() + b1, b2, b3, b4 = Bar("b1"), Bar("b2"), Bar("b3"), Bar("b4") + f1.someattr = {"b1": b1, "b3": b3} + f1.someattr = {"b2": b2, "b3": b3, "b4": b4} + + bulk_replace_token = Foo.someattr.impl._bulk_replace_token + + eq_( + m1.mock_calls, + [ + call(f1, [b1, b3], bulk_replace_token, keys=["b1", "b3"]), + call(f1, b1, bulk_replace_token, key="b1"), + call(f1, b3, bulk_replace_token, key="b3"), + call( + f1, + [b2, b3, b4], + bulk_replace_token, + keys=["b2", "b3", "b4"], + ), + call(f1, b2, bulk_replace_token, key="b2"), + call(f1, b4, bulk_replace_token, key="b4"), + call(f1, b1, bulk_replace_token, key=NO_KEY), + ], + ) + + def test_listen_wo_dict_key(self, dict_collection): + Foo, Bar = dict_collection + + m1 = Mock() + + event.listen(Foo.someattr, "append", m1) + event.listen(Foo.someattr, "remove", m1) + + f1 = Foo() + b1, b2, b3 = Bar("b1"), Bar("b2"), Bar("b3") + f1.someattr["k1"] = b1 + f1.someattr.update({"k2": b2, "k3": b3}) + + del f1.someattr["k2"] + + append_token, remove_token = ( + Foo.someattr.impl._append_token, + Foo.someattr.impl._remove_token, + ) + + eq_( + m1.mock_calls, + [ + call( + f1, + b1, + append_token, + ), + call( + f1, + b2, + append_token, + ), + call( + f1, + b3, + append_token, + ), + call( + f1, + b2, + remove_token, + ), + ], + ) + + class ListenerTest(fixtures.ORMTest): def test_receive_changes(self): """test that Listeners can mutate the given value."""