]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add scalars method to connection and session classes
authorMiguel Grinberg <miguel.grinberg@gmail.com>
Mon, 13 Sep 2021 18:41:13 +0000 (14:41 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Sep 2021 16:58:37 +0000 (12:58 -0400)
Added new methods :meth:`_orm.Session.scalars`,
:meth:`_engine.Connection.scalars`, :meth:`_asyncio.AsyncSession.scalars`
and :meth:`_asyncio.AsyncSession.stream_scalars`, which provide a short cut
to the use case of receiving a row-oriented :class:`_result.Result` object
and converting it to a :class:`_result.ScalarResult` object via the
:meth:`_engine.Result.scalars` method, to return a list of values rather
than a list of rows. The new methods are analogous to the long existing
:meth:`_orm.Session.scalar` and :meth:`_engine.Connection.scalar` methods
used to return a single value from the first row only. Pull request
courtesy Miguel Grinberg.

Fixes: #6990
Closes: #6991
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/6991
Pull-request-sha: b3e0bb3042c55b0cc5af6a25cb3f31b929f88a47

Change-Id: Ia445775e24ca964b0162c2c8e5ca67dd1e39199f

doc/build/changelog/unreleased_14/6990.rst [new file with mode: 0644]
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/session.py
test/engine/test_execute.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py
test/orm/test_bundle.py
test/orm/test_session.py

diff --git a/doc/build/changelog/unreleased_14/6990.rst b/doc/build/changelog/unreleased_14/6990.rst
new file mode 100644 (file)
index 0000000..1a53ad2
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: usecase, engine, orm
+    :tickets: 6990
+
+    Added new methods :meth:`_orm.Session.scalars`,
+    :meth:`_engine.Connection.scalars`, :meth:`_asyncio.AsyncSession.scalars`
+    and :meth:`_asyncio.AsyncSession.stream_scalars`, which provide a short cut
+    to the use case of receiving a row-oriented :class:`_result.Result` object
+    and converting it to a :class:`_result.ScalarResult` object via the
+    :meth:`_engine.Result.scalars` method, to return a list of values rather
+    than a list of rows. The new methods are analogous to the long existing
+    :meth:`_orm.Session.scalar` and :meth:`_engine.Connection.scalar` methods
+    used to return a single value from the first row only. Pull request
+    courtesy Miguel Grinberg.
index a316f904f01c478284a3df1750cd895a4f436864..25ced0343ef7ad7b2437970cd047c380f4793b99 100644 (file)
@@ -1157,10 +1157,28 @@ class Connection(Connectable):
         """Executes and returns the first column of the first row.
 
         The underlying result/cursor is closed after execution.
+
         """
 
         return self.execute(object_, *multiparams, **params).scalar()
 
+    def scalars(self, object_, *multiparams, **params):
+        """Executes and returns a scalar result set, which yields scalar values
+        from the first column of each row.
+
+        This method is equivalent to calling :meth:`_engine.Connection.execute`
+        to receive a :class:`_result.Result` object, then invoking the
+        :meth:`_result.Result.scalars` method to produce a
+        :class:`_result.ScalarResult` instance.
+
+        :return: a :class:`_result.ScalarResult`
+
+        .. versionadded:: 1.4.24
+
+        """
+
+        return self.execute(object_, *multiparams, **params).scalars()
+
     def execute(self, statement, *multiparams, **params):
         r"""Executes a SQL statement construct and returns a
         :class:`_engine.CursorResult`.
index 5a692ffb1be71e852d28f3e94b329af44e066bc2..ab29438ed0e72b89b6ea124a5ad3f2e64bf6da8c 100644 (file)
@@ -439,6 +439,47 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         result = await self.execute(statement, parameters, execution_options)
         return result.scalar()
 
+    async def scalars(
+        self,
+        statement,
+        parameters=None,
+        execution_options=util.EMPTY_DICT,
+    ):
+        r"""Executes a SQL statement construct and returns a scalar objects.
+
+        This method is shorthand for invoking the
+        :meth:`_engine.Result.scalars` method after invoking the
+        :meth:`_future.Connection.execute` method.  Parameters are equivalent.
+
+        :return: a :class:`_engine.ScalarResult` object.
+
+        .. versionadded:: 1.4.24
+
+        """
+        result = await self.execute(statement, parameters, execution_options)
+        return result.scalars()
+
+    async def stream_scalars(
+        self,
+        statement,
+        parameters=None,
+        execution_options=util.EMPTY_DICT,
+    ):
+        r"""Executes a SQL statement and returns a streaming scalar result
+        object.
+
+        This method is shorthand for invoking the
+        :meth:`_engine.AsyncResult.scalars` method after invoking the
+        :meth:`_future.Connection.stream` method.  Parameters are equivalent.
+
+        :return: an :class:`_asyncio.AsyncScalarResult` object.
+
+        .. versionadded:: 1.4.24
+
+        """
+        result = await self.stream(statement, parameters, execution_options)
+        return result.scalars()
+
     async def run_sync(self, fn, *arg, **kw):
         """Invoke the given sync callable passing self as the first argument.
 
index e4372f448450aa96c497c681e5fc0e300cd7b662..6e3ac5a900f741d9ecba104bcaa6a9f09357f950 100644 (file)
@@ -231,6 +231,37 @@ class AsyncSession(ReversibleProxy):
         )
         return result.scalar()
 
+    async def scalars(
+        self,
+        statement,
+        params=None,
+        execution_options=util.EMPTY_DICT,
+        bind_arguments=None,
+        **kw
+    ):
+        """Execute a statement and return scalar results.
+
+        :return: a :class:`_result.ScalarResult` object
+
+        .. versionadded:: 1.4.24
+
+        .. seealso::
+
+            :meth:`_orm.Session.scalars` - main documentation for scalars
+
+            :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version
+
+        """
+
+        result = await self.execute(
+            statement,
+            params=params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            **kw
+        )
+        return result.scalars()
+
     async def get(
         self,
         entity,
@@ -287,6 +318,37 @@ class AsyncSession(ReversibleProxy):
         )
         return _result.AsyncResult(result)
 
+    async def stream_scalars(
+        self,
+        statement,
+        params=None,
+        execution_options=util.EMPTY_DICT,
+        bind_arguments=None,
+        **kw
+    ):
+        """Execute a statement and return a stream of scalar results.
+
+        :return: an :class:`_asyncio.AsyncScalarResult` object
+
+        .. versionadded:: 1.4.24
+
+        .. seealso::
+
+            :meth:`_orm.Session.scalars` - main documentation for scalars
+
+            :meth:`_asyncio.AsyncSession.scalars` - non streaming version
+
+        """
+
+        result = await self.stream(
+            statement,
+            params=params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            **kw
+        )
+        return result.scalars()
+
     async def delete(self, instance):
         """Mark an instance as deleted.
 
index 0368bf83a39bc4c0f9b3cdd5e012dfadeefaa239..f051d8df2c4e8f5cf3def951929a859a823769b5 100644 (file)
@@ -1724,6 +1724,35 @@ class Session(_SessionClassMethods):
             **kw
         ).scalar()
 
+    def scalars(
+        self,
+        statement,
+        params=None,
+        execution_options=util.EMPTY_DICT,
+        bind_arguments=None,
+        **kw
+    ):
+        """Execute a statement and return the results as scalars.
+
+        Usage and parameters are the same as that of
+        :meth:`_orm.Session.execute`; the return result is a
+        :class:`_result.ScalarResult` filtering object which
+        will return single elements rather than :class:`_row.Row` objects.
+
+        :return:  a :class:`_result.ScalarResult` object
+
+        .. versionadded:: 1.4.24
+
+        """
+
+        return self.execute(
+            statement,
+            params=params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            **kw
+        ).scalars()
+
     def close(self):
         """Close out the transactional resources and ORM objects used by this
         :class:`_orm.Session`.
index dd4ee32f8c49ee84bcaf31cdb3abd4927c7569a2..791e42bc0fa3424fb03cb6b11fe25829ee98ae72 100644 (file)
@@ -707,6 +707,32 @@ class ExecuteTest(fixtures.TablesTest):
                 eq_(conn.scalar(select(1)), 1)
             eng.dispose()
 
+    def test_scalar(self, connection):
+        conn = connection
+        users = self.tables.users
+        conn.execute(
+            users.insert(),
+            [
+                {"user_id": 1, "user_name": "sandy"},
+                {"user_id": 2, "user_name": "spongebob"},
+            ],
+        )
+        res = conn.scalar(select(users.c.user_name).order_by(users.c.user_id))
+        eq_(res, "sandy")
+
+    def test_scalars(self, connection):
+        conn = connection
+        users = self.tables.users
+        conn.execute(
+            users.insert(),
+            [
+                {"user_id": 1, "user_name": "sandy"},
+                {"user_id": 2, "user_name": "spongebob"},
+            ],
+        )
+        res = conn.scalars(select(users.c.user_name).order_by(users.c.user_id))
+        eq_(res.all(), ["sandy", "spongebob"])
+
 
 class UnicodeReturnsTest(fixtures.TestBase):
     @testing.requires.python3
index 0450e30544d1be9214d934d97c397d169ae50cda..01e3e3040ffc341c7aa1b78ed72cb05d67eba981 100644 (file)
@@ -889,6 +889,20 @@ class AsyncResultTest(EngineFixture):
             ):
                 await result.one()
 
+    @testing.combinations(
+        ("scalars",), ("stream_scalars",), argnames="filter_"
+    )
+    @async_test
+    async def test_scalars(self, async_engine, filter_):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            if filter_ == "scalars":
+                result = (await conn.scalars(select(users))).all()
+            elif filter_ == "stream_scalars":
+                result = await (await conn.stream_scalars(select(users))).all()
+
+        eq_(result, list(range(1, 20)))
+
 
 class TextSyncDBAPI(fixtures.TestBase):
     def test_sync_dbapi_raises(self):
index cd90547406cb7df8484881841601b2e7dedcf4dc..4e475b2122e2b9c601a09e344d05d1148f8c909e 100644 (file)
@@ -91,6 +91,25 @@ class AsyncSessionQueryTest(AsyncFixture):
         result = await async_session.scalar(stmt)
         eq_(result, 7)
 
+    @testing.combinations(
+        ("scalars",), ("stream_scalars",), argnames="filter_"
+    )
+    @async_test
+    async def test_scalars(self, async_session, filter_):
+        User = self.classes.User
+
+        stmt = (
+            select(User)
+            .options(selectinload(User.addresses))
+            .order_by(User.id)
+        )
+
+        if filter_ == "scalars":
+            result = (await async_session.scalars(stmt)).all()
+        elif filter_ == "stream_scalars":
+            result = await (await async_session.stream_scalars(stmt)).all()
+        eq_(result, self.static.user_address_result)
+
     @async_test
     async def test_get(self, async_session):
         User = self.classes.User
index b0113f1fcc4066129fcad1870c72d304826f2c72..db4267fc39a88dae3745a86a1813b831f3b9fc21 100644 (file)
@@ -304,7 +304,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
 
         stmt = select(b1).filter(b1.c.x.between("d3d1", "d5d1"))
         eq_(
-            sess.execute(stmt).scalars().all(),
+            sess.scalars(stmt).all(),
             [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")],
         )
 
@@ -335,7 +335,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
 
         stmt = select(b1).filter(b1.c.d1.between("d3d1", "d5d1"))
         eq_(
-            sess.execute(stmt).scalars().all(),
+            sess.scalars(stmt).all(),
             [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")],
         )
 
index e4010c635fe7c447dfa20d8d7459956dea1be828..b8997015f5212d5d3922c1c7be6ae799ce6bcf84 100644 (file)
@@ -1950,7 +1950,9 @@ class DisposedStates(fixtures.MappedTest):
 class SessionInterface(fixtures.TestBase):
     """Bogus args to Session methods produce actionable exceptions."""
 
-    _class_methods = set(("connection", "execute", "get_bind", "scalar"))
+    _class_methods = set(
+        ("connection", "execute", "get_bind", "scalar", "scalars")
+    )
 
     def _public_session_methods(self):
         Session = sa.orm.session.Session
@@ -2078,6 +2080,10 @@ class SessionInterface(fixtures.TestBase):
             "scalar", text("SELECT 1"), bind_arguments=dict(mapper=user_arg)
         )
 
+        raises_(
+            "scalars", text("SELECT 1"), bind_arguments=dict(mapper=user_arg)
+        )
+
         eq_(
             watchdog,
             self._class_methods,