From 9c217ea2e95d720928e40fb3a16c4f2706738868 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Mon, 26 Feb 2024 22:16:18 +0100 Subject: [PATCH] add doctests to asyncio tutorial Change-Id: I28c94a7bc1e7ae572af0d206b8e63a110dc6fd7a (cherry picked from commit e32954b91eef968be33ac4b46c16055daffa90dd) --- doc/build/orm/extensions/asyncio.rst | 366 ++++++++++++++++----------- test/base/test_tutorials.py | 103 ++------ 2 files changed, 243 insertions(+), 226 deletions(-) diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index 6649a98103..fbd965d15d 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -64,47 +64,64 @@ methods which both deliver asynchronous context managers. The :class:`_asyncio.AsyncConnection` can then invoke statements using either the :meth:`_asyncio.AsyncConnection.execute` method to deliver a buffered :class:`_engine.Result`, or the :meth:`_asyncio.AsyncConnection.stream` method -to deliver a streaming server-side :class:`_asyncio.AsyncResult`:: - - import asyncio - - from sqlalchemy import Column - from sqlalchemy import MetaData - from sqlalchemy import select - from sqlalchemy import String - from sqlalchemy import Table - from sqlalchemy.ext.asyncio import create_async_engine - - meta = MetaData() - t1 = Table("t1", meta, Column("name", String(50), primary_key=True)) - - - async def async_main() -> None: - engine = create_async_engine( - "postgresql+asyncpg://scott:tiger@localhost/test", - echo=True, - ) - - async with engine.begin() as conn: - await conn.run_sync(meta.create_all) - - await conn.execute( - t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}] - ) - - async with engine.connect() as conn: - # select a Result, which will be delivered with buffered - # results - result = await conn.execute(select(t1).where(t1.c.name == "some name 1")) - - print(result.fetchall()) - - # for AsyncEngine created in function scope, close and - # clean-up pooled connections - await engine.dispose() - - - asyncio.run(async_main()) +to deliver a streaming server-side :class:`_asyncio.AsyncResult`: + +.. sourcecode:: pycon+sql + + >>> import asyncio + + >>> from sqlalchemy import Column + >>> from sqlalchemy import MetaData + >>> from sqlalchemy import select + >>> from sqlalchemy import String + >>> from sqlalchemy import Table + >>> from sqlalchemy.ext.asyncio import create_async_engine + + >>> meta = MetaData() + >>> t1 = Table("t1", meta, Column("name", String(50), primary_key=True)) + + + >>> async def async_main() -> None: + ... engine = create_async_engine("sqlite+aiosqlite://", echo=True) + ... + ... async with engine.begin() as conn: + ... await conn.run_sync(meta.drop_all) + ... await conn.run_sync(meta.create_all) + ... + ... await conn.execute( + ... t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}] + ... ) + ... + ... async with engine.connect() as conn: + ... # select a Result, which will be delivered with buffered + ... # results + ... result = await conn.execute(select(t1).where(t1.c.name == "some name 1")) + ... + ... print(result.fetchall()) + ... + ... # for AsyncEngine created in function scope, close and + ... # clean-up pooled connections + ... await engine.dispose() + + + >>> asyncio.run(async_main()) + {execsql}BEGIN (implicit) + ... + CREATE TABLE t1 ( + name VARCHAR(50) NOT NULL, + PRIMARY KEY (name) + ) + ... + INSERT INTO t1 (name) VALUES (?) + [...] [('some name 1',), ('some name 2',)] + COMMIT + BEGIN (implicit) + SELECT t1.name + FROM t1 + WHERE t1.name = ? + [...] ('some name 1',) + [('some name 1',)] + ROLLBACK Above, the :meth:`_asyncio.AsyncConnection.run_sync` method may be used to invoke special DDL functions such as :meth:`_schema.MetaData.create_all` that @@ -154,114 +171,165 @@ this. :ref:`asyncio_concurrency` and :ref:`session_faq_threadsafe` for background. The example below illustrates a complete example including mapper and session -configuration:: - - from __future__ import annotations - - import asyncio - import datetime - from typing import List - - from sqlalchemy import ForeignKey - from sqlalchemy import func - from sqlalchemy import select - from sqlalchemy.ext.asyncio import AsyncAttrs - from sqlalchemy.ext.asyncio import async_sessionmaker - from sqlalchemy.ext.asyncio import AsyncSession - from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped - from sqlalchemy.orm import mapped_column - from sqlalchemy.orm import relationship - from sqlalchemy.orm import selectinload - - - class Base(AsyncAttrs, DeclarativeBase): - pass - - - class A(Base): - __tablename__ = "a" - - id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[str] - create_date: Mapped[datetime.datetime] = mapped_column(server_default=func.now()) - bs: Mapped[List[B]] = relationship() - - - class B(Base): - __tablename__ = "b" - id: Mapped[int] = mapped_column(primary_key=True) - a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) - data: Mapped[str] - - - async def insert_objects(async_session: async_sessionmaker[AsyncSession]) -> None: - async with async_session() as session: - async with session.begin(): - session.add_all( - [ - A(bs=[B(data="b1"), B(data="b2")], data="a1"), - A(bs=[], data="a2"), - A(bs=[B(data="b3"), B(data="b4")], data="a3"), - ] - ) - - - async def select_and_update_objects( - async_session: async_sessionmaker[AsyncSession], - ) -> None: - async with async_session() as session: - stmt = select(A).options(selectinload(A.bs)) - - result = await session.execute(stmt) - - for a in result.scalars(): - print(a) - print(f"created at: {a.create_date}") - for b in a.bs: - print(b, b.data) - - result = await session.execute(select(A).order_by(A.id).limit(1)) - - a1 = result.scalars().one() - - a1.data = "new data" - - await session.commit() - - # access attribute subsequent to commit; this is what - # expire_on_commit=False allows - print(a1.data) - - # alternatively, AsyncAttrs may be used to access any attribute - # as an awaitable (new in 2.0.13) - for b1 in await a1.awaitable_attrs.bs: - print(b1, b1.data) - - - async def async_main() -> None: - engine = create_async_engine( - "postgresql+asyncpg://scott:tiger@localhost/test", - echo=True, - ) - - # async_sessionmaker: a factory for new AsyncSession objects. - # expire_on_commit - don't expire objects after transaction commit - async_session = async_sessionmaker(engine, expire_on_commit=False) - - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - await insert_objects(async_session) - await select_and_update_objects(async_session) - - # for AsyncEngine created in function scope, close and - # clean-up pooled connections - await engine.dispose() - - - asyncio.run(async_main()) +configuration: + +.. sourcecode:: pycon+sql + + >>> from __future__ import annotations + + >>> import asyncio + >>> import datetime + >>> from typing import List + + >>> from sqlalchemy import ForeignKey + >>> from sqlalchemy import func + >>> from sqlalchemy import select + >>> from sqlalchemy.ext.asyncio import AsyncAttrs + >>> from sqlalchemy.ext.asyncio import async_sessionmaker + >>> from sqlalchemy.ext.asyncio import AsyncSession + >>> from sqlalchemy.ext.asyncio import create_async_engine + >>> from sqlalchemy.orm import DeclarativeBase + >>> from sqlalchemy.orm import Mapped + >>> from sqlalchemy.orm import mapped_column + >>> from sqlalchemy.orm import relationship + >>> from sqlalchemy.orm import selectinload + + + >>> class Base(AsyncAttrs, DeclarativeBase): + ... pass + + >>> class B(Base): + ... __tablename__ = "b" + ... + ... id: Mapped[int] = mapped_column(primary_key=True) + ... a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + ... data: Mapped[str] + + >>> class A(Base): + ... __tablename__ = "a" + ... + ... id: Mapped[int] = mapped_column(primary_key=True) + ... data: Mapped[str] + ... create_date: Mapped[datetime.datetime] = mapped_column(server_default=func.now()) + ... bs: Mapped[List[B]] = relationship() + + >>> async def insert_objects(async_session: async_sessionmaker[AsyncSession]) -> None: + ... async with async_session() as session: + ... async with session.begin(): + ... session.add_all( + ... [ + ... A(bs=[B(data="b1"), B(data="b2")], data="a1"), + ... A(bs=[], data="a2"), + ... A(bs=[B(data="b3"), B(data="b4")], data="a3"), + ... ] + ... ) + + + >>> async def select_and_update_objects( + ... async_session: async_sessionmaker[AsyncSession], + ... ) -> None: + ... async with async_session() as session: + ... stmt = select(A).order_by(A.id).options(selectinload(A.bs)) + ... + ... result = await session.execute(stmt) + ... + ... for a in result.scalars(): + ... print(a, a.data) + ... print(f"created at: {a.create_date}") + ... for b in a.bs: + ... print(b, b.data) + ... + ... result = await session.execute(select(A).order_by(A.id).limit(1)) + ... + ... a1 = result.scalars().one() + ... + ... a1.data = "new data" + ... + ... await session.commit() + ... + ... # access attribute subsequent to commit; this is what + ... # expire_on_commit=False allows + ... print(a1.data) + ... + ... # alternatively, AsyncAttrs may be used to access any attribute + ... # as an awaitable (new in 2.0.13) + ... for b1 in await a1.awaitable_attrs.bs: + ... print(b1, b1.data) + + + >>> async def async_main() -> None: + ... engine = create_async_engine("sqlite+aiosqlite://", echo=True) + ... + ... # async_sessionmaker: a factory for new AsyncSession objects. + ... # expire_on_commit - don't expire objects after transaction commit + ... async_session = async_sessionmaker(engine, expire_on_commit=False) + ... + ... async with engine.begin() as conn: + ... await conn.run_sync(Base.metadata.create_all) + ... + ... await insert_objects(async_session) + ... await select_and_update_objects(async_session) + ... + ... # for AsyncEngine created in function scope, close and + ... # clean-up pooled connections + ... await engine.dispose() + + + >>> asyncio.run(async_main()) + {execsql}BEGIN (implicit) + ... + CREATE TABLE a ( + id INTEGER NOT NULL, + data VARCHAR NOT NULL, + create_date DATETIME DEFAULT (CURRENT_TIMESTAMP) NOT NULL, + PRIMARY KEY (id) + ) + ... + CREATE TABLE b ( + id INTEGER NOT NULL, + a_id INTEGER NOT NULL, + data VARCHAR NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY(a_id) REFERENCES a (id) + ) + ... + COMMIT + BEGIN (implicit) + INSERT INTO a (data) VALUES (?) RETURNING id, create_date + [...] ('a1',) + ... + INSERT INTO b (a_id, data) VALUES (?, ?) RETURNING id + [...] (1, 'b2') + ... + COMMIT + BEGIN (implicit) + SELECT a.id, a.data, a.create_date + FROM a ORDER BY a.id + [...] () + SELECT b.a_id AS b_a_id, b.id AS b_id, b.data AS b_data + FROM b + WHERE b.a_id IN (?, ?, ?) + [...] (1, 2, 3) + a1 + created at: ... + b1 + b2 + a2 + created at: ... + a3 + created at: ... + b3 + b4 + SELECT a.id, a.data, a.create_date + FROM a ORDER BY a.id + LIMIT ? OFFSET ? + [...] (1, 0) + UPDATE a SET data=? WHERE a.id = ? + [...] ('new data', 1) + COMMIT + new data + b1 + b2 In the example above, the :class:`_asyncio.AsyncSession` is instantiated using the optional :class:`_asyncio.async_sessionmaker` helper, which provides diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py index b920f25f0a..7543b1c100 100644 --- a/test/base/test_tutorials.py +++ b/test/base/test_tutorials.py @@ -1,14 +1,17 @@ from __future__ import annotations +import asyncio import doctest import logging import os import re import sys +from sqlalchemy.engine.url import make_url from sqlalchemy.testing import config from sqlalchemy.testing import fixtures from sqlalchemy.testing import requires +from sqlalchemy.testing import skip_test class DocTest(fixtures.TestBase): @@ -65,12 +68,9 @@ class DocTest(fixtures.TestBase): doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.IGNORE_EXCEPTION_DETAIL - | _get_allow_unicode_flag() ) runner = doctest.DocTestRunner( - verbose=None, - optionflags=optionflags, - checker=_get_unicode_checker(), + verbose=config.options.verbose >= 2, optionflags=optionflags ) parser = doctest.DocTestParser() globs = {"print_function": print} @@ -163,90 +163,39 @@ class DocTest(fixtures.TestBase): ) def test_orm_queryguide_inheritance(self): - self._run_doctest( - "orm/queryguide/inheritance.rst", - ) + self._run_doctest("orm/queryguide/inheritance.rst") @requires.update_from def test_orm_queryguide_dml(self): - self._run_doctest( - "orm/queryguide/dml.rst", - ) + self._run_doctest("orm/queryguide/dml.rst") def test_orm_large_collections(self): - self._run_doctest( - "orm/large_collections.rst", - ) + self._run_doctest("orm/large_collections.rst") def test_orm_queryguide_columns(self): - self._run_doctest( - "orm/queryguide/columns.rst", - ) + self._run_doctest("orm/queryguide/columns.rst") def test_orm_quickstart(self): self._run_doctest("orm/quickstart.rst") - -# unicode checker courtesy pytest - - -def _get_unicode_checker(): - """ - Returns a doctest.OutputChecker subclass that takes in account the - ALLOW_UNICODE option to ignore u'' prefixes in strings. Useful - when the same doctest should run in Python 2 and Python 3. - - An inner class is used to avoid importing "doctest" at the module - level. - """ - if hasattr(_get_unicode_checker, "UnicodeOutputChecker"): - return _get_unicode_checker.UnicodeOutputChecker() - - import doctest - import re - - class UnicodeOutputChecker(doctest.OutputChecker): - """ - Copied from doctest_nose_plugin.py from the nltk project: - https://github.com/nltk/nltk - """ - - _literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE) - - def check_output(self, want, got, optionflags): - res = doctest.OutputChecker.check_output( - self, want, got, optionflags - ) - if res: - return True - - if not (optionflags & _get_allow_unicode_flag()): - return False - - else: # pragma: no cover - # the code below will end up executed only in Python 2 in - # our tests, and our coverage check runs in Python 3 only - def remove_u_prefixes(txt): - return re.sub(self._literal_re, r"\1\2", txt) - - want = remove_u_prefixes(want) - got = remove_u_prefixes(got) - res = doctest.OutputChecker.check_output( - self, want, got, optionflags - ) - return res - - _get_unicode_checker.UnicodeOutputChecker = UnicodeOutputChecker - return _get_unicode_checker.UnicodeOutputChecker() - - -def _get_allow_unicode_flag(): - """ - Registers and returns the ALLOW_UNICODE flag. - """ - import doctest - - return doctest.register_optionflag("ALLOW_UNICODE") + @config.fixture(scope="class") + def restore_asyncio(self): + # NOTE: this is required since test_asyncio will remove the global + # loop. 2.1 uses runners that don't require this hack + yield + ep = asyncio.get_event_loop_policy() + try: + ep.get_event_loop() + except RuntimeError: + ep.set_event_loop(ep.new_event_loop()) + + @requires.greenlet + def test_asyncio(self, restore_asyncio): + try: + make_url("sqlite+aiosqlite://").get_dialect().import_dbapi() + except ImportError: + skip_test("missing aiosqile") + self._run_doctest("orm/extensions/asyncio.rst") # increase number to force pipeline run. 1 -- 2.47.2