From: Carlos Sousa Date: Mon, 25 Sep 2023 17:03:26 +0000 (-0400) Subject: Add get_one to Session, AsyncSession, scoped, etc X-Git-Tag: rel_2_0_22~21^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=dc8b7cb5fdb556d78145c1f67737671307f3604d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add get_one to Session, AsyncSession, scoped, etc Added method :meth:`_orm.Session.get_one` that behaves like meth:`_orm.Session.get` but raises an exception instead of returning None`` if no instance was found with the provided primary key. Pull request courtesy of Carlos Sousa. Fixed the :paramref:`_asyncio.AsyncSession.get.execution_options` parameter which was not being propagated to the underlying :class:`_orm.Session` and was instead being ignored. Fixes #10292 Closes: #10376 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10376 Pull-request-sha: 70e4505e93905ee3cebc52f828a95c6bf987c9be Change-Id: I78eb9816c26446757b6c6c171df2e400777a3d36 --- diff --git a/doc/build/changelog/unreleased_20/10292.rst b/doc/build/changelog/unreleased_20/10292.rst new file mode 100644 index 0000000000..1ca2dfb897 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10292.rst @@ -0,0 +1,16 @@ +.. change:: + :tags: orm, usecase + :tickets: 10202 + + Added method :meth:`_orm.Session.get_one` that behaves like + :meth:`_orm.Session.get` but raises an exception instead of returning + ``None`` if no instance was found with the provided primary key. + Pull request courtesy of Carlos Sousa. + + +.. change:: + :tags: asyncio, bug + + Fixed the :paramref:`_asyncio.AsyncSession.get.execution_options` parameter + which was not being propagated to the underlying :class:`_orm.Session` and + was instead being ignored. diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index b70c3366b1..d0228b84c4 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -94,6 +94,8 @@ _T = TypeVar("_T", bound=Any) "rollback", "scalar", "scalars", + "get", + "get_one", "stream", "stream_scalars", ], @@ -108,6 +110,7 @@ _T = TypeVar("_T", bound=Any) "no_autoflush", "info", ], + use_intermediate_variable=["get"], ) class async_scoped_session(Generic[_AS]): """Provides scoped management of :class:`.AsyncSession` objects. @@ -213,49 +216,6 @@ class async_scoped_session(Generic[_AS]): await self.registry().close() self.registry.clear() - async def get( - self, - entity: _EntityBindKey[_O], - ident: _PKIdentityArgument, - *, - options: Optional[Sequence[ORMOption]] = None, - populate_existing: bool = False, - with_for_update: ForUpdateParameter = None, - identity_token: Optional[Any] = None, - execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - ) -> Optional[_O]: - r"""Return an instance based on the given primary key identifier, - or ``None`` if not found. - - .. container:: class_bases - - Proxied for the :class:`_asyncio.AsyncSession` class on - behalf of the :class:`_asyncio.scoping.async_scoped_session` class. - - .. seealso:: - - :meth:`_orm.Session.get` - main documentation for get - - - - """ # noqa: E501 - - # this was proxied but Mypy is requiring the return type to be - # clarified - - # work around: - # https://github.com/python/typing/discussions/1143 - return_value = await self._proxied.get( - entity, - ident, - options=options, - populate_existing=populate_existing, - with_for_update=with_for_update, - identity_token=identity_token, - execution_options=execution_options, - ) - return return_value - # START PROXY METHODS async_scoped_session # code within this block is **programmatically, @@ -1137,6 +1097,85 @@ class async_scoped_session(Generic[_AS]): **kw, ) + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Union[_O, None]: + r"""Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + + """ # noqa: E501 + + result = await self._proxied.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + return result + + async def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> _O: + r"""Return an instance based on the given primary key identifier, + or raise an exception if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. + + ..versionadded: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.get_one` - main documentation for get_one + + + """ # noqa: E501 + + return await self._proxied.get_one( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + @overload async def stream( self, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index da69c4fb3e..b768b22526 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -509,7 +509,7 @@ class AsyncSession(ReversibleProxy[Session]): else: execution_options = _EXECUTE_OPTIONS - result = await greenlet_spawn( + return await greenlet_spawn( self.sync_session.scalar, statement, params=params, @@ -517,7 +517,6 @@ class AsyncSession(ReversibleProxy[Session]): bind_arguments=bind_arguments, **kw, ) - return result @overload async def scalars( @@ -588,7 +587,7 @@ class AsyncSession(ReversibleProxy[Session]): with_for_update: ForUpdateParameter = None, identity_token: Optional[Any] = None, execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, - ) -> Optional[_O]: + ) -> Union[_O, None]: """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -599,9 +598,7 @@ class AsyncSession(ReversibleProxy[Session]): """ - # result_obj = self.sync_session.get(entity, ident) - - result_obj = await greenlet_spawn( + return await greenlet_spawn( cast("Callable[..., _O]", self.sync_session.get), entity, ident, @@ -609,8 +606,44 @@ class AsyncSession(ReversibleProxy[Session]): populate_existing=populate_existing, with_for_update=with_for_update, identity_token=identity_token, + execution_options=execution_options, + ) + + async def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> _O: + """Return an instance based on the given primary key identifier, + or raise an exception if not found. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. + + ..versionadded: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.get_one` - main documentation for get_one + + """ + + return await greenlet_spawn( + cast("Callable[..., _O]", self.sync_session.get_one), + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, ) - return result_obj @overload async def stream( diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index fc144d98c4..39f69d589f 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -118,6 +118,7 @@ __all__ = ["scoped_session"] "expunge_all", "flush", "get", + "get_one", "get_bind", "is_modified", "bulk_save_objects", @@ -1028,6 +1029,56 @@ class scoped_session(Generic[_S]): bind_arguments=bind_arguments, ) + def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> _O: + r"""Return exactly one instance based on the given primary key + identifier, or raise an exception if not found. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query + selects no rows. + + For a detailed documentation of the arguments see the + method :meth:`.Session.get`. + + ..versionadded: 2.0.22 + + :return: The object instance, or ``None``. + + .. seealso:: + + :meth:`.Session.get` - equivalent method that instead + returns ``None`` if no row was found with the provided primary + key + + + """ # noqa: E501 + + return self._proxied.get_one( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + def get_bind( self, mapper: Optional[_EntityBindKey[_O]] = None, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e5eb5036dd..2490dd1311 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -3580,6 +3580,57 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments=bind_arguments, ) + def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> _O: + """Return exactly one instance based on the given primary key + identifier, or raise an exception if not found. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query + selects no rows. + + For a detailed documentation of the arguments see the + method :meth:`.Session.get`. + + ..versionadded: 2.0.22 + + :return: The object instance, or ``None``. + + .. seealso:: + + :meth:`.Session.get` - equivalent method that instead + returns ``None`` if no row was found with the provided primary + key + + """ + + instance = self.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + if instance is None: + raise sa_exc.NoResultFound( + "No row was found when one was required" + ) + + return instance + def _get_impl( self, entity: _EntityBindKey[_O], diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 38c324ea0d..c4527e123f 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -690,6 +690,7 @@ def create_proxy_methods( classmethods: Sequence[str] = (), methods: Sequence[str] = (), attributes: Sequence[str] = (), + use_intermediate_variable: Sequence[str] = (), ) -> Callable[[_T], _T]: """A class decorator indicating attributes should refer to a proxy class. diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 8fa174eeba..42687a3987 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -193,6 +193,24 @@ class AsyncSessionQueryTest(AsyncFixture): u3 = await async_session.get(User, 12) is_(u3, None) + @async_test + async def test_get_one(self, async_session): + User = self.classes.User + + u1 = await async_session.get_one(User, 7) + u2 = await async_session.get_one(User, 10) + u3 = await async_session.get_one(User, 7) + + is_(u1, u3) + eq_(u1.name, "jack") + eq_(u2.name, "chuck") + + with testing.expect_raises_message( + exc.NoResultFound, + "No row was found when one was required", + ): + await async_session.get_one(User, 12) + @async_test async def test_force_a_lazyload(self, async_session): """test for #9298""" diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 389dbe00a0..3ff49fc82f 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -5,6 +5,7 @@ from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import delete from sqlalchemy import event +from sqlalchemy import exc from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import inspect @@ -35,6 +36,7 @@ from sqlalchemy.testing import expect_deprecated from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import provision +from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.engines import testing_reaper @@ -196,7 +198,7 @@ class ShardTest: toronto = WeatherLocation("North America", "Toronto") london = WeatherLocation("Europe", "London") dublin = WeatherLocation("Europe", "Dublin") - brasilia = WeatherLocation("South America", "Brasila") + brasilia = WeatherLocation("South America", "Brasilia") quito = WeatherLocation("South America", "Quito") tokyo.reports.append(Report(80.0, id_=1)) newyork.reports.append(Report(75, id_=1)) @@ -226,6 +228,21 @@ class ShardTest: t2 = sess.get(WeatherLocation, 1) is_(t2, tokyo) + def test_get_one(self): + sess = self._fixture_data() + brasilia = sess.get_one(WeatherLocation, 6) + eq_(brasilia.id, 6) + eq_(brasilia.city, "Brasilia") + + toronto = sess.get_one(WeatherLocation, 3) + eq_(toronto.id, 3) + eq_(toronto.city, "Toronto") + + with expect_raises_message( + exc.NoResultFound, "No row was found when one was required" + ): + sess.get_one(WeatherLocation, 25) + def test_get_explicit_shard(self): sess = self._fixture_data() tokyo = ( diff --git a/test/orm/test_query.py b/test/orm/test_query.py index ce5c64a43a..85307add81 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -1213,18 +1213,6 @@ class GetTest(QueryTest): {"i": 1, "j": "2", "k": 3}, ) - def test_get(self): - User = self.classes.User - - s = fixture_session() - assert s.get(User, 19) is None - u = s.get(User, 7) - u2 = s.get(User, 7) - assert u is u2 - s.expunge_all() - u2 = s.get(User, 7) - assert u is not u2 - def test_get_synonym_direct_name(self, decl_base): """test #8753""" diff --git a/test/orm/test_session.py b/test/orm/test_session.py index b304ac5745..c9a47efc5a 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -609,6 +609,59 @@ class SessionUtilTest(_fixtures.FixtureTest): is_true(called) + def test_get(self): + users, User = self.tables.users, self.classes.User + self.mapper_registry.map_imperatively(User, users) + + s = fixture_session() + s.execute( + insert(self.tables.users), + [{"id": 7, "name": "7"}, {"id": 19, "name": "19"}], + ) + assertions.is_not_none(s.get(User, 19)) + u = s.get(User, 7) + u2 = s.get(User, 7) + assertions.is_not_none(u) + is_(u, u2) + s.expunge_all() + u2 = s.get(User, 7) + is_not(u, u2) + + def test_get_one(self): + users, User = self.tables.users, self.classes.User + self.mapper_registry.map_imperatively(User, users) + + s = fixture_session() + s.execute( + insert(self.tables.users), + [{"id": 7, "name": "7"}, {"id": 19, "name": "19"}], + ) + u = s.get_one(User, 7) + u2 = s.get_one(User, 7) + assertions.is_not_none(u) + is_(u, u2) + s.expunge_all() + u2 = s.get_one(User, 7) + is_not(u, u2) + + def test_get_one_2(self): + users, User = self.tables.users, self.classes.User + self.mapper_registry.map_imperatively(User, users) + + sess = fixture_session() + user1 = User(id=1, name="u1") + + sess.add(user1) + sess.commit() + + u1 = sess.get_one(User, user1.id) + eq_(user1.name, u1.name) + + with expect_raises_message( + sa.exc.NoResultFound, "No row was found when one was required" + ): + sess.get_one(User, 2) + class SessionStateTest(_fixtures.FixtureTest): run_inserts = None @@ -1928,7 +1981,14 @@ class SessionInterface(fixtures.MappedTest): def _public_session_methods(self): Session = sa.orm.session.Session - blocklist = {"begin", "query", "bind_mapper", "get", "bind_table"} + blocklist = { + "begin", + "query", + "bind_mapper", + "get", + "get_one", + "bind_table", + } specials = {"__iter__", "__contains__"} ok = set() for name in dir(Session): diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py index 857f8eb718..9881d26426 100644 --- a/tools/generate_proxy_methods.py +++ b/tools/generate_proxy_methods.py @@ -99,6 +99,7 @@ def create_proxy_methods( classmethods: Iterable[str] = (), methods: Iterable[str] = (), attributes: Iterable[str] = (), + use_intermediate_variable: Iterable[str] = (), ) -> Callable[[Type[_T]], Type[_T]]: """A class decorator that will copy attributes to a proxy class. @@ -120,6 +121,7 @@ def create_proxy_methods( classmethods, methods, attributes, + use_intermediate_variable, cls, ) return cls @@ -180,6 +182,7 @@ def process_class( classmethods: Iterable[str], methods: Iterable[str], attributes: Iterable[str], + use_intermediate_variable: Iterable[str], cls: Type[Any], ): sphinx_symbol_match = re.match(r":class:`(.+)`", target_cls_sphinx_name) @@ -192,6 +195,8 @@ def process_class( sphinx_symbol = sphinx_symbol_match.group(1) + require_intermediate = set(use_intermediate_variable) + def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None: fn = getattr(target_cls, name) @@ -255,19 +260,34 @@ def process_class( ).lstrip(), } + if fn.__name__ in require_intermediate: + metadata["line_prefix"] = "result =" + metadata["after_line"] = "return result\n" + else: + metadata["line_prefix"] = "return" + metadata["after_line"] = "" + if clslevel: code = ( - "@classmethod\n" - "%(async)sdef %(name)s%(grouped_args)s:\n" - ' r"""%(doc)s\n """ # noqa: E501\n\n' - " return %(await)s%(target_cls_name)s.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501 + '''\ +@classmethod +%(async)sdef %(name)s%(grouped_args)s: + r"""%(doc)s\n """ # noqa: E501 + + %(line_prefix)s %(await)s%(target_cls_name)s.%(name)s(%(apply_kw_proxied)s) + %(after_line)s +''' % metadata ) else: code = ( - "%(async)sdef %(name)s%(grouped_args)s:\n" - ' r"""%(doc)s\n """ # noqa: E501\n\n' - " return %(await)s%(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501 + '''\ +%(async)sdef %(name)s%(grouped_args)s: + r"""%(doc)s\n """ # noqa: E501 + + %(line_prefix)s %(await)s%(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s) + %(after_line)s +''' # noqa: E501 % metadata )