From b3e0bb3042c55b0cc5af6a25cb3f31b929f88a47 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 5 Sep 2021 23:47:30 +0100 Subject: [PATCH] Add scalars method to connection and session classes --- doc/build/changelog/unreleased_14/6990.rst | 14 ++++++ lib/sqlalchemy/engine/base.py | 16 +++++- lib/sqlalchemy/ext/asyncio/engine.py | 41 +++++++++++++++ lib/sqlalchemy/ext/asyncio/session.py | 58 ++++++++++++++++++++++ lib/sqlalchemy/orm/session.py | 29 +++++++++++ test/engine/test_execute.py | 26 ++++++++++ test/ext/asyncio/test_engine_py3k.py | 14 ++++++ test/ext/asyncio/test_session_py3k.py | 19 +++++++ test/orm/test_bundle.py | 4 +- test/orm/test_session.py | 8 ++- 10 files changed, 224 insertions(+), 5 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6990.rst diff --git a/doc/build/changelog/unreleased_14/6990.rst b/doc/build/changelog/unreleased_14/6990.rst new file mode 100644 index 0000000000..1a53ad2fa6 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6990.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: usecase, engine, orm + :tickets: 6990 + + Added new methods :meth:`_orm.Session.scalars`, + :meth:`_engine.Connection.scalars`, :meth:`_asyncio.AsyncSession.scalars` + and :meth:`_asyncio.AsyncSession.stream_scalars`, which provide a short cut + to the use case of receiving a row-oriented :class:`_result.Result` object + and converting it to a :class:`_result.ScalarResult` object via the + :meth:`_engine.Result.scalars` method, to return a list of values rather + than a list of rows. The new methods are analogous to the long existing + :meth:`_orm.Session.scalar` and :meth:`_engine.Connection.scalar` methods + used to return a single value from the first row only. Pull request + courtesy Miguel Grinberg. diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index a316f904f0..ac62f41e58 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1154,13 +1154,25 @@ class Connection(Connectable): self.__can_reconnect = False def scalar(self, object_, *multiparams, **params): - """Executes and returns the first column of the first row. + """Executes and returns a scalar result set. + + :return: a :class:_result.ScalarResult - The underlying result/cursor is closed after execution. """ return self.execute(object_, *multiparams, **params).scalar() + def scalars(self, object_, *multiparams, **params): + """Executes and returns a scalar result set. + + :return: a :class:_result.ScalarResult + + .. versionadded:: 1.4.24 + + """ + + return self.execute(object_, *multiparams, **params).scalars() + def execute(self, statement, *multiparams, **params): r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 5a692ffb1b..90ee470608 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -439,6 +439,47 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): result = await self.execute(statement, parameters, execution_options) return result.scalar() + async def scalars( + self, + statement, + parameters=None, + execution_options=util.EMPTY_DICT, + ): + r"""Executes a SQL statement construct and returns a scalar objects. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalars` method after invoking the + :meth:`_future.Connection.execute` method. Parameters are equivalent. + + :return: a :class:`_engine.ScalarResult` object. + + .. versionadded:: 1.4.24 + + """ + result = await self.execute(statement, parameters, execution_options) + return result.scalars() + + async def stream_scalars( + self, + statement, + parameters=None, + execution_options=util.EMPTY_DICT, + ): + r"""Executes a SQL statement and returns a streaming scalar result + object. + + This method is shorthand for invoking the + :meth:`_engine.AsyncResult.scalars` method after invoking the + :meth:`_future.Connection.stream` method. Parameters are equivalent. + + :return: a :class:`_asyncio.AsyncScalarResult` object. + + .. versionadded:: 1.4.24 + + """ + result = await self.stream(statement, parameters, execution_options) + return result.scalars() + async def run_sync(self, fn, *arg, **kw): """Invoke the given sync callable passing self as the first argument. diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index e4372f4484..3d5d187249 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -231,6 +231,35 @@ class AsyncSession(ReversibleProxy): ) return result.scalar() + async def scalars( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return scalar results. + + :return: an :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + """ + + result = await self.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return result.scalars() + async def get( self, entity, @@ -287,6 +316,35 @@ class AsyncSession(ReversibleProxy): ) return _result.AsyncResult(result) + async def stream_scalars( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return a stream of scalar results. + + :return: an :class:`_asyncio.AsyncScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + """ + + result = await self.stream( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return result.scalars() + async def delete(self, instance): """Mark an instance as deleted. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a93684126b..09bda2d114 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1724,6 +1724,35 @@ class Session(_SessionClassMethods): **kw ).scalar() + def scalars( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return the results as scalars. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a + :class:`_result.ScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 + + """ + + return self.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ).scalars() + def close(self): """Close out the transactional resources and ORM objects used by this :class:`_orm.Session`. diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index dd4ee32f8c..791e42bc0f 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -707,6 +707,32 @@ class ExecuteTest(fixtures.TablesTest): eq_(conn.scalar(select(1)), 1) eng.dispose() + def test_scalar(self, connection): + conn = connection + users = self.tables.users + conn.execute( + users.insert(), + [ + {"user_id": 1, "user_name": "sandy"}, + {"user_id": 2, "user_name": "spongebob"}, + ], + ) + res = conn.scalar(select(users.c.user_name).order_by(users.c.user_id)) + eq_(res, "sandy") + + def test_scalars(self, connection): + conn = connection + users = self.tables.users + conn.execute( + users.insert(), + [ + {"user_id": 1, "user_name": "sandy"}, + {"user_id": 2, "user_name": "spongebob"}, + ], + ) + res = conn.scalars(select(users.c.user_name).order_by(users.c.user_id)) + eq_(res.all(), ["sandy", "spongebob"]) + class UnicodeReturnsTest(fixtures.TestBase): @testing.requires.python3 diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index c75dd86655..de6d13472a 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -888,6 +888,20 @@ class AsyncResultTest(EngineFixture): ): await result.one() + @testing.combinations( + ("scalars",), ("stream_scalars",), argnames="filter_" + ) + @async_test + async def test_scalars(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + if filter_ == "scalars": + result = (await conn.scalars(select(users))).all() + elif filter_ == "stream_scalars": + result = await (await conn.stream_scalars(select(users))).all() + + eq_(result, list(range(1, 20))) + class TextSyncDBAPI(fixtures.TestBase): def test_sync_dbapi_raises(self): diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index cd90547406..4e475b2122 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -91,6 +91,25 @@ class AsyncSessionQueryTest(AsyncFixture): result = await async_session.scalar(stmt) eq_(result, 7) + @testing.combinations( + ("scalars",), ("stream_scalars",), argnames="filter_" + ) + @async_test + async def test_scalars(self, async_session, filter_): + User = self.classes.User + + stmt = ( + select(User) + .options(selectinload(User.addresses)) + .order_by(User.id) + ) + + if filter_ == "scalars": + result = (await async_session.scalars(stmt)).all() + elif filter_ == "stream_scalars": + result = await (await async_session.stream_scalars(stmt)).all() + eq_(result, self.static.user_address_result) + @async_test async def test_get(self, async_session): User = self.classes.User diff --git a/test/orm/test_bundle.py b/test/orm/test_bundle.py index b0113f1fcc..db4267fc39 100644 --- a/test/orm/test_bundle.py +++ b/test/orm/test_bundle.py @@ -304,7 +304,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): stmt = select(b1).filter(b1.c.x.between("d3d1", "d5d1")) eq_( - sess.execute(stmt).scalars().all(), + sess.scalars(stmt).all(), [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")], ) @@ -335,7 +335,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): stmt = select(b1).filter(b1.c.d1.between("d3d1", "d5d1")) eq_( - sess.execute(stmt).scalars().all(), + sess.scalars(stmt).all(), [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")], ) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index e4010c635f..b8997015f5 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1950,7 +1950,9 @@ class DisposedStates(fixtures.MappedTest): class SessionInterface(fixtures.TestBase): """Bogus args to Session methods produce actionable exceptions.""" - _class_methods = set(("connection", "execute", "get_bind", "scalar")) + _class_methods = set( + ("connection", "execute", "get_bind", "scalar", "scalars") + ) def _public_session_methods(self): Session = sa.orm.session.Session @@ -2078,6 +2080,10 @@ class SessionInterface(fixtures.TestBase): "scalar", text("SELECT 1"), bind_arguments=dict(mapper=user_arg) ) + raises_( + "scalars", text("SELECT 1"), bind_arguments=dict(mapper=user_arg) + ) + eq_( watchdog, self._class_methods, -- 2.47.2