]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add doctests to asyncio tutorial
authorFederico Caselli <cfederico87@gmail.com>
Mon, 26 Feb 2024 21:16:18 +0000 (22:16 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 11 Mar 2024 22:56:52 +0000 (23:56 +0100)
Change-Id: I28c94a7bc1e7ae572af0d206b8e63a110dc6fd7a
(cherry picked from commit e32954b91eef968be33ac4b46c16055daffa90dd)

doc/build/orm/extensions/asyncio.rst
test/base/test_tutorials.py

index 6649a9810374b0e5965dd76b1c8774ae96d3e333..fbd965d15d9fbb65a292289f2e20e4f08e81d914 100644 (file)
@@ -64,47 +64,64 @@ methods which both deliver asynchronous context managers.   The
 :class:`_asyncio.AsyncConnection` can then invoke statements using either the
 :meth:`_asyncio.AsyncConnection.execute` method to deliver a buffered
 :class:`_engine.Result`, or the :meth:`_asyncio.AsyncConnection.stream` method
-to deliver a streaming server-side :class:`_asyncio.AsyncResult`::
-
-    import asyncio
-
-    from sqlalchemy import Column
-    from sqlalchemy import MetaData
-    from sqlalchemy import select
-    from sqlalchemy import String
-    from sqlalchemy import Table
-    from sqlalchemy.ext.asyncio import create_async_engine
-
-    meta = MetaData()
-    t1 = Table("t1", meta, Column("name", String(50), primary_key=True))
-
-
-    async def async_main() -> None:
-        engine = create_async_engine(
-            "postgresql+asyncpg://scott:tiger@localhost/test",
-            echo=True,
-        )
-
-        async with engine.begin() as conn:
-            await conn.run_sync(meta.create_all)
-
-            await conn.execute(
-                t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}]
-            )
-
-        async with engine.connect() as conn:
-            # select a Result, which will be delivered with buffered
-            # results
-            result = await conn.execute(select(t1).where(t1.c.name == "some name 1"))
-
-            print(result.fetchall())
-
-        # for AsyncEngine created in function scope, close and
-        # clean-up pooled connections
-        await engine.dispose()
-
-
-    asyncio.run(async_main())
+to deliver a streaming server-side :class:`_asyncio.AsyncResult`:
+
+.. sourcecode:: pycon+sql
+
+    >>> import asyncio
+
+    >>> from sqlalchemy import Column
+    >>> from sqlalchemy import MetaData
+    >>> from sqlalchemy import select
+    >>> from sqlalchemy import String
+    >>> from sqlalchemy import Table
+    >>> from sqlalchemy.ext.asyncio import create_async_engine
+
+    >>> meta = MetaData()
+    >>> t1 = Table("t1", meta, Column("name", String(50), primary_key=True))
+
+
+    >>> async def async_main() -> None:
+    ...     engine = create_async_engine("sqlite+aiosqlite://", echo=True)
+    ...
+    ...     async with engine.begin() as conn:
+    ...         await conn.run_sync(meta.drop_all)
+    ...         await conn.run_sync(meta.create_all)
+    ...
+    ...         await conn.execute(
+    ...             t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}]
+    ...         )
+    ...
+    ...     async with engine.connect() as conn:
+    ...         # select a Result, which will be delivered with buffered
+    ...         # results
+    ...         result = await conn.execute(select(t1).where(t1.c.name == "some name 1"))
+    ...
+    ...         print(result.fetchall())
+    ...
+    ...     # for AsyncEngine created in function scope, close and
+    ...     # clean-up pooled connections
+    ...     await engine.dispose()
+
+
+    >>> asyncio.run(async_main())
+    {execsql}BEGIN (implicit)
+    ...
+    CREATE TABLE t1 (
+        name VARCHAR(50) NOT NULL,
+        PRIMARY KEY (name)
+    )
+    ...
+    INSERT INTO t1 (name) VALUES (?)
+    [...] [('some name 1',), ('some name 2',)]
+    COMMIT
+    BEGIN (implicit)
+    SELECT t1.name
+    FROM t1
+    WHERE t1.name = ?
+    [...] ('some name 1',)
+    [('some name 1',)]
+    ROLLBACK
 
 Above, the :meth:`_asyncio.AsyncConnection.run_sync` method may be used to
 invoke special DDL functions such as :meth:`_schema.MetaData.create_all` that
@@ -154,114 +171,165 @@ this.
     :ref:`asyncio_concurrency` and :ref:`session_faq_threadsafe` for background.
 
 The example below illustrates a complete example including mapper and session
-configuration::
-
-    from __future__ import annotations
-
-    import asyncio
-    import datetime
-    from typing import List
-
-    from sqlalchemy import ForeignKey
-    from sqlalchemy import func
-    from sqlalchemy import select
-    from sqlalchemy.ext.asyncio import AsyncAttrs
-    from sqlalchemy.ext.asyncio import async_sessionmaker
-    from sqlalchemy.ext.asyncio import AsyncSession
-    from sqlalchemy.ext.asyncio import create_async_engine
-    from sqlalchemy.orm import DeclarativeBase
-    from sqlalchemy.orm import Mapped
-    from sqlalchemy.orm import mapped_column
-    from sqlalchemy.orm import relationship
-    from sqlalchemy.orm import selectinload
-
-
-    class Base(AsyncAttrs, DeclarativeBase):
-        pass
-
-
-    class A(Base):
-        __tablename__ = "a"
-
-        id: Mapped[int] = mapped_column(primary_key=True)
-        data: Mapped[str]
-        create_date: Mapped[datetime.datetime] = mapped_column(server_default=func.now())
-        bs: Mapped[List[B]] = relationship()
-
-
-    class B(Base):
-        __tablename__ = "b"
-        id: Mapped[int] = mapped_column(primary_key=True)
-        a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
-        data: Mapped[str]
-
-
-    async def insert_objects(async_session: async_sessionmaker[AsyncSession]) -> None:
-        async with async_session() as session:
-            async with session.begin():
-                session.add_all(
-                    [
-                        A(bs=[B(data="b1"), B(data="b2")], data="a1"),
-                        A(bs=[], data="a2"),
-                        A(bs=[B(data="b3"), B(data="b4")], data="a3"),
-                    ]
-                )
-
-
-    async def select_and_update_objects(
-        async_session: async_sessionmaker[AsyncSession],
-    ) -> None:
-        async with async_session() as session:
-            stmt = select(A).options(selectinload(A.bs))
-
-            result = await session.execute(stmt)
-
-            for a in result.scalars():
-                print(a)
-                print(f"created at: {a.create_date}")
-                for b in a.bs:
-                    print(b, b.data)
-
-            result = await session.execute(select(A).order_by(A.id).limit(1))
-
-            a1 = result.scalars().one()
-
-            a1.data = "new data"
-
-            await session.commit()
-
-            # access attribute subsequent to commit; this is what
-            # expire_on_commit=False allows
-            print(a1.data)
-
-            # alternatively, AsyncAttrs may be used to access any attribute
-            # as an awaitable (new in 2.0.13)
-            for b1 in await a1.awaitable_attrs.bs:
-                print(b1, b1.data)
-
-
-    async def async_main() -> None:
-        engine = create_async_engine(
-            "postgresql+asyncpg://scott:tiger@localhost/test",
-            echo=True,
-        )
-
-        # async_sessionmaker: a factory for new AsyncSession objects.
-        # expire_on_commit - don't expire objects after transaction commit
-        async_session = async_sessionmaker(engine, expire_on_commit=False)
-
-        async with engine.begin() as conn:
-            await conn.run_sync(Base.metadata.create_all)
-
-        await insert_objects(async_session)
-        await select_and_update_objects(async_session)
-
-        # for AsyncEngine created in function scope, close and
-        # clean-up pooled connections
-        await engine.dispose()
-
-
-    asyncio.run(async_main())
+configuration:
+
+.. sourcecode:: pycon+sql
+
+    >>> from __future__ import annotations
+
+    >>> import asyncio
+    >>> import datetime
+    >>> from typing import List
+
+    >>> from sqlalchemy import ForeignKey
+    >>> from sqlalchemy import func
+    >>> from sqlalchemy import select
+    >>> from sqlalchemy.ext.asyncio import AsyncAttrs
+    >>> from sqlalchemy.ext.asyncio import async_sessionmaker
+    >>> from sqlalchemy.ext.asyncio import AsyncSession
+    >>> from sqlalchemy.ext.asyncio import create_async_engine
+    >>> from sqlalchemy.orm import DeclarativeBase
+    >>> from sqlalchemy.orm import Mapped
+    >>> from sqlalchemy.orm import mapped_column
+    >>> from sqlalchemy.orm import relationship
+    >>> from sqlalchemy.orm import selectinload
+
+
+    >>> class Base(AsyncAttrs, DeclarativeBase):
+    ...     pass
+
+    >>> class B(Base):
+    ...     __tablename__ = "b"
+    ...
+    ...     id: Mapped[int] = mapped_column(primary_key=True)
+    ...     a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+    ...     data: Mapped[str]
+
+    >>> class A(Base):
+    ...     __tablename__ = "a"
+    ...
+    ...     id: Mapped[int] = mapped_column(primary_key=True)
+    ...     data: Mapped[str]
+    ...     create_date: Mapped[datetime.datetime] = mapped_column(server_default=func.now())
+    ...     bs: Mapped[List[B]] = relationship()
+
+    >>> async def insert_objects(async_session: async_sessionmaker[AsyncSession]) -> None:
+    ...     async with async_session() as session:
+    ...         async with session.begin():
+    ...             session.add_all(
+    ...                 [
+    ...                     A(bs=[B(data="b1"), B(data="b2")], data="a1"),
+    ...                     A(bs=[], data="a2"),
+    ...                     A(bs=[B(data="b3"), B(data="b4")], data="a3"),
+    ...                 ]
+    ...             )
+
+
+    >>> async def select_and_update_objects(
+    ...     async_session: async_sessionmaker[AsyncSession],
+    ... ) -> None:
+    ...     async with async_session() as session:
+    ...         stmt = select(A).order_by(A.id).options(selectinload(A.bs))
+    ...
+    ...         result = await session.execute(stmt)
+    ...
+    ...         for a in result.scalars():
+    ...             print(a, a.data)
+    ...             print(f"created at: {a.create_date}")
+    ...             for b in a.bs:
+    ...                 print(b, b.data)
+    ...
+    ...         result = await session.execute(select(A).order_by(A.id).limit(1))
+    ...
+    ...         a1 = result.scalars().one()
+    ...
+    ...         a1.data = "new data"
+    ...
+    ...         await session.commit()
+    ...
+    ...         # access attribute subsequent to commit; this is what
+    ...         # expire_on_commit=False allows
+    ...         print(a1.data)
+    ...
+    ...         # alternatively, AsyncAttrs may be used to access any attribute
+    ...         # as an awaitable (new in 2.0.13)
+    ...         for b1 in await a1.awaitable_attrs.bs:
+    ...             print(b1, b1.data)
+
+
+    >>> async def async_main() -> None:
+    ...     engine = create_async_engine("sqlite+aiosqlite://", echo=True)
+    ...
+    ...     # async_sessionmaker: a factory for new AsyncSession objects.
+    ...     # expire_on_commit - don't expire objects after transaction commit
+    ...     async_session = async_sessionmaker(engine, expire_on_commit=False)
+    ...
+    ...     async with engine.begin() as conn:
+    ...         await conn.run_sync(Base.metadata.create_all)
+    ...
+    ...     await insert_objects(async_session)
+    ...     await select_and_update_objects(async_session)
+    ...
+    ...     # for AsyncEngine created in function scope, close and
+    ...     # clean-up pooled connections
+    ...     await engine.dispose()
+
+
+    >>> asyncio.run(async_main())
+    {execsql}BEGIN (implicit)
+    ...
+    CREATE TABLE a (
+        id INTEGER NOT NULL,
+        data VARCHAR NOT NULL,
+        create_date DATETIME DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
+        PRIMARY KEY (id)
+    )
+    ...
+    CREATE TABLE b (
+        id INTEGER NOT NULL,
+        a_id INTEGER NOT NULL,
+        data VARCHAR NOT NULL,
+        PRIMARY KEY (id),
+        FOREIGN KEY(a_id) REFERENCES a (id)
+    )
+    ...
+    COMMIT
+    BEGIN (implicit)
+    INSERT INTO a (data) VALUES (?) RETURNING id, create_date
+    [...] ('a1',)
+    ...
+    INSERT INTO b (a_id, data) VALUES (?, ?) RETURNING id
+    [...] (1, 'b2')
+    ...
+    COMMIT
+    BEGIN (implicit)
+    SELECT a.id, a.data, a.create_date
+    FROM a ORDER BY a.id
+    [...] ()
+    SELECT b.a_id AS b_a_id, b.id AS b_id, b.data AS b_data
+    FROM b
+    WHERE b.a_id IN (?, ?, ?)
+    [...] (1, 2, 3)
+    <A object at ...> a1
+    created at: ...
+    <B object at ...> b1
+    <B object at ...> b2
+    <A object at ...> a2
+    created at: ...
+    <A object at ...> a3
+    created at: ...
+    <B object at ...> b3
+    <B object at ...> b4
+    SELECT a.id, a.data, a.create_date
+    FROM a ORDER BY a.id
+    LIMIT ? OFFSET ?
+    [...] (1, 0)
+    UPDATE a SET data=? WHERE a.id = ?
+    [...] ('new data', 1)
+    COMMIT
+    new data
+    <B object at ...> b1
+    <B object at ...> b2
 
 In the example above, the :class:`_asyncio.AsyncSession` is instantiated using
 the optional :class:`_asyncio.async_sessionmaker` helper, which provides
index b920f25f0a5d325172492fe1a536e42e0d3b5087..7543b1c100c25f87663763b85df0032e029f63aa 100644 (file)
@@ -1,14 +1,17 @@
 from __future__ import annotations
 
+import asyncio
 import doctest
 import logging
 import os
 import re
 import sys
 
+from sqlalchemy.engine.url import make_url
 from sqlalchemy.testing import config
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import requires
+from sqlalchemy.testing import skip_test
 
 
 class DocTest(fixtures.TestBase):
@@ -65,12 +68,9 @@ class DocTest(fixtures.TestBase):
             doctest.ELLIPSIS
             | doctest.NORMALIZE_WHITESPACE
             | doctest.IGNORE_EXCEPTION_DETAIL
-            | _get_allow_unicode_flag()
         )
         runner = doctest.DocTestRunner(
-            verbose=None,
-            optionflags=optionflags,
-            checker=_get_unicode_checker(),
+            verbose=config.options.verbose >= 2, optionflags=optionflags
         )
         parser = doctest.DocTestParser()
         globs = {"print_function": print}
@@ -163,90 +163,39 @@ class DocTest(fixtures.TestBase):
         )
 
     def test_orm_queryguide_inheritance(self):
-        self._run_doctest(
-            "orm/queryguide/inheritance.rst",
-        )
+        self._run_doctest("orm/queryguide/inheritance.rst")
 
     @requires.update_from
     def test_orm_queryguide_dml(self):
-        self._run_doctest(
-            "orm/queryguide/dml.rst",
-        )
+        self._run_doctest("orm/queryguide/dml.rst")
 
     def test_orm_large_collections(self):
-        self._run_doctest(
-            "orm/large_collections.rst",
-        )
+        self._run_doctest("orm/large_collections.rst")
 
     def test_orm_queryguide_columns(self):
-        self._run_doctest(
-            "orm/queryguide/columns.rst",
-        )
+        self._run_doctest("orm/queryguide/columns.rst")
 
     def test_orm_quickstart(self):
         self._run_doctest("orm/quickstart.rst")
 
-
-# unicode checker courtesy pytest
-
-
-def _get_unicode_checker():
-    """
-    Returns a doctest.OutputChecker subclass that takes in account the
-    ALLOW_UNICODE option to ignore u'' prefixes in strings. Useful
-    when the same doctest should run in Python 2 and Python 3.
-
-    An inner class is used to avoid importing "doctest" at the module
-    level.
-    """
-    if hasattr(_get_unicode_checker, "UnicodeOutputChecker"):
-        return _get_unicode_checker.UnicodeOutputChecker()
-
-    import doctest
-    import re
-
-    class UnicodeOutputChecker(doctest.OutputChecker):
-        """
-        Copied from doctest_nose_plugin.py from the nltk project:
-            https://github.com/nltk/nltk
-        """
-
-        _literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE)
-
-        def check_output(self, want, got, optionflags):
-            res = doctest.OutputChecker.check_output(
-                self, want, got, optionflags
-            )
-            if res:
-                return True
-
-            if not (optionflags & _get_allow_unicode_flag()):
-                return False
-
-            else:  # pragma: no cover
-                # the code below will end up executed only in Python 2 in
-                # our tests, and our coverage check runs in Python 3 only
-                def remove_u_prefixes(txt):
-                    return re.sub(self._literal_re, r"\1\2", txt)
-
-                want = remove_u_prefixes(want)
-                got = remove_u_prefixes(got)
-                res = doctest.OutputChecker.check_output(
-                    self, want, got, optionflags
-                )
-                return res
-
-    _get_unicode_checker.UnicodeOutputChecker = UnicodeOutputChecker
-    return _get_unicode_checker.UnicodeOutputChecker()
-
-
-def _get_allow_unicode_flag():
-    """
-    Registers and returns the ALLOW_UNICODE flag.
-    """
-    import doctest
-
-    return doctest.register_optionflag("ALLOW_UNICODE")
+    @config.fixture(scope="class")
+    def restore_asyncio(self):
+        # NOTE: this is required since test_asyncio will remove the global
+        # loop. 2.1 uses runners that don't require this hack
+        yield
+        ep = asyncio.get_event_loop_policy()
+        try:
+            ep.get_event_loop()
+        except RuntimeError:
+            ep.set_event_loop(ep.new_event_loop())
+
+    @requires.greenlet
+    def test_asyncio(self, restore_asyncio):
+        try:
+            make_url("sqlite+aiosqlite://").get_dialect().import_dbapi()
+        except ImportError:
+            skip_test("missing aiosqile")
+        self._run_doctest("orm/extensions/asyncio.rst")
 
 
 # increase number to force pipeline run. 1