]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add get_one to Session, AsyncSession, scoped, etc
authorCarlos Sousa <edu-eduardo99@hotmail.com>
Mon, 25 Sep 2023 17:03:26 +0000 (13:03 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 1 Oct 2023 13:07:44 +0000 (09:07 -0400)
Added method :meth:`_orm.Session.get_one` that behaves like
meth:`_orm.Session.get` but raises an exception instead of returning
None`` if no instance was found with the provided primary key.
Pull request courtesy of Carlos Sousa.

Fixed the :paramref:`_asyncio.AsyncSession.get.execution_options` parameter
which was not being propagated to the underlying :class:`_orm.Session` and
was instead being ignored.

Fixes #10292
Closes: #10376
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10376
Pull-request-sha: 70e4505e93905ee3cebc52f828a95c6bf987c9be

Change-Id: I78eb9816c26446757b6c6c171df2e400777a3d36

doc/build/changelog/unreleased_20/10292.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/util/langhelpers.py
test/ext/asyncio/test_session_py3k.py
test/ext/test_horizontal_shard.py
test/orm/test_query.py
test/orm/test_session.py
tools/generate_proxy_methods.py

diff --git a/doc/build/changelog/unreleased_20/10292.rst b/doc/build/changelog/unreleased_20/10292.rst
new file mode 100644 (file)
index 0000000..1ca2dfb
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: orm, usecase
+    :tickets: 10202
+
+    Added method :meth:`_orm.Session.get_one` that behaves like
+    :meth:`_orm.Session.get` but raises an exception instead of returning
+    ``None`` if no instance was found with the provided primary key.
+    Pull request courtesy of Carlos Sousa.
+
+
+.. change::
+    :tags: asyncio, bug
+
+    Fixed the :paramref:`_asyncio.AsyncSession.get.execution_options` parameter
+    which was not being propagated to the underlying :class:`_orm.Session` and
+    was instead being ignored.
index b70c3366b16513045d421188ca56a98aa1287095..d0228b84c4554218739b2a796095a54c517b9a40 100644 (file)
@@ -94,6 +94,8 @@ _T = TypeVar("_T", bound=Any)
         "rollback",
         "scalar",
         "scalars",
+        "get",
+        "get_one",
         "stream",
         "stream_scalars",
     ],
@@ -108,6 +110,7 @@ _T = TypeVar("_T", bound=Any)
         "no_autoflush",
         "info",
     ],
+    use_intermediate_variable=["get"],
 )
 class async_scoped_session(Generic[_AS]):
     """Provides scoped management of :class:`.AsyncSession` objects.
@@ -213,49 +216,6 @@ class async_scoped_session(Generic[_AS]):
             await self.registry().close()
         self.registry.clear()
 
-    async def get(
-        self,
-        entity: _EntityBindKey[_O],
-        ident: _PKIdentityArgument,
-        *,
-        options: Optional[Sequence[ORMOption]] = None,
-        populate_existing: bool = False,
-        with_for_update: ForUpdateParameter = None,
-        identity_token: Optional[Any] = None,
-        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-    ) -> Optional[_O]:
-        r"""Return an instance based on the given primary key identifier,
-        or ``None`` if not found.
-
-        .. container:: class_bases
-
-            Proxied for the :class:`_asyncio.AsyncSession` class on
-            behalf of the :class:`_asyncio.scoping.async_scoped_session` class.
-
-        .. seealso::
-
-            :meth:`_orm.Session.get` - main documentation for get
-
-
-
-        """  # noqa: E501
-
-        # this was proxied but Mypy is requiring the return type to be
-        # clarified
-
-        # work around:
-        # https://github.com/python/typing/discussions/1143
-        return_value = await self._proxied.get(
-            entity,
-            ident,
-            options=options,
-            populate_existing=populate_existing,
-            with_for_update=with_for_update,
-            identity_token=identity_token,
-            execution_options=execution_options,
-        )
-        return return_value
-
     # START PROXY METHODS async_scoped_session
 
     # code within this block is **programmatically,
@@ -1137,6 +1097,85 @@ class async_scoped_session(Generic[_AS]):
             **kw,
         )
 
+    async def get(
+        self,
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: ForUpdateParameter = None,
+        identity_token: Optional[Any] = None,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+    ) -> Union[_O, None]:
+        r"""Return an instance based on the given primary key identifier,
+        or ``None`` if not found.
+
+        .. container:: class_bases
+
+            Proxied for the :class:`_asyncio.AsyncSession` class on
+            behalf of the :class:`_asyncio.scoping.async_scoped_session` class.
+
+        .. seealso::
+
+            :meth:`_orm.Session.get` - main documentation for get
+
+
+
+        """  # noqa: E501
+
+        result = await self._proxied.get(
+            entity,
+            ident,
+            options=options,
+            populate_existing=populate_existing,
+            with_for_update=with_for_update,
+            identity_token=identity_token,
+            execution_options=execution_options,
+        )
+        return result
+
+    async def get_one(
+        self,
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: ForUpdateParameter = None,
+        identity_token: Optional[Any] = None,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+    ) -> _O:
+        r"""Return an instance based on the given primary key identifier,
+        or raise an exception if not found.
+
+        .. container:: class_bases
+
+            Proxied for the :class:`_asyncio.AsyncSession` class on
+            behalf of the :class:`_asyncio.scoping.async_scoped_session` class.
+
+        Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects
+        no rows.
+
+        ..versionadded: 2.0.22
+
+        .. seealso::
+
+            :meth:`_orm.Session.get_one` - main documentation for get_one
+
+
+        """  # noqa: E501
+
+        return await self._proxied.get_one(
+            entity,
+            ident,
+            options=options,
+            populate_existing=populate_existing,
+            with_for_update=with_for_update,
+            identity_token=identity_token,
+            execution_options=execution_options,
+        )
+
     @overload
     async def stream(
         self,
index da69c4fb3efc651d6a6354f211a6d46866008130..b768b22526a35a5ae46efc1921af19bd26919e8f 100644 (file)
@@ -509,7 +509,7 @@ class AsyncSession(ReversibleProxy[Session]):
         else:
             execution_options = _EXECUTE_OPTIONS
 
-        result = await greenlet_spawn(
+        return await greenlet_spawn(
             self.sync_session.scalar,
             statement,
             params=params,
@@ -517,7 +517,6 @@ class AsyncSession(ReversibleProxy[Session]):
             bind_arguments=bind_arguments,
             **kw,
         )
-        return result
 
     @overload
     async def scalars(
@@ -588,7 +587,7 @@ class AsyncSession(ReversibleProxy[Session]):
         with_for_update: ForUpdateParameter = None,
         identity_token: Optional[Any] = None,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
-    ) -> Optional[_O]:
+    ) -> Union[_O, None]:
         """Return an instance based on the given primary key identifier,
         or ``None`` if not found.
 
@@ -599,9 +598,7 @@ class AsyncSession(ReversibleProxy[Session]):
 
         """
 
-        # result_obj = self.sync_session.get(entity, ident)
-
-        result_obj = await greenlet_spawn(
+        return await greenlet_spawn(
             cast("Callable[..., _O]", self.sync_session.get),
             entity,
             ident,
@@ -609,8 +606,44 @@ class AsyncSession(ReversibleProxy[Session]):
             populate_existing=populate_existing,
             with_for_update=with_for_update,
             identity_token=identity_token,
+            execution_options=execution_options,
+        )
+
+    async def get_one(
+        self,
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: ForUpdateParameter = None,
+        identity_token: Optional[Any] = None,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+    ) -> _O:
+        """Return an instance based on the given primary key identifier,
+        or raise an exception if not found.
+
+        Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects
+        no rows.
+
+        ..versionadded: 2.0.22
+
+        .. seealso::
+
+            :meth:`_orm.Session.get_one` - main documentation for get_one
+
+        """
+
+        return await greenlet_spawn(
+            cast("Callable[..., _O]", self.sync_session.get_one),
+            entity,
+            ident,
+            options=options,
+            populate_existing=populate_existing,
+            with_for_update=with_for_update,
+            identity_token=identity_token,
+            execution_options=execution_options,
         )
-        return result_obj
 
     @overload
     async def stream(
index fc144d98c4e0172537f20424094aa7b7c14a3456..39f69d589fc03cb0f79f218c99055e322e8a9da9 100644 (file)
@@ -118,6 +118,7 @@ __all__ = ["scoped_session"]
         "expunge_all",
         "flush",
         "get",
+        "get_one",
         "get_bind",
         "is_modified",
         "bulk_save_objects",
@@ -1028,6 +1029,56 @@ class scoped_session(Generic[_S]):
             bind_arguments=bind_arguments,
         )
 
+    def get_one(
+        self,
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: ForUpdateParameter = None,
+        identity_token: Optional[Any] = None,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+    ) -> _O:
+        r"""Return exactly one instance based on the given primary key
+        identifier, or raise an exception if not found.
+
+        .. container:: class_bases
+
+            Proxied for the :class:`_orm.Session` class on
+            behalf of the :class:`_orm.scoping.scoped_session` class.
+
+        Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query
+        selects no rows.
+
+        For a detailed documentation of the arguments see the
+        method :meth:`.Session.get`.
+
+        ..versionadded: 2.0.22
+
+        :return: The object instance, or ``None``.
+
+        .. seealso::
+
+            :meth:`.Session.get` - equivalent method that instead
+              returns ``None`` if no row was found with the provided primary
+              key
+
+
+        """  # noqa: E501
+
+        return self._proxied.get_one(
+            entity,
+            ident,
+            options=options,
+            populate_existing=populate_existing,
+            with_for_update=with_for_update,
+            identity_token=identity_token,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+        )
+
     def get_bind(
         self,
         mapper: Optional[_EntityBindKey[_O]] = None,
index e5eb5036dd738b502fa89c7cd3548724c5620196..2490dd1311e32168736761fed1facd0f3a638912 100644 (file)
@@ -3580,6 +3580,57 @@ class Session(_SessionClassMethods, EventTarget):
             bind_arguments=bind_arguments,
         )
 
+    def get_one(
+        self,
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: ForUpdateParameter = None,
+        identity_token: Optional[Any] = None,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+    ) -> _O:
+        """Return exactly one instance based on the given primary key
+        identifier, or raise an exception if not found.
+
+        Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query
+        selects no rows.
+
+        For a detailed documentation of the arguments see the
+        method :meth:`.Session.get`.
+
+        ..versionadded: 2.0.22
+
+        :return: The object instance, or ``None``.
+
+        .. seealso::
+
+            :meth:`.Session.get` - equivalent method that instead
+              returns ``None`` if no row was found with the provided primary
+              key
+
+        """
+
+        instance = self.get(
+            entity,
+            ident,
+            options=options,
+            populate_existing=populate_existing,
+            with_for_update=with_for_update,
+            identity_token=identity_token,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+        )
+
+        if instance is None:
+            raise sa_exc.NoResultFound(
+                "No row was found when one was required"
+            )
+
+        return instance
+
     def _get_impl(
         self,
         entity: _EntityBindKey[_O],
index 38c324ea0dbd09301ce43bd491bbcbafbcc7b55f..c4527e123f610d94d0dab8d441d09f510988c017 100644 (file)
@@ -690,6 +690,7 @@ def create_proxy_methods(
     classmethods: Sequence[str] = (),
     methods: Sequence[str] = (),
     attributes: Sequence[str] = (),
+    use_intermediate_variable: Sequence[str] = (),
 ) -> Callable[[_T], _T]:
     """A class decorator indicating attributes should refer to a proxy
     class.
index 8fa174eebaaea5f0e3f8ed220f8e86cb03bd1922..42687a398756d98e9317fa7bf1d8ed19f2c4b260 100644 (file)
@@ -193,6 +193,24 @@ class AsyncSessionQueryTest(AsyncFixture):
         u3 = await async_session.get(User, 12)
         is_(u3, None)
 
+    @async_test
+    async def test_get_one(self, async_session):
+        User = self.classes.User
+
+        u1 = await async_session.get_one(User, 7)
+        u2 = await async_session.get_one(User, 10)
+        u3 = await async_session.get_one(User, 7)
+
+        is_(u1, u3)
+        eq_(u1.name, "jack")
+        eq_(u2.name, "chuck")
+
+        with testing.expect_raises_message(
+            exc.NoResultFound,
+            "No row was found when one was required",
+        ):
+            await async_session.get_one(User, 12)
+
     @async_test
     async def test_force_a_lazyload(self, async_session):
         """test for #9298"""
index 389dbe00a0894155ef33b46b20a8286941391f89..3ff49fc82fe98e22d4c982a08b7810ba0e6ebc44 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import delete
 from sqlalchemy import event
+from sqlalchemy import exc
 from sqlalchemy import Float
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
@@ -35,6 +36,7 @@ from sqlalchemy.testing import expect_deprecated
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import provision
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing.engines import testing_reaper
 
@@ -196,7 +198,7 @@ class ShardTest:
         toronto = WeatherLocation("North America", "Toronto")
         london = WeatherLocation("Europe", "London")
         dublin = WeatherLocation("Europe", "Dublin")
-        brasilia = WeatherLocation("South America", "Brasila")
+        brasilia = WeatherLocation("South America", "Brasilia")
         quito = WeatherLocation("South America", "Quito")
         tokyo.reports.append(Report(80.0, id_=1))
         newyork.reports.append(Report(75, id_=1))
@@ -226,6 +228,21 @@ class ShardTest:
         t2 = sess.get(WeatherLocation, 1)
         is_(t2, tokyo)
 
+    def test_get_one(self):
+        sess = self._fixture_data()
+        brasilia = sess.get_one(WeatherLocation, 6)
+        eq_(brasilia.id, 6)
+        eq_(brasilia.city, "Brasilia")
+
+        toronto = sess.get_one(WeatherLocation, 3)
+        eq_(toronto.id, 3)
+        eq_(toronto.city, "Toronto")
+
+        with expect_raises_message(
+            exc.NoResultFound, "No row was found when one was required"
+        ):
+            sess.get_one(WeatherLocation, 25)
+
     def test_get_explicit_shard(self):
         sess = self._fixture_data()
         tokyo = (
index ce5c64a43af6b52a98b6a89cb67364e006cfdd97..85307add814a80dc78f3fa55a9d496c78ac459e8 100644 (file)
@@ -1213,18 +1213,6 @@ class GetTest(QueryTest):
             {"i": 1, "j": "2", "k": 3},
         )
 
-    def test_get(self):
-        User = self.classes.User
-
-        s = fixture_session()
-        assert s.get(User, 19) is None
-        u = s.get(User, 7)
-        u2 = s.get(User, 7)
-        assert u is u2
-        s.expunge_all()
-        u2 = s.get(User, 7)
-        assert u is not u2
-
     def test_get_synonym_direct_name(self, decl_base):
         """test #8753"""
 
index b304ac574540085e2f5837fbfc35657fe617b5c9..c9a47efc5affaf10b40822b9db23a65986b40ff9 100644 (file)
@@ -609,6 +609,59 @@ class SessionUtilTest(_fixtures.FixtureTest):
 
         is_true(called)
 
+    def test_get(self):
+        users, User = self.tables.users, self.classes.User
+        self.mapper_registry.map_imperatively(User, users)
+
+        s = fixture_session()
+        s.execute(
+            insert(self.tables.users),
+            [{"id": 7, "name": "7"}, {"id": 19, "name": "19"}],
+        )
+        assertions.is_not_none(s.get(User, 19))
+        u = s.get(User, 7)
+        u2 = s.get(User, 7)
+        assertions.is_not_none(u)
+        is_(u, u2)
+        s.expunge_all()
+        u2 = s.get(User, 7)
+        is_not(u, u2)
+
+    def test_get_one(self):
+        users, User = self.tables.users, self.classes.User
+        self.mapper_registry.map_imperatively(User, users)
+
+        s = fixture_session()
+        s.execute(
+            insert(self.tables.users),
+            [{"id": 7, "name": "7"}, {"id": 19, "name": "19"}],
+        )
+        u = s.get_one(User, 7)
+        u2 = s.get_one(User, 7)
+        assertions.is_not_none(u)
+        is_(u, u2)
+        s.expunge_all()
+        u2 = s.get_one(User, 7)
+        is_not(u, u2)
+
+    def test_get_one_2(self):
+        users, User = self.tables.users, self.classes.User
+        self.mapper_registry.map_imperatively(User, users)
+
+        sess = fixture_session()
+        user1 = User(id=1, name="u1")
+
+        sess.add(user1)
+        sess.commit()
+
+        u1 = sess.get_one(User, user1.id)
+        eq_(user1.name, u1.name)
+
+        with expect_raises_message(
+            sa.exc.NoResultFound, "No row was found when one was required"
+        ):
+            sess.get_one(User, 2)
+
 
 class SessionStateTest(_fixtures.FixtureTest):
     run_inserts = None
@@ -1928,7 +1981,14 @@ class SessionInterface(fixtures.MappedTest):
     def _public_session_methods(self):
         Session = sa.orm.session.Session
 
-        blocklist = {"begin", "query", "bind_mapper", "get", "bind_table"}
+        blocklist = {
+            "begin",
+            "query",
+            "bind_mapper",
+            "get",
+            "get_one",
+            "bind_table",
+        }
         specials = {"__iter__", "__contains__"}
         ok = set()
         for name in dir(Session):
index 857f8eb71824e520e924c2ce0ad2c0f17ddcef2b..9881d26426fdc738e9717613a7d51b1e889ffc09 100644 (file)
@@ -99,6 +99,7 @@ def create_proxy_methods(
     classmethods: Iterable[str] = (),
     methods: Iterable[str] = (),
     attributes: Iterable[str] = (),
+    use_intermediate_variable: Iterable[str] = (),
 ) -> Callable[[Type[_T]], Type[_T]]:
     """A class decorator that will copy attributes to a proxy class.
 
@@ -120,6 +121,7 @@ def create_proxy_methods(
             classmethods,
             methods,
             attributes,
+            use_intermediate_variable,
             cls,
         )
         return cls
@@ -180,6 +182,7 @@ def process_class(
     classmethods: Iterable[str],
     methods: Iterable[str],
     attributes: Iterable[str],
+    use_intermediate_variable: Iterable[str],
     cls: Type[Any],
 ):
     sphinx_symbol_match = re.match(r":class:`(.+)`", target_cls_sphinx_name)
@@ -192,6 +195,8 @@ def process_class(
 
     sphinx_symbol = sphinx_symbol_match.group(1)
 
+    require_intermediate = set(use_intermediate_variable)
+
     def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None:
         fn = getattr(target_cls, name)
 
@@ -255,19 +260,34 @@ def process_class(
             ).lstrip(),
         }
 
+        if fn.__name__ in require_intermediate:
+            metadata["line_prefix"] = "result ="
+            metadata["after_line"] = "return result\n"
+        else:
+            metadata["line_prefix"] = "return"
+            metadata["after_line"] = ""
+
         if clslevel:
             code = (
-                "@classmethod\n"
-                "%(async)sdef %(name)s%(grouped_args)s:\n"
-                '    r"""%(doc)s\n    """  # noqa: E501\n\n'
-                "    return %(await)s%(target_cls_name)s.%(name)s(%(apply_kw_proxied)s)\n\n"  # noqa: E501
+                '''\
+@classmethod
+%(async)sdef %(name)s%(grouped_args)s:
+    r"""%(doc)s\n    """  # noqa: E501
+
+    %(line_prefix)s %(await)s%(target_cls_name)s.%(name)s(%(apply_kw_proxied)s)
+    %(after_line)s
+'''
                 % metadata
             )
         else:
             code = (
-                "%(async)sdef %(name)s%(grouped_args)s:\n"
-                '    r"""%(doc)s\n    """  # noqa: E501\n\n'
-                "    return %(await)s%(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)\n\n"  # noqa: E501
+                '''\
+%(async)sdef %(name)s%(grouped_args)s:
+    r"""%(doc)s\n    """  # noqa: E501
+
+    %(line_prefix)s %(await)s%(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)
+    %(after_line)s
+'''  # noqa: E501
                 % metadata
             )