From: Federico Caselli Date: Fri, 3 Jan 2025 20:42:48 +0000 (+0100) Subject: Added `merge_all` and `delete_all` X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=41c30cc031266d2e3a02ccc0d6cd2ab91bc725fa;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Added `merge_all` and `delete_all` Added the utility method :meth:`_orm.Session.merge_all` and :meth:`_orm.Session.delete_all` that operate on a collection of instances. Fixes: #11776 Change-Id: Ifd70ba2850db7c5e7aee482799fd65c348c2899a --- diff --git a/doc/build/changelog/unreleased_21/11776.rst b/doc/build/changelog/unreleased_21/11776.rst new file mode 100644 index 0000000000..446c5e1717 --- /dev/null +++ b/doc/build/changelog/unreleased_21/11776.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: orm, usecase + :tickets: 11776 + + Added the utility method :meth:`_orm.Session.merge_all` and + :meth:`_orm.Session.delete_all` that operate on a collection + of instances. diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 027e6947db..823c354f3f 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -85,6 +85,7 @@ _Ts = TypeVarTuple("_Ts") "commit", "connection", "delete", + "delete_all", "execute", "expire", "expire_all", @@ -95,6 +96,7 @@ _Ts = TypeVarTuple("_Ts") "is_modified", "invalidate", "merge", + "merge_all", "refresh", "rollback", "scalar", @@ -287,7 +289,7 @@ class async_scoped_session(Generic[_AS]): return await self._proxied.aclose() - def add(self, instance: object, _warn: bool = True) -> None: + def add(self, instance: object, *, _warn: bool = True) -> None: r"""Place an object into this :class:`_orm.Session`. .. container:: class_bases @@ -530,6 +532,23 @@ class async_scoped_session(Generic[_AS]): return await self._proxied.delete(instance) + async def delete_all(self, instances: Iterable[object]) -> None: + r"""Calls :meth:`.AsyncSession.delete` on multiple instances. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.delete_all` - main documentation for delete_all + + + """ # noqa: E501 + + return await self._proxied.delete_all(instances) + @overload async def execute( self, @@ -958,6 +977,31 @@ class async_scoped_session(Generic[_AS]): return await self._proxied.merge(instance, load=load, options=options) + async def merge_all( + self, + instances: Iterable[_O], + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> Sequence[_O]: + r"""Calls :meth:`.AsyncSession.merge` on multiple instances. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.merge_all` - main documentation for merge_all + + + """ # noqa: E501 + + return await self._proxied.merge_all( + instances, load=load, options=options + ) + async def refresh( self, instance: object, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 65e3b541a7..adb88f53f6 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -775,6 +775,16 @@ class AsyncSession(ReversibleProxy[Session]): """ await greenlet_spawn(self.sync_session.delete, instance) + async def delete_all(self, instances: Iterable[object]) -> None: + """Calls :meth:`.AsyncSession.delete` on multiple instances. + + .. seealso:: + + :meth:`_orm.Session.delete_all` - main documentation for delete_all + + """ + await greenlet_spawn(self.sync_session.delete_all, instances) + async def merge( self, instance: _O, @@ -794,6 +804,24 @@ class AsyncSession(ReversibleProxy[Session]): self.sync_session.merge, instance, load=load, options=options ) + async def merge_all( + self, + instances: Iterable[_O], + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> Sequence[_O]: + """Calls :meth:`.AsyncSession.merge` on multiple instances. + + .. seealso:: + + :meth:`_orm.Session.merge_all` - main documentation for merge_all + + """ + return await greenlet_spawn( + self.sync_session.merge_all, instances, load=load, options=options + ) + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: """Flush all the object changes to the database. @@ -1122,7 +1150,7 @@ class AsyncSession(ReversibleProxy[Session]): return self._proxied.__iter__() - def add(self, instance: object, _warn: bool = True) -> None: + def add(self, instance: object, *, _warn: bool = True) -> None: r"""Place an object into this :class:`_orm.Session`. .. container:: class_bases diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index b5f51fee53..deee8bc3ad 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -327,9 +327,7 @@ def merge_frozen_result(session, statement, frozen_result, load=True): statement, legacy=False ) - autoflush = session.autoflush - try: - session.autoflush = False + with session.no_autoflush: mapped_entities = [ i for i, e in enumerate(ctx._entities) @@ -356,8 +354,6 @@ def merge_frozen_result(session, statement, frozen_result, load=True): result.append(keyed_tuple(newrow)) return frozen_result.with_new_rows(result) - finally: - session.autoflush = autoflush @util.became_legacy_20( diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 8a333401be..ac746ee056 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -116,6 +116,7 @@ __all__ = ["scoped_session"] "commit", "connection", "delete", + "delete_all", "execute", "expire", "expire_all", @@ -130,6 +131,7 @@ __all__ = ["scoped_session"] "bulk_insert_mappings", "bulk_update_mappings", "merge", + "merge_all", "query", "refresh", "rollback", @@ -350,7 +352,7 @@ class scoped_session(Generic[_S]): return self._proxied.__iter__() - def add(self, instance: object, _warn: bool = True) -> None: + def add(self, instance: object, *, _warn: bool = True) -> None: r"""Place an object into this :class:`_orm.Session`. .. container:: class_bases @@ -673,11 +675,32 @@ class scoped_session(Generic[_S]): :ref:`session_deleting` - at :ref:`session_basics` + :meth:`.Session.delete_all` - multiple instance version + """ # noqa: E501 return self._proxied.delete(instance) + def delete_all(self, instances: Iterable[object]) -> None: + r"""Calls :meth:`.Session.delete` on multiple instances. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. seealso:: + + :meth:`.Session.delete` - main documentation on delete + + .. versionadded: 2.1 + + + """ # noqa: E501 + + return self._proxied.delete_all(instances) + @overload def execute( self, @@ -1567,11 +1590,38 @@ class scoped_session(Generic[_S]): :func:`.make_transient_to_detached` - provides for an alternative means of "merging" a single object into the :class:`.Session` + :meth:`.Session.merge_all` - multiple instance version + """ # noqa: E501 return self._proxied.merge(instance, load=load, options=options) + def merge_all( + self, + instances: Iterable[_O], + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> Sequence[_O]: + r"""Calls :meth:`.Session.merge` on multiple instances. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. seealso:: + + :meth:`.Session.merge` - main documentation on merge + + .. versionadded: 2.1 + + + """ # noqa: E501 + + return self._proxied.merge_all(instances, load=load, options=options) + @overload def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 28a32b3f23..8e7c38061e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -3459,7 +3459,7 @@ class Session(_SessionClassMethods, EventTarget): if persistent_to_deleted is not None: persistent_to_deleted(self, state) - def add(self, instance: object, _warn: bool = True) -> None: + def add(self, instance: object, *, _warn: bool = True) -> None: """Place an object into this :class:`_orm.Session`. Objects that are in the :term:`transient` state when passed to the @@ -3544,16 +3544,30 @@ class Session(_SessionClassMethods, EventTarget): :ref:`session_deleting` - at :ref:`session_basics` + :meth:`.Session.delete_all` - multiple instance version + """ if self._warn_on_events: self._flush_warning("Session.delete()") - try: - state = attributes.instance_state(instance) - except exc.NO_STATE as err: - raise exc.UnmappedInstanceError(instance) from err + self._delete_impl(object_state(instance), instance, head=True) + + def delete_all(self, instances: Iterable[object]) -> None: + """Calls :meth:`.Session.delete` on multiple instances. - self._delete_impl(state, instance, head=True) + .. seealso:: + + :meth:`.Session.delete` - main documentation on delete + + .. versionadded: 2.1 + + """ + + if self._warn_on_events: + self._flush_warning("Session.delete_all()") + + for instance in instances: + self._delete_impl(object_state(instance), instance, head=True) def _delete_impl( self, state: InstanceState[Any], obj: object, head: bool @@ -3955,32 +3969,62 @@ class Session(_SessionClassMethods, EventTarget): :func:`.make_transient_to_detached` - provides for an alternative means of "merging" a single object into the :class:`.Session` + :meth:`.Session.merge_all` - multiple instance version + """ if self._warn_on_events: self._flush_warning("Session.merge()") - _recursive: Dict[InstanceState[Any], object] = {} - _resolve_conflict_map: Dict[_IdentityKeyType[Any], object] = {} - if load: # flush current contents if we expect to load data self._autoflush() - object_mapper(instance) # verify mapped - autoflush = self.autoflush - try: - self.autoflush = False + with self.no_autoflush: return self._merge( - attributes.instance_state(instance), + object_state(instance), attributes.instance_dict(instance), load=load, options=options, - _recursive=_recursive, - _resolve_conflict_map=_resolve_conflict_map, + _recursive={}, + _resolve_conflict_map={}, ) - finally: - self.autoflush = autoflush + + def merge_all( + self, + instances: Iterable[_O], + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> Sequence[_O]: + """Calls :meth:`.Session.merge` on multiple instances. + + .. seealso:: + + :meth:`.Session.merge` - main documentation on merge + + .. versionadded: 2.1 + + """ + + if self._warn_on_events: + self._flush_warning("Session.merge_all()") + + if load: + # flush current contents if we expect to load data + self._autoflush() + + return [ + self._merge( + object_state(instance), + attributes.instance_dict(instance), + load=load, + options=options, + _recursive={}, + _resolve_conflict_map={}, + ) + for instance in instances + ] def _merge( self, diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index c313c4b33d..9fb16a2ce1 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -1806,6 +1806,29 @@ class MergeTest(_fixtures.FixtureTest): eq_(sess.query(Address).one(), Address(id=1, email_address="c")) + def test_merge_all(self): + User, users = self.classes.User, self.tables.users + + self.mapper_registry.map_imperatively(User, users) + sess = fixture_session() + load = self.load_tracker(User) + + ua = User(id=42, name="bob") + ub = User(id=7, name="fred") + eq_(load.called, 0) + uam, ubm = sess.merge_all([ua, ub]) + eq_(load.called, 2) + assert uam in sess + assert ubm in sess + eq_(uam, User(id=42, name="bob")) + eq_(ubm, User(id=7, name="fred")) + sess.flush() + sess.expunge_all() + eq_( + sess.query(User).order_by("id").all(), + [User(id=7, name="fred"), User(id=42, name="bob")], + ) + class M2ONoUseGetLoadingTest(fixtures.MappedTest): """Merge a one-to-many. The many-to-one on the other side is set up diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 1495932744..a59e9d33da 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -681,6 +681,23 @@ class SessionUtilTest(_fixtures.FixtureTest): ): sess.get_one(User, 2) + def test_delete_all(self): + users, User = self.tables.users, self.classes.User + self.mapper_registry.map_imperatively(User, users) + + sess = fixture_session() + + sess.add_all([User(id=1, name="u1"), User(id=2, name="u2")]) + sess.commit() + sess.close() + + ua, ub = sess.scalars(select(User)).all() + eq_([ua in sess, ub in sess], [True, True]) + sess.delete_all([ua, ub]) + sess.flush() + eq_([ua in sess, ub in sess], [False, False]) + eq_(sess.scalars(select(User)).all(), []) + class SessionStateTest(_fixtures.FixtureTest): run_inserts = None @@ -2109,7 +2126,8 @@ class SessionInterface(fixtures.MappedTest): ]: raises_(name, user_arg) - raises_("add_all", (user_arg,)) + for name in ["add_all", "merge_all", "delete_all"]: + raises_(name, (user_arg,)) # flush will no-op without something in the unit of work def _(): diff --git a/test/profiles.txt b/test/profiles.txt index 618002023e..eff6c5f46d 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -394,10 +394,10 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.12_ # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_no_load -test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 108,20 -test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 108,20 -test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 108,20 -test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_nocextensions 108,20 +test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 108,29 +test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 108,29 +test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 108,29 +test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_nocextensions 108,29 # TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols