From: Mike Bayer Date: Sat, 4 Jul 2020 16:21:36 +0000 (-0400) Subject: Implement rudimentary asyncio support w/ asyncpg X-Git-Tag: rel_1_4_0b1~178^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5fb0138a3220161703e6ab1087319a669d14e7f4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement rudimentary asyncio support w/ asyncpg Using the approach introduced at https://gist.github.com/zzzeek/6287e28054d3baddc07fa21a7227904e We can now create asyncio endpoints that are then handled in "implicit IO" form within the majority of the Core internals. Then coroutines are re-exposed at the point at which we call into asyncpg methods. Patch includes: * asyncpg dialect * asyncio package * engine, result, ORM session classes * new test fixtures, tests * some work with pep-484 and a short plugin for the pyannotate package, which seems to have so-so results Change-Id: Idbcc0eff72c4cad572914acdd6f40ddb1aef1a7d Fixes: #3414 --- diff --git a/.gitignore b/.gitignore index 4931017b78..3916fe299b 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,5 @@ test/test_schema.db /.ipynb_checkpoints/ *.ipynb /querytest.db +/.mypy_cache +/.pytest_cache \ No newline at end of file diff --git a/doc/build/changelog/migration_14.rst b/doc/build/changelog/migration_14.rst index 5753cb089d..14584fd430 100644 --- a/doc/build/changelog/migration_14.rst +++ b/doc/build/changelog/migration_14.rst @@ -20,8 +20,8 @@ What's New in SQLAlchemy 1.4? For the current status of SQLAlchemy 2.0, see :ref:`migration_20_toplevel`. -Behavioral Changes - General -============================ +Major API changes and features - General +========================================= .. _change_5159: @@ -224,6 +224,92 @@ driven in order to support this new feature. :ticket:`4808` :ticket:`5004` +.. _change_3414: + +Asynchronous IO Support for Core and ORM +------------------------------------------ + +SQLAlchemy now supports Python ``asyncio``-compatible database drivers using an +all-new asyncio front-end interface to :class:`_engine.Connection` for Core +usage as well as :class:`_orm.Session` for ORM use, using the +:class:`_asyncio.AsyncConnection` and :class:`_asyncio.AsyncSession` objects. + +.. note:: The new asyncio feature should be considered **alpha level** for + the initial releases of SQLAlchemy 1.4. This is super new stuff that uses + some previously unfamiliar programming techniques. + +The initial database API supported is the :ref:`dialect-postgresql-asyncpg` +asyncio driver for PostgreSQL. + +The internal features of SQLAlchemy are fully integrated by making use of +the `greenlet `_ library in order +to adapt the flow of execution within SQLAlchemy's internals to propagate +asyncio ``await`` keywords outwards from the database driver to the end-user +API, which features ``async`` methods. Using this approach, the asyncpg +driver is fully operational within SQLAlchemy's own test suite and features +compatibility with most psycopg2 features. The approach was vetted and +improved upon by developers of the greenlet project for which SQLAlchemy +is appreciative. + +.. sidebar:: greenlets are good + + Don't confuse the greenlet_ library with event-based IO libraries that build + on top of it such as ``gevent`` and ``eventlet``; while the use of these + libraries with SQLAlchemy is common, SQLAlchemy's asyncio integration + **does not** make use of these event based systems in any way. The asyncio + API integrates with the user-provided event loop, typically Python's own + asyncio event loop, without the use of additional threads or event systems. + The approach involves a single greenlet context switch per ``await`` call, + and the extension which makes it possible is less than 20 lines of code. + +The user facing ``async`` API itself is focused around IO-oriented methods such +as :meth:`_asyncio.AsyncEngine.connect` and +:meth:`_asyncio.AsyncConnection.execute`. The new Core constructs strictly +support :term:`2.0 style` usage only; which means all statements must be +invoked given a connection object, in this case +:class:`_asyncio.AsyncConnection`. + +Within the ORM, :term:`2.0 style` query execution is +supported, using :func:`_sql.select` constructs in conjunction with +:meth:`_asyncio.AsyncSession.execute`; the legacy :class:`_orm.Query` +object itself is not supported by the :class:`_asyncio.AsyncSession` class. + +ORM features such as lazy loading of related attributes as well as unexpiry of +expired attributes are by definition disallowed in the traditional asyncio +programming model, as they indicate IO operations that would run implicitly +within the scope of a Python ``getattr()`` operation. To overcome this, the +**traditional** asyncio application should make judicious use of :ref:`eager +loading ` techniques as well as forego the use of features +such as :ref:`expire on commit ` so that such loads are not +needed. + +For the asyncio application developer who **chooses to break** with +tradition, the new API provides a **strictly optional +feature** such that applications that wish to make use of such ORM features +can opt to organize database-related code into functions which can then be +run within greenlets using the :meth:`_asyncio.AsyncSession.run_sync` +method. See the ``greenlet_orm.py`` example at :ref:`examples_asyncio` +for a demonstration. + +Support for asynchronous cursors is also provided using new methods +:meth:`_asyncio.AsyncConnection.stream` and +:meth:`_asyncio.AsyncSession.stream`, which support a new +:class:`_asyncio.AsyncResult` object that itself provides awaitable +versions of common methods like +:meth:`_asyncio.AsyncResult.all` and +:meth:`_asyncio.AsyncResult.fetchmany`. Both Core and ORM are integrated +with the feature which corresponds to the use of "server side cursors" +in traditional SQLAlchemy. + +.. seealso:: + + :ref:`asyncio_toplevel` + + :ref:`examples_asyncio` + + + +:ticket:`3414` .. _change_deferred_construction: diff --git a/doc/build/changelog/migration_20.rst b/doc/build/changelog/migration_20.rst index 535756f53e..7b3d23c8ca 100644 --- a/doc/build/changelog/migration_20.rst +++ b/doc/build/changelog/migration_20.rst @@ -1252,10 +1252,7 @@ Asyncio Support .. admonition:: Certainty: definite - A surprising development will allow asyncio support including with the - ORM to be fully implemented. There will even be a **completely optional** - path to having lazy loading be available, for those willing to make use of - some "controversial" patterns. + This is now implemented in 1.4. There was previously an entire section here detailing how asyncio is a nice to have, but not really necessary from a technical standpoint, there are some @@ -1267,113 +1264,7 @@ an entirely separate version of everything be maintained, therefore this makes it feasible to deliver this feature to those users who prefer an all-async application style without impact on the traditional blocking archictecture. -The proof of concept at https://gist.github.com/zzzeek/4e89ce6226826e7a8df13e1b573ad354 -illustrates how to write an asyncio application that makes use of a pure asyncio -driver (asyncpg), with part of the code **in between** remaining as sync code -without the use of any await/async keywords. The central technique involves -minimal use of a greenlet (e.g. stackless Python) to perform the necessary -context switches when an "await" occurs. The approach has been vetted -both with asyncio developers as well as greenlet developers, the latter -of which contributed a great degree of simplification the already simple recipe -such that can context switch async coroutines with no decrease in performance. - -The proof of concept has then been expanded to work within SQLAlchemy Core -and is presently in a Gerrit review. A SQLAlchemy dialect for the asyncpg -driver has been written and it passes most tests. - -Example ORM use will look similar to the following; this example is already -runnable with the in-review codebase:: - - import asyncio - - from sqlalchemy.asyncio import create_async_engine - from sqlalchemy.asyncio import AsyncSession - # ... other imports ... - - async def async_main(): - engine = create_async_engine( - "postgresql+asyncpg://scott:tiger@localhost/test", echo=True, - ) - - - # assume a typical ORM model with classes A and B - - session = AsyncSession(engine) - session.add_all( - [ - A(bs=[B(), B()], data="a1"), - A(bs=[B()], data="a2"), - A(bs=[B(), B()], data="a3"), - ] - ) - await session.commit() - stmt = select(A).options(selectinload(A.bs)) - result = await session.execute(stmt) - for a1 in result.scalars(): - print(a1) - for b1 in a1.bs: - print(b1) - - result = await session.execute(select(A).order_by(A.id)) - - a1 = result.scalars().first() - a1.data = "new data" - await session.commit() - - asyncio.run(async_main()) - -The "controversial" feature, if provided, would include that the "greenlet" -context would be supplied as front-facing API. This would allow an asyncio -application to spawn a greenlet that contains sync-code, which could use the -Core and ORM in a fully traditional manner including that lazy loading -for columns and relationships would be present. This mode of use is -somewhat similar to running an application under an event-based -programming library such as gevent or eventlet, however the underyling -network calls would be within a pure asyncio context, i.e. like that of the -asyncpg driver. An example of this use, which is also runnable with -the in-review codebase:: - - import asyncio - - from sqlalchemy.asyncio import greenlet_spawn - - from sqlalchemy import create_engine - from sqlalchemy.orm import Session - # ... other imports ... - - def main(): - # standard "sync" engine with the "async" driver. - engine = create_engine( - "postgresql+asyncpg://scott:tiger@localhost/test", echo=True, - ) - - # assume a typical ORM model with classes A and B - - session = Session(engine) - session.add_all( - [ - A(bs=[B(), B()], data="a1"), - A(bs=[B()], data="a2"), - A(bs=[B(), B()], data="a3"), - ] - ) - session.commit() - for a1 in session.query(A).all(): - print("a: %s" % a1) - print("bs: %s" % (a1.bs)) # emits a lazyload. - - asyncio.run(greenlet_spawn(main)) - - -Above, we see a ``main()`` function that contains within it a 100% normal -looking Python program using the SQLAlchemy ORM, using plain ORM imports and -basically absolutely nothing out of the ordinary. It just happens to be called -from inside of an ``asyncio.run()`` call rather than directly, and it uses a -DBAPI that is only compatible with asyncio. There is no "monkeypatching" or -anything else like that involved. Any asyncio program can opt -to place it's database-related business methods into the above pattern, -if preferred, rather than using the asyncio SQLAlchemy API directly. This -technique is also being adapted to other frameworks such as Flask and will -hopefully lead to greater interoperability between blocking and non-blocking -libraries and frameworks. +SQLAlchemy 1.4 now includes full asyncio capability with initial support +using the :ref:`dialect-postgresql-asyncpg` Python database driver; +see :ref:`asyncio_toplevel`. diff --git a/doc/build/changelog/unreleased_14/3414.rst b/doc/build/changelog/unreleased_14/3414.rst new file mode 100644 index 0000000000..a278244622 --- /dev/null +++ b/doc/build/changelog/unreleased_14/3414.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: feature, engine, orm + :tickets: 3414 + + SQLAlchemy now includes support for Python asyncio within both Core and + ORM, using the included :ref:`asyncio extension `. The + extension makes use of the `greenlet + `_ library in order to adapt + SQLAlchemy's sync-oriented internals such that an asyncio interface that + ultimately interacts with an asyncio database adapter is now feasible. The + single driver supported at the moment is the + :ref:`dialect-postgresql-asyncpg` driver for PostgreSQL. + + .. seealso:: + + :ref:`change_3414` + diff --git a/doc/build/conf.py b/doc/build/conf.py index 13d5732960..d4fdf58a00 100644 --- a/doc/build/conf.py +++ b/doc/build/conf.py @@ -106,6 +106,9 @@ autodocmods_convert_modname = { "sqlalchemy.engine.row": "sqlalchemy.engine", "sqlalchemy.engine.cursor": "sqlalchemy.engine", "sqlalchemy.engine.result": "sqlalchemy.engine", + "sqlalchemy.ext.asyncio.result": "sqlalchemy.ext.asyncio", + "sqlalchemy.ext.asyncio.engine": "sqlalchemy.ext.asyncio", + "sqlalchemy.ext.asyncio.session": "sqlalchemy.ext.asyncio", "sqlalchemy.util._collections": "sqlalchemy.util", "sqlalchemy.orm.relationships": "sqlalchemy.orm", "sqlalchemy.orm.interfaces": "sqlalchemy.orm", @@ -128,6 +131,7 @@ zzzeeksphinx_module_prefixes = { "_row": "sqlalchemy.engine", "_schema": "sqlalchemy.schema", "_types": "sqlalchemy.types", + "_asyncio": "sqlalchemy.ext.asyncio", "_expression": "sqlalchemy.sql.expression", "_sql": "sqlalchemy.sql.expression", "_dml": "sqlalchemy.sql.expression", diff --git a/doc/build/core/connections.rst b/doc/build/core/connections.rst index c6186cbaa3..b9605bb498 100644 --- a/doc/build/core/connections.rst +++ b/doc/build/core/connections.rst @@ -1225,18 +1225,9 @@ The above will respond to ``create_engine("mysql+foodialect://")`` and load the Connection / Engine API ======================= -.. autoclass:: BaseCursorResult - :members: - -.. autoclass:: ChunkedIteratorResult - :members: - .. autoclass:: Connection :members: -.. autoclass:: Connectable - :members: - .. autoclass:: CreateEnginePlugin :members: @@ -1246,6 +1237,25 @@ Connection / Engine API .. autoclass:: ExceptionContext :members: +.. autoclass:: NestedTransaction + :members: + +.. autoclass:: Transaction + :members: + +.. autoclass:: TwoPhaseTransaction + :members: + + +Result Set API +================= + +.. autoclass:: BaseCursorResult + :members: + +.. autoclass:: ChunkedIteratorResult + :members: + .. autoclass:: FrozenResult :members: @@ -1258,9 +1268,6 @@ Connection / Engine API .. autoclass:: MergedResult :members: -.. autoclass:: NestedTransaction - :members: - .. autoclass:: Result :members: :inherited-members: @@ -1291,9 +1298,3 @@ Connection / Engine API .. autoclass:: RowMapping :members: -.. autoclass:: Transaction - :members: - -.. autoclass:: TwoPhaseTransaction - :members: - diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index 35ed285eb2..6c36e58147 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -196,6 +196,13 @@ pg8000 .. automodule:: sqlalchemy.dialects.postgresql.pg8000 +.. _dialect-postgresql-asyncpg: + +asyncpg +------- + +.. automodule:: sqlalchemy.dialects.postgresql.asyncpg + psycopg2cffi ------------ diff --git a/doc/build/index.rst b/doc/build/index.rst index 6afef50833..bee062f89d 100644 --- a/doc/build/index.rst +++ b/doc/build/index.rst @@ -44,7 +44,8 @@ of Python objects, proceed first to the tutorial. * **ORM Usage:** :doc:`Session Usage and Guidelines ` | - :doc:`Loading Objects ` + :doc:`Loading Objects ` | + :doc:`AsyncIO Support ` * **Extending the ORM:** :doc:`ORM Events and Internals ` @@ -68,6 +69,7 @@ are documented here. In contrast to the ORM's domain-centric mode of usage, the * **Engines, Connections, Pools:** :doc:`Engine Configuration ` | :doc:`Connections, Transactions ` | + :doc:`AsyncIO Support ` | :doc:`Connection Pooling ` * **Schema Definition:** diff --git a/doc/build/intro.rst b/doc/build/intro.rst index 828ba31b31..4b9376ab0f 100644 --- a/doc/build/intro.rst +++ b/doc/build/intro.rst @@ -146,7 +146,6 @@ mechanism:: setuptools. - Installing a Database API ---------------------------------- diff --git a/doc/build/orm/examples.rst b/doc/build/orm/examples.rst index 7a79104b9b..10cafb2d2a 100644 --- a/doc/build/orm/examples.rst +++ b/doc/build/orm/examples.rst @@ -30,6 +30,13 @@ Associations .. automodule:: examples.association +.. _examples_asyncio: + +Asyncio Integration +------------------- + +.. automodule:: examples.asyncio + Directed Graphs --------------- diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst new file mode 100644 index 0000000000..388dee949b --- /dev/null +++ b/doc/build/orm/extensions/asyncio.rst @@ -0,0 +1,292 @@ +.. _asyncio_toplevel: + +asyncio +======= + +Support for Python asyncio. Support for Core and ORM usage is +included, using asyncio-compatible dialects. + +.. versionadded:: 1.4 + + +.. note:: The asyncio should be regarded as **alpha level** for the + 1.4 release of SQLAlchemy. API details are **subject to change** at + any time. + + +.. seealso:: + + :ref:`change_3414` - initial feature announcement + + :ref:`examples_asyncio` - example scripts illustrating working examples + of Core and ORM use within the asyncio extension. + +Synopsis - Core +--------------- + +For Core use, the :func:`_asyncio.create_async_engine` function creates an +instance of :class:`_asyncio.AsyncEngine` which then offers an async version of +the traditional :class:`_engine.Engine` API. The +:class:`_asyncio.AsyncEngine` delivers an :class:`_asyncio.AsyncConnection` via +its :meth:`_asyncio.AsyncEngine.connect` and :meth:`_asyncio.AsyncEngine.begin` +methods which both deliver asynchronous context managers. The +:class:`_asyncio.AsyncConnection` can then invoke statements using either the +:meth:`_asyncio.AsyncConnection.execute` method to deliver a buffered +:class:`_engine.Result`, or the :meth:`_asyncio.AsyncConnection.stream` method +to deliver a streaming server-side :class:`_asyncio.AsyncResult`:: + + import asyncio + + from sqlalchemy.ext.asyncio import create_async_engine + + async def async_main(): + engine = create_async_engine( + "postgresql+asyncpg://scott:tiger@localhost/test", echo=True, + ) + + async with engine.begin() as conn: + await conn.run_sync(meta.drop_all) + await conn.run_sync(meta.create_all) + + await conn.execute( + t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}] + ) + + async with engine.connect() as conn: + + # select a Result, which will be delivered with buffered + # results + result = await conn.execute(select(t1).where(t1.c.name == "some name 1")) + + print(result.fetchall()) + + + asyncio.run(async_main()) + +Above, the :meth:`_asyncio.AsyncConnection.run_sync` method may be used to +invoke special DDL functions such as :meth:`_schema.MetaData.create_all` that +don't include an awaitable hook. + +The :class:`_asyncio.AsyncConnection` also features a "streaming" API via +the :meth:`_asyncio.AsyncConnection.stream` method that returns an +:class:`_asyncio.AsyncResult` object. This result object uses a server-side +cursor and provides an async/await API, such as an async iterator:: + + async with engine.connect() as conn: + async_result = await conn.stream(select(t1)) + + async for row in async_result: + print("row: %s" % (row, )) + + +Synopsis - ORM +--------------- + +Using :term:`2.0 style` querying, the :class:`_asyncio.AsyncSession` class +provides full ORM functionality. Within the default mode of use, special care +must be taken to avoid :term:`lazy loading` of ORM relationships and column +attributes, as below where the :func:`_orm.selectinload` eager loading strategy +is used to ensure the ``A.bs`` on each ``A`` object is loaded:: + + import asyncio + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import AsyncSession + + async def async_main(): + engine = create_async_engine( + "postgresql+asyncpg://scott:tiger@localhost/test", echo=True, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + + async with AsyncSession(engine) as session: + async with session.begin(): + session.add_all( + [ + A(bs=[B(), B()], data="a1"), + A(bs=[B()], data="a2"), + A(bs=[B(), B()], data="a3"), + ] + ) + + stmt = select(A).options(selectinload(A.bs)) + + result = await session.execute(stmt) + + for a1 in result.scalars(): + print(a1) + for b1 in a1.bs: + print(b1) + + result = await session.execute(select(A).order_by(A.id)) + + a1 = result.scalars().first() + + a1.data = "new data" + + await session.commit() + + asyncio.run(async_main()) + +Above, the :func:`_orm.selectinload` eager loader is employed in order +to eagerly load the ``A.bs`` collection within the scope of the +``await session.execute()`` call. If the default loader strategy of +"lazyload" were left in place, the access of the ``A.bs`` attribute would +raise an asyncio exception. Using traditional asyncio, the application +needs to avoid any points at which IO-on-attribute access may occur. +This also includes that methods such as :meth:`_orm.Session.expire` should be +avoided in favor of :meth:`_asyncio.AsyncSession.refresh`, and that +appropriate loader options should be employed for :func:`_orm.deferred` +columns as well as for :func:`_orm.relationship` constructs. + +Adapting ORM Lazy loads to asyncio +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. deepalchemy:: This approach is essentially exposing publicly the + mechanism by which SQLAlchemy is able to provide the asyncio interface + in the first place. While there is no technical issue with doing so, overall + the approach can probably be considered "controversial" as it works against + some of the central philosophies of the asyncio programming model, which + is essentially that any programming statement that can potentially result + in IO being invoked **must** have an ``await`` call, lest the program + does not make it explicitly clear every line at which IO may occur. + This approach does not change that general idea, except that it allows + a series of synchronous IO instructions to be exempted from this rule + within the scope of a function call, essentially bundled up into a single + awaitable. + +As an alternative means of integrating traditional SQLAlchemy "lazy loading" +within an asyncio event loop, an **optional** method known as +:meth:`_asyncio.AsyncSession.run_sync` is provided which will run any +Python function inside of a greenlet, where traditional synchronous +programming concepts will be translated to use ``await`` when they reach the +database driver. A hypothetical approach here is an asyncio-oriented +application can package up database-related methods into functions that are +invoked using :meth:`_asyncio.AsyncSession.run_sync`. + +Altering the above example, if we didn't use :func:`_orm.selectinload` +for the ``A.bs`` collection, we could accomplish our treatment of these +attribute accesses within a separate function:: + + import asyncio + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import AsyncSession + + def fetch_and_update_objects(session): + """run traditional sync-style ORM code in a function that will be + invoked within an awaitable. + + """ + + # the session object here is a traditional ORM Session. + # all features are available here including legacy Query use. + + stmt = select(A) + + result = session.execute(stmt) + for a1 in result.scalars(): + print(a1) + + # lazy loads + for b1 in a1.bs: + print(b1) + + # legacy Query use + a1 = session.query(A).order_by(A.id).first() + + a1.data = "new data" + + + async def async_main(): + engine = create_async_engine( + "postgresql+asyncpg://scott:tiger@localhost/test", echo=True, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + + async with AsyncSession(engine) as session: + async with session.begin(): + session.add_all( + [ + A(bs=[B(), B()], data="a1"), + A(bs=[B()], data="a2"), + A(bs=[B(), B()], data="a3"), + ] + ) + + session.run_sync(fetch_and_update_objects) + + await session.commit() + + asyncio.run(async_main()) + +The above approach of running certain functions within a "sync" runner +has some parallels to an application that runs a SQLAlchemy application +on top of an event-based programming library such as ``gevent``. The +differences are as follows: + +1. unlike when using ``gevent``, we can continue to use the standard Python + asyncio event loop, or any custom event loop, without the need to integrate + into the ``gevent`` event loop. + +2. There is no "monkeypatching" whatsoever. The above example makes use of + a real asyncio driver and the underlying SQLAlchemy connection pool is also + using the Python built-in ``asyncio.Queue`` for pooling connections. + +3. The program can freely switch between async/await code and contained + functions that use sync code with virtually no performance penalty. There + is no "thread executor" or any additional waiters or synchronization in use. + +4. The underlying network drivers are also using pure Python asyncio + concepts, no third party networking libraries as ``gevent`` and ``eventlet`` + provides are in use. + +.. currentmodule:: sqlalchemy.ext.asyncio + +Engine API Documentation +------------------------- + +.. autofunction:: create_async_engine + +.. autoclass:: AsyncEngine + :members: + +.. autoclass:: AsyncConnection + :members: + +.. autoclass:: AsyncTransaction + :members: + +Result Set API Documentation +---------------------------------- + +The :class:`_asyncio.AsyncResult` object is an async-adapted version of the +:class:`_result.Result` object. It is only returned when using the +:meth:`_asyncio.AsyncConnection.stream` or :meth:`_asyncio.AsyncSession.stream` +methods, which return a result object that is on top of an active database +cursor. + +.. autoclass:: AsyncResult + :members: + +.. autoclass:: AsyncScalarResult + :members: + +.. autoclass:: AsyncMappingResult + :members: + +ORM Session API Documentation +----------------------------- + +.. autoclass:: AsyncSession + :members: + +.. autoclass:: AsyncSessionTransaction + :members: + + + diff --git a/doc/build/orm/extensions/index.rst b/doc/build/orm/extensions/index.rst index e23fd55ee7..ba040b9f65 100644 --- a/doc/build/orm/extensions/index.rst +++ b/doc/build/orm/extensions/index.rst @@ -15,6 +15,7 @@ behavior. In particular the "Horizontal Sharding", "Hybrid Attributes", and .. toctree:: :maxdepth: 1 + asyncio associationproxy automap baked diff --git a/examples/asyncio/__init__.py b/examples/asyncio/__init__.py new file mode 100644 index 0000000000..c53120f54b --- /dev/null +++ b/examples/asyncio/__init__.py @@ -0,0 +1,6 @@ +""" +Examples illustrating the asyncio engine feature of SQLAlchemy. + +.. autosource:: + +""" diff --git a/examples/asyncio/async_orm.py b/examples/asyncio/async_orm.py new file mode 100644 index 0000000000..b1054a239f --- /dev/null +++ b/examples/asyncio/async_orm.py @@ -0,0 +1,89 @@ +"""Illustrates use of the sqlalchemy.ext.asyncio.AsyncSession object +for asynchronous ORM use. + +""" + +import asyncio + +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.future import select +from sqlalchemy.orm import relationship +from sqlalchemy.orm import selectinload + +Base = declarative_base() + + +class A(Base): + __tablename__ = "a" + + id = Column(Integer, primary_key=True) + data = Column(String) + bs = relationship("B") + + +class B(Base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey("a.id")) + data = Column(String) + + +async def async_main(): + """Main program function.""" + + engine = create_async_engine( + "postgresql+asyncpg://scott:tiger@localhost/test", echo=True, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + + async with AsyncSession(engine) as session: + async with session.begin(): + session.add_all( + [ + A(bs=[B(), B()], data="a1"), + A(bs=[B()], data="a2"), + A(bs=[B(), B()], data="a3"), + ] + ) + + # for relationship loading, eager loading should be applied. + stmt = select(A).options(selectinload(A.bs)) + + # AsyncSession.execute() is used for 2.0 style ORM execution + # (same as the synchronous API). + result = await session.execute(stmt) + + # result is a buffered Result object. + for a1 in result.scalars(): + print(a1) + for b1 in a1.bs: + print(b1) + + # for streaming ORM results, AsyncSession.stream() may be used. + result = await session.stream(stmt) + + # result is a streaming AsyncResult object. + async for a1 in result.scalars(): + print(a1) + for b1 in a1.bs: + print(b1) + + result = await session.execute(select(A).order_by(A.id)) + + a1 = result.scalars().first() + + a1.data = "new data" + + await session.commit() + + +asyncio.run(async_main()) diff --git a/examples/asyncio/basic.py b/examples/asyncio/basic.py new file mode 100644 index 0000000000..05cdd8a05c --- /dev/null +++ b/examples/asyncio/basic.py @@ -0,0 +1,71 @@ +"""Illustrates the asyncio engine / connection interface. + +In this example, we have an async engine created by +:func:`_engine.create_async_engine`. We then use it using await +within a coroutine. + +""" + + +import asyncio + +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.ext.asyncio import create_async_engine + + +meta = MetaData() + +t1 = Table( + "t1", meta, Column("id", Integer, primary_key=True), Column("name", String) +) + + +async def async_main(): + # engine is an instance of AsyncEngine + engine = create_async_engine( + "postgresql+asyncpg://scott:tiger@localhost/test", echo=True, + ) + + # conn is an instance of AsyncConnection + async with engine.begin() as conn: + + # to support SQLAlchemy DDL methods as well as legacy functions, the + # AsyncConnection.run_sync() awaitable method will pass a "sync" + # version of the AsyncConnection object to any synchronous method, + # where synchronous IO calls will be transparently translated for + # await. + await conn.run_sync(meta.drop_all) + await conn.run_sync(meta.create_all) + + # for normal statement execution, a traditional "await execute()" + # pattern is used. + await conn.execute( + t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}] + ) + + async with engine.connect() as conn: + + # the default result object is the + # sqlalchemy.engine.Result object + result = await conn.execute(t1.select()) + + # the results are buffered so no await call is necessary + # for this case. + print(result.fetchall()) + + # for a streaming result that buffers only segments of the + # result at time, the AsyncConnection.stream() method is used. + # this returns a sqlalchemy.ext.asyncio.AsyncResult object. + async_result = await conn.stream(t1.select()) + + # this object supports async iteration and awaitable + # versions of methods like .all(), fetchmany(), etc. + async for row in async_result: + print(row) + + +asyncio.run(async_main()) diff --git a/examples/asyncio/greenlet_orm.py b/examples/asyncio/greenlet_orm.py new file mode 100644 index 0000000000..e0b568c4b8 --- /dev/null +++ b/examples/asyncio/greenlet_orm.py @@ -0,0 +1,92 @@ +"""Illustrates use of the sqlalchemy.ext.asyncio.AsyncSession object +for asynchronous ORM use, including the optional run_sync() method. + + +""" + +import asyncio + +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.future import select +from sqlalchemy.orm import relationship + +Base = declarative_base() + + +class A(Base): + __tablename__ = "a" + + id = Column(Integer, primary_key=True) + data = Column(String) + bs = relationship("B") + + +class B(Base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey("a.id")) + data = Column(String) + + +def run_queries(session): + """A function written in "synchronous" style that will be invoked + within the asyncio event loop. + + The session object passed is a traditional orm.Session object with + synchronous interface. + + """ + + stmt = select(A) + + result = session.execute(stmt) + + for a1 in result.scalars(): + print(a1) + # lazy loads + for b1 in a1.bs: + print(b1) + + result = session.execute(select(A).order_by(A.id)) + + a1 = result.scalars().first() + + a1.data = "new data" + + +async def async_main(): + """Main program function.""" + + engine = create_async_engine( + "postgresql+asyncpg://scott:tiger@localhost/test", echo=True, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + + async with AsyncSession(engine) as session: + async with session.begin(): + session.add_all( + [ + A(bs=[B(), B()], data="a1"), + A(bs=[B()], data="a2"), + A(bs=[B(), B()], data="a3"), + ] + ) + + # we have the option to run a function written in sync style + # within the AsyncSession.run_sync() method. The function will + # be passed a synchronous-style Session object and the function + # can use traditional ORM patterns. + await session.run_sync(run_queries) + + await session.commit() + + +asyncio.run(async_main()) diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 06d22872a9..2762a9971b 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php - from . import base from . import pg8000 # noqa from . import psycopg2 # noqa @@ -58,7 +57,10 @@ from .ranges import INT8RANGE from .ranges import NUMRANGE from .ranges import TSRANGE from .ranges import TSTZRANGE +from ...util import compat +if compat.py3k: + from . import asyncpg # noqa base.dialect = dialect = psycopg2.dialect diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py new file mode 100644 index 0000000000..515ef6e288 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -0,0 +1,786 @@ +# postgresql/asyncpg.py +# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: postgresql+asyncpg + :name: asyncpg + :dbapi: asyncpg + :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...] + :url: https://magicstack.github.io/asyncpg/ + +The asyncpg dialect is SQLAlchemy's first Python asyncio dialect. + +Using a special asyncio mediation layer, the asyncpg dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname") + +The dialect can also be run as a "synchronous" dialect within the +:func:`_sa.create_engine` function, which will pass "await" calls into +an ad-hoc event loop. This mode of operation is of **limited use** +and is for special testing scenarios only. The mode can be enabled by +adding the SQLAlchemy-specific flag ``async_fallback`` to the URL +in conjunction with :func:`_sa.craete_engine`:: + + # for testing purposes only; do not use in production! + engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true") + + +.. versionadded:: 1.4 + +""" # noqa + +import collections +import decimal +import itertools +import re + +from . import json +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import ENUM +from .base import INTERVAL +from .base import OID +from .base import PGCompiler +from .base import PGDialect +from .base import PGExecutionContext +from .base import PGIdentifierPreparer +from .base import REGCLASS +from .base import UUID +from ... import exc +from ... import pool +from ... import processors +from ... import util +from ...sql import sqltypes +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +try: + from uuid import UUID as _python_UUID # noqa +except ImportError: + _python_UUID = None + + +class AsyncpgTime(sqltypes.Time): + def get_dbapi_type(self, dbapi): + return dbapi.TIME + + +class AsyncpgDate(sqltypes.Date): + def get_dbapi_type(self, dbapi): + return dbapi.DATE + + +class AsyncpgDateTime(sqltypes.DateTime): + def get_dbapi_type(self, dbapi): + if self.timezone: + return dbapi.TIMESTAMP_W_TZ + else: + return dbapi.TIMESTAMP + + +class AsyncpgBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.BOOLEAN + + +class AsyncPgInterval(INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + + return AsyncPgInterval(precision=interval.second_precision) + + +class AsyncPgEnum(ENUM): + def get_dbapi_type(self, dbapi): + return dbapi.ENUM + + +class AsyncpgInteger(sqltypes.Integer): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class AsyncpgBigInteger(sqltypes.BigInteger): + def get_dbapi_type(self, dbapi): + return dbapi.BIGINTEGER + + +class AsyncpgJSON(json.JSON): + def get_dbapi_type(self, dbapi): + return dbapi.JSON + + +class AsyncpgJSONB(json.JSONB): + def get_dbapi_type(self, dbapi): + return dbapi.JSONB + + +class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType): + def get_dbapi_type(self, dbapi): + raise NotImplementedError("should not be here") + + +class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.STRING + + +class AsyncpgJSONPathType(json.JSONPathType): + def bind_processor(self, dialect): + def process(value): + assert isinstance(value, util.collections_abc.Sequence) + tokens = [util.text_type(elem) for elem in value] + return tokens + + return process + + +class AsyncpgUUID(UUID): + def get_dbapi_type(self, dbapi): + return dbapi.UUID + + def bind_processor(self, dialect): + if not self.as_uuid and dialect.use_native_uuid: + + def process(value): + if value is not None: + value = _python_UUID(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not self.as_uuid and dialect.use_native_uuid: + + def process(value): + if value is not None: + value = str(value) + return value + + return process + + +class AsyncpgNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class AsyncpgREGCLASS(REGCLASS): + def get_dbapi_type(self, dbapi): + return dbapi.STRING + + +class AsyncpgOID(OID): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class PGExecutionContext_asyncpg(PGExecutionContext): + def pre_exec(self): + if self.isddl: + self._dbapi_connection.reset_schema_state() + + if not self.compiled: + return + + # we have to exclude ENUM because "enum" not really a "type" + # we can cast to, it has to be the name of the type itself. + # for now we just omit it from casting + self.set_input_sizes(exclude_types={AsyncAdapt_asyncpg_dbapi.ENUM}) + + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class PGCompiler_asyncpg(PGCompiler): + pass + + +class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer): + pass + + +class AsyncAdapt_asyncpg_cursor: + __slots__ = ( + "_adapt_connection", + "_connection", + "_rows", + "description", + "arraysize", + "rowcount", + "_inputsizes", + "_cursor", + ) + + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self._rows = [] + self._cursor = None + self.description = None + self.arraysize = 1 + self.rowcount = -1 + self._inputsizes = None + + def close(self): + self._rows[:] = [] + + def _handle_exception(self, error): + self._adapt_connection._handle_exception(error) + + def _parameters(self): + if not self._inputsizes: + return ("$%d" % idx for idx in itertools.count(1)) + else: + + return ( + "$%d::%s" % (idx, typ) if typ else "$%d" % idx + for idx, typ in enumerate( + (_pg_types.get(typ) for typ in self._inputsizes), 1 + ) + ) + + async def _prepare_and_execute(self, operation, parameters): + # TODO: I guess cache these in an LRU cache, or see if we can + # use some asyncpg concept + + # TODO: would be nice to support the dollar numeric thing + # directly, this is much easier for now + + if not self._adapt_connection._started: + await self._adapt_connection._start_transaction() + + params = self._parameters() + operation = re.sub(r"\?", lambda m: next(params), operation) + + try: + prepared_stmt = await self._connection.prepare(operation) + + attributes = prepared_stmt.get_attributes() + if attributes: + self.description = [ + (attr.name, attr.type.oid, None, None, None, None, None) + for attr in prepared_stmt.get_attributes() + ] + else: + self.description = None + + if self.server_side: + self._cursor = await prepared_stmt.cursor(*parameters) + self.rowcount = -1 + else: + self._rows = await prepared_stmt.fetch(*parameters) + status = prepared_stmt.get_statusmsg() + + reg = re.match(r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status) + if reg: + self.rowcount = int(reg.group(1)) + else: + self.rowcount = -1 + + except Exception as error: + self._handle_exception(error) + + def execute(self, operation, parameters=()): + try: + self._adapt_connection.await_( + self._prepare_and_execute(operation, parameters) + ) + except Exception as error: + self._handle_exception(error) + + def executemany(self, operation, seq_of_parameters): + if not self._adapt_connection._started: + self._adapt_connection.await_( + self._adapt_connection._start_transaction() + ) + + params = self._parameters() + operation = re.sub(r"\?", lambda m: next(params), operation) + try: + return self._adapt_connection.await_( + self._connection.executemany(operation, seq_of_parameters) + ) + except Exception as error: + self._handle_exception(error) + + def setinputsizes(self, *inputsizes): + self._inputsizes = inputsizes + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): + + server_side = True + __slots__ = ("_rowbuffer",) + + def __init__(self, adapt_connection): + super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection) + self._rowbuffer = None + + def close(self): + self._cursor = None + self._rowbuffer = None + + def _buffer_rows(self): + new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) + self._rowbuffer = collections.deque(new_rows) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._rowbuffer: + self._buffer_rows() + + while True: + while self._rowbuffer: + yield self._rowbuffer.popleft() + + self._buffer_rows() + if not self._rowbuffer: + break + + def fetchone(self): + if not self._rowbuffer: + self._buffer_rows() + if not self._rowbuffer: + return None + return self._rowbuffer.popleft() + + def fetchmany(self, size=None): + if size is None: + return self.fetchall() + + if not self._rowbuffer: + self._buffer_rows() + + buf = list(self._rowbuffer) + lb = len(buf) + if size > lb: + buf.extend( + self._adapt_connection.await_(self._cursor.fetch(size - lb)) + ) + + result = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + return result + + def fetchall(self): + ret = list(self._rowbuffer) + list( + self._adapt_connection.await_(self._all()) + ) + self._rowbuffer.clear() + return ret + + async def _all(self): + rows = [] + + # TODO: looks like we have to hand-roll some kind of batching here. + # hardcoding for the moment but this should be improved. + while True: + batch = await self._cursor.fetch(1000) + if batch: + rows.extend(batch) + continue + else: + break + return rows + + def executemany(self, operation, seq_of_parameters): + raise NotImplementedError( + "server side cursor doesn't support executemany yet" + ) + + +class AsyncAdapt_asyncpg_connection: + __slots__ = ( + "dbapi", + "_connection", + "isolation_level", + "_transaction", + "_started", + ) + + await_ = staticmethod(await_only) + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self.isolation_level = "read_committed" + self._transaction = None + self._started = False + self.await_(self._setup_type_codecs()) + + async def _setup_type_codecs(self): + """set up type decoders at the asyncpg level. + + this is first to accommodate the "char" value of + pg_catalog.pg_attribute.attgenerated being returned as bytes. + Even though the doc at + https://magicstack.github.io/asyncpg/current/usage.html#type-conversion + claims "char" is returned as "str", it looks like this is actually + the 'bpchar' datatype, blank padded. 'char' seems to be some + more obscure type (oid 18) and asyncpg codes this to bytea: + https://github.com/MagicStack/asyncpg/blob/master/asyncpg/protocol/ + codecs/pgproto.pyx#L26 + + all the other drivers treat this as a string. + + """ + + await self._connection.set_type_codec( + "char", + schema="pg_catalog", + encoder=lambda value: value, + decoder=lambda value: value, + format="text", + ) + + def _handle_exception(self, error): + if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): + exception_mapping = self.dbapi._asyncpg_error_translate + + for super_ in type(error).__mro__: + if super_ in exception_mapping: + translated_error = exception_mapping[super_]( + "%s: %s" % (type(error), error) + ) + raise translated_error from error + else: + raise error + else: + raise error + + def set_isolation_level(self, level): + if self._started: + self.rollback() + self.isolation_level = level + + async def _start_transaction(self): + if self.isolation_level == "autocommit": + return + + try: + self._transaction = self._connection.transaction( + isolation=self.isolation_level + ) + await self._transaction.start() + except Exception as error: + self._handle_exception(error) + else: + self._started = True + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_asyncpg_ss_cursor(self) + else: + return AsyncAdapt_asyncpg_cursor(self) + + def reset_schema_state(self): + self.await_(self._connection.reload_schema_state()) + + def rollback(self): + if self._started: + self.await_(self._transaction.rollback()) + + self._transaction = None + self._started = False + + def commit(self): + if self._started: + self.await_(self._transaction.commit()) + self._transaction = None + self._started = False + + def close(self): + self.rollback() + + self.await_(self._connection.close()) + + +class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection): + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_asyncpg_dbapi: + def __init__(self, asyncpg): + self.asyncpg = asyncpg + self.paramstyle = "qmark" + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + if async_fallback: + return AsyncAdaptFallback_asyncpg_connection( + self, await_fallback(self.asyncpg.connect(*arg, **kw)), + ) + else: + return AsyncAdapt_asyncpg_connection( + self, await_only(self.asyncpg.connect(*arg, **kw)), + ) + + class Error(Exception): + pass + + class Warning(Exception): # noqa + pass + + class InterfaceError(Error): + pass + + class DatabaseError(Error): + pass + + class InternalError(DatabaseError): + pass + + class OperationalError(DatabaseError): + pass + + class ProgrammingError(DatabaseError): + pass + + class IntegrityError(DatabaseError): + pass + + class DataError(DatabaseError): + pass + + class NotSupportedError(DatabaseError): + pass + + @util.memoized_property + def _asyncpg_error_translate(self): + import asyncpg + + return { + asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa + asyncpg.exceptions.PostgresError: self.Error, + asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError, + asyncpg.exceptions.InterfaceError: self.InterfaceError, + } + + def Binary(self, value): + return value + + STRING = util.symbol("STRING") + TIMESTAMP = util.symbol("TIMESTAMP") + TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ") + TIME = util.symbol("TIME") + DATE = util.symbol("DATE") + INTERVAL = util.symbol("INTERVAL") + NUMBER = util.symbol("NUMBER") + FLOAT = util.symbol("FLOAT") + BOOLEAN = util.symbol("BOOLEAN") + INTEGER = util.symbol("INTEGER") + BIGINTEGER = util.symbol("BIGINTEGER") + BYTES = util.symbol("BYTES") + DECIMAL = util.symbol("DECIMAL") + JSON = util.symbol("JSON") + JSONB = util.symbol("JSONB") + ENUM = util.symbol("ENUM") + UUID = util.symbol("UUID") + BYTEA = util.symbol("BYTEA") + + DATETIME = TIMESTAMP + BINARY = BYTEA + + +_pg_types = { + AsyncAdapt_asyncpg_dbapi.STRING: "varchar", + AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp", + AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone", + AsyncAdapt_asyncpg_dbapi.DATE: "date", + AsyncAdapt_asyncpg_dbapi.TIME: "time", + AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval", + AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric", + AsyncAdapt_asyncpg_dbapi.FLOAT: "float", + AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool", + AsyncAdapt_asyncpg_dbapi.INTEGER: "integer", + AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint", + AsyncAdapt_asyncpg_dbapi.BYTES: "bytes", + AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal", + AsyncAdapt_asyncpg_dbapi.JSON: "json", + AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb", + AsyncAdapt_asyncpg_dbapi.ENUM: "enum", + AsyncAdapt_asyncpg_dbapi.UUID: "uuid", + AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea", +} + + +class PGDialect_asyncpg(PGDialect): + driver = "asyncpg" + + supports_unicode_statements = True + supports_server_side_cursors = True + + supports_unicode_binds = True + + default_paramstyle = "qmark" + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_asyncpg + statement_compiler = PGCompiler_asyncpg + preparer = PGIdentifierPreparer_asyncpg + + use_native_uuid = True + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Time: AsyncpgTime, + sqltypes.Date: AsyncpgDate, + sqltypes.DateTime: AsyncpgDateTime, + sqltypes.Interval: AsyncPgInterval, + INTERVAL: AsyncPgInterval, + UUID: AsyncpgUUID, + sqltypes.Boolean: AsyncpgBoolean, + sqltypes.Integer: AsyncpgInteger, + sqltypes.BigInteger: AsyncpgBigInteger, + sqltypes.Numeric: AsyncpgNumeric, + sqltypes.JSON: AsyncpgJSON, + json.JSONB: AsyncpgJSONB, + sqltypes.JSON.JSONPathType: AsyncpgJSONPathType, + sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType, + sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType, + sqltypes.Enum: AsyncPgEnum, + OID: AsyncpgOID, + REGCLASS: AsyncpgREGCLASS, + }, + ) + + def __init__(self, server_side_cursors=False, **kwargs): + PGDialect.__init__(self, **kwargs) + self.server_side_cursors = server_side_cursors + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + @classmethod + def dbapi(cls): + return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg")) + + @util.memoized_property + def _isolation_lookup(self): + return { + "AUTOCOMMIT": "autocommit", + "READ COMMITTED": "read_committed", + "REPEATABLE READ": "repeatable_read", + "SERIALIZABLE": "serializable", + } + + def set_isolation_level(self, connection, level): + try: + level = self._isolation_lookup[level.replace("_", " ")] + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ), + replace_context=err, + ) + + connection.set_isolation_level(level) + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + return ([], opts) + + @classmethod + def get_pool_class(self, url): + return pool.AsyncAdaptedQueuePool + + def is_disconnect(self, e, connection, cursor): + if connection: + return connection._connection.is_closed() + else: + return isinstance( + e, self.dbapi.InterfaceError + ) and "connection is closed" in str(e) + + +dialect = PGDialect_asyncpg diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 3bd7e62d53..7717a2526b 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1299,6 +1299,14 @@ class UUID(sqltypes.TypeEngine): """ self.as_uuid = as_uuid + def coerce_compared_value(self, op, value): + """See :meth:`.TypeEngine.coerce_compared_value` for a description.""" + + if isinstance(value, util.string_types): + return self + else: + return super(UUID, self).coerce_compared_value(op, value) + def bind_processor(self, dialect): if self.as_uuid: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index d60f14f315..34bf720b78 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -89,6 +89,7 @@ class Connection(Connectable): if connection is not None else engine.raw_connection() ) + self._transaction = self._nested_transaction = None self.__savepoint_seq = 0 self.__in_begin = False @@ -623,6 +624,9 @@ class Connection(Connectable): self._dbapi_connection.detach() + def _autobegin(self): + self.begin() + def begin(self): """Begin a transaction and return a transaction handle. @@ -1433,7 +1437,7 @@ class Connection(Connectable): self._invalid_transaction() if self._is_future and self._transaction is None: - self.begin() + self._autobegin() context.pre_exec() @@ -2592,6 +2596,7 @@ class Engine(Connectable, log.Identified): return self.conn def __exit__(self, type_, value, traceback): + if type_ is not None: self.transaction.rollback() else: diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index dc895ee15d..66173d9b03 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -553,7 +553,7 @@ def create_engine(url, **kwargs): poolclass = pop_kwarg("poolclass", None) if poolclass is None: - poolclass = dialect_cls.get_pool_class(u) + poolclass = dialect.get_dialect_pool_class(u) pool_args = {"dialect": dialect} # consume pool arguments from kwargs, translating a few of diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index c76f820f9b..4fb20a3d50 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -317,6 +317,9 @@ class DefaultDialect(interfaces.Dialect): def get_pool_class(cls, url): return getattr(cls, "poolclass", pool.QueuePool) + def get_dialect_pool_class(self, url): + return self.get_pool_class(url) + @classmethod def load_provisioning(cls): package = ".".join(cls.__module__.split(".")[0:-1]) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 9badbffc3c..10a88c7d88 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -20,6 +20,7 @@ from ..sql.base import _generative from ..sql.base import HasMemoized from ..sql.base import InPlaceGenerative from ..util import collections_abc +from ..util import py2k if util.TYPE_CHECKING: from typing import Any @@ -616,6 +617,16 @@ class ResultInternal(InPlaceGenerative): else: return row + def _iter_impl(self): + return self._iterator_getter(self) + + def _next_impl(self): + row = self._onerow_getter(self) + if row is _NO_ROW: + raise StopIteration() + else: + return row + @_generative def _column_slices(self, indexes): real_result = self._real_result if self._real_result else self @@ -892,16 +903,15 @@ class Result(ResultInternal): raise NotImplementedError() def __iter__(self): - return self._iterator_getter(self) + return self._iter_impl() def __next__(self): - row = self._onerow_getter(self) - if row is _NO_ROW: - raise StopIteration() - else: - return row + return self._next_impl() + + if py2k: - next = __next__ + def next(self): # noqa + return self._next_impl() def partitions(self, size=None): # type: (Optional[Int]) -> Iterator[List[Row]] @@ -1015,12 +1025,10 @@ class Result(ResultInternal): column of the first row, use the :meth:`.Result.scalar` method, or combine :meth:`.Result.scalars` and :meth:`.Result.first`. - .. comment: A warning is emitted if additional rows remain. - :return: a :class:`.Row` object, or None if no rows remain. - .. seealso:: + .. seealso:: :meth:`_result.Result.scalar` @@ -1186,18 +1194,6 @@ class FilterResult(ResultInternal): def _attributes(self): return self._real_result._attributes - def __iter__(self): - return self._iterator_getter(self) - - def __next__(self): - row = self._onerow_getter(self) - if row is _NO_ROW: - raise StopIteration() - else: - return row - - next = __next__ - def _fetchiter_impl(self): return self._real_result._fetchiter_impl() @@ -1299,6 +1295,17 @@ class ScalarResult(FilterResult): """ return self._allrows() + def __iter__(self): + return self._iter_impl() + + def __next__(self): + return self._next_impl() + + if py2k: + + def next(self): # noqa + return self._next_impl() + def first(self): # type: () -> Optional[Any] """Fetch the first object or None if no object is present. @@ -1409,7 +1416,7 @@ class MappingResult(FilterResult): def fetchall(self): # type: () -> List[Mapping] - """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + """A synonym for the :meth:`_engine.MappingResult.all` method.""" return self._allrows() @@ -1453,6 +1460,17 @@ class MappingResult(FilterResult): return self._allrows() + def __iter__(self): + return self._iter_impl() + + def __next__(self): + return self._next_impl() + + if py2k: + + def next(self): # noqa + return self._next_impl() + def first(self): # type: () -> Optional[Mapping] """Fetch the first object or None if no object is present. @@ -1519,13 +1537,11 @@ class FrozenResult(object): .. seealso:: - .. seealso:: - - :ref:`do_orm_execute_re_executing` - example usage within the - ORM to implement a result-set cache. + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. - :func:`_orm.loading.merge_frozen_result` - ORM function to merge - a frozen result back into a :class:`_orm.Session`. + :func:`_orm.loading.merge_frozen_result` - ORM function to merge + a frozen result back into a :class:`_orm.Session`. """ @@ -1624,21 +1640,36 @@ class ChunkedIteratorResult(IteratorResult): """ def __init__( - self, cursor_metadata, chunks, source_supports_scalars=False, raw=None + self, + cursor_metadata, + chunks, + source_supports_scalars=False, + raw=None, + dynamic_yield_per=False, ): self._metadata = cursor_metadata self.chunks = chunks self._source_supports_scalars = source_supports_scalars self.raw = raw self.iterator = itertools.chain.from_iterable(self.chunks(None)) + self.dynamic_yield_per = dynamic_yield_per @_generative def yield_per(self, num): + # TODO: this throws away the iterator which may be holding + # onto a chunk. the yield_per cannot be changed once any + # rows have been fetched. either find a way to enforce this, + # or we can't use itertools.chain and will instead have to + # keep track. + self._yield_per = num - # TODO: this should raise if the iterator has already been started. - # we can't change the yield mid-stream like this self.iterator = itertools.chain.from_iterable(self.chunks(num)) + def _fetchmany_impl(self, size=None): + if self.dynamic_yield_per: + self.iterator = itertools.chain.from_iterable(self.chunks(size)) + return super(ChunkedIteratorResult, self)._fetchmany_impl(size=size) + class MergedResult(IteratorResult): """A :class:`_engine.Result` that is merged from any number of @@ -1677,6 +1708,5 @@ class MergedResult(IteratorResult): def _soft_close(self, hard=False): for r in self._results: r._soft_close(hard=hard) - if hard: self.closed = True diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py new file mode 100644 index 0000000000..fbbc958d42 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -0,0 +1,9 @@ +from .engine import AsyncConnection # noqa +from .engine import AsyncEngine # noqa +from .engine import AsyncTransaction # noqa +from .engine import create_async_engine # noqa +from .result import AsyncMappingResult # noqa +from .result import AsyncResult # noqa +from .result import AsyncScalarResult # noqa +from .session import AsyncSession # noqa +from .session import AsyncSessionTransaction # noqa diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py new file mode 100644 index 0000000000..051f9e21a1 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -0,0 +1,25 @@ +import abc + +from . import exc as async_exc + + +class StartableContext(abc.ABC): + @abc.abstractmethod + async def start(self) -> "StartableContext": + pass + + def __await__(self): + return self.start().__await__() + + async def __aenter__(self): + return await self.start() + + @abc.abstractmethod + async def __aexit__(self, type_, value, traceback): + pass + + def _raise_for_not_started(self): + raise async_exc.AsyncContextNotStarted( + "%s context has not been started and object has not been awaited." + % (self.__class__.__name__) + ) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py new file mode 100644 index 0000000000..2d9198d169 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -0,0 +1,461 @@ +from typing import Any +from typing import Callable +from typing import Mapping +from typing import Optional + +from . import exc as async_exc +from .base import StartableContext +from .result import AsyncResult +from ... import exc +from ... import util +from ...engine import Connection +from ...engine import create_engine as _create_engine +from ...engine import Engine +from ...engine import Result +from ...engine import Transaction +from ...engine.base import OptionEngineMixin +from ...sql import Executable +from ...util.concurrency import greenlet_spawn + + +def create_async_engine(*arg, **kw): + """Create a new async engine instance. + + Arguments passed to :func:`_asyncio.create_async_engine` are mostly + identical to those passed to the :func:`_sa.create_engine` function. + The specified dialect must be an asyncio-compatible dialect + such as :ref:`dialect-postgresql-asyncpg`. + + .. versionadded:: 1.4 + + """ + + if kw.get("server_side_cursors", False): + raise exc.AsyncMethodRequired( + "Can't set server_side_cursors for async engine globally; " + "use the connection.stream() method for an async " + "streaming result set" + ) + kw["future"] = True + sync_engine = _create_engine(*arg, **kw) + return AsyncEngine(sync_engine) + + +class AsyncConnection(StartableContext): + """An asyncio proxy for a :class:`_engine.Connection`. + + :class:`_asyncio.AsyncConnection` is acquired using the + :meth:`_asyncio.AsyncEngine.connect` + method of :class:`_asyncio.AsyncEngine`:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + async with engine.connect() as conn: + result = await conn.execute(select(table)) + + .. versionadded:: 1.4 + + """ # noqa + + __slots__ = ( + "sync_engine", + "sync_connection", + ) + + def __init__( + self, sync_engine: Engine, sync_connection: Optional[Connection] = None + ): + self.sync_engine = sync_engine + self.sync_connection = sync_connection + + async def start(self): + """Start this :class:`_asyncio.AsyncConnection` object's context + outside of using a Python ``with:`` block. + + """ + if self.sync_connection: + raise exc.InvalidRequestError("connection is already started") + self.sync_connection = await (greenlet_spawn(self.sync_engine.connect)) + return self + + def _sync_connection(self): + if not self.sync_connection: + self._raise_for_not_started() + return self.sync_connection + + def begin(self) -> "AsyncTransaction": + """Begin a transaction prior to autobegin occurring. + + """ + self._sync_connection() + return AsyncTransaction(self) + + def begin_nested(self) -> "AsyncTransaction": + """Begin a nested transaction and return a transaction handle. + + """ + self._sync_connection() + return AsyncTransaction(self, nested=True) + + async def commit(self): + """Commit the transaction that is currently in progress. + + This method commits the current transaction if one has been started. + If no transaction was started, the method has no effect, assuming + the connection is in a non-invalidated state. + + A transaction is begun on a :class:`_future.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_future.Connection.begin` method is called. + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.commit) + + async def rollback(self): + """Roll back the transaction that is currently in progress. + + This method rolls back the current transaction if one has been started. + If no transaction was started, the method has no effect. If a + transaction was started and the connection is in an invalidated state, + the transaction is cleared using this method. + + A transaction is begun on a :class:`_future.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_future.Connection.begin` method is called. + + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.rollback) + + async def close(self): + """Close this :class:`_asyncio.AsyncConnection`. + + This has the effect of also rolling back the transaction if one + is in place. + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.close) + + async def exec_driver_sql( + self, + statement: Executable, + parameters: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + ) -> Result: + r"""Executes a driver-level SQL string and return buffered + :class:`_engine.Result`. + + """ + + conn = self._sync_connection() + + result = await greenlet_spawn( + conn.exec_driver_sql, statement, parameters, execution_options, + ) + if result.context._is_server_side: + raise async_exc.AsyncMethodRequired( + "Can't use the connection.exec_driver_sql() method with a " + "server-side cursor." + "Use the connection.stream() method for an async " + "streaming result set." + ) + + return result + + async def stream( + self, + statement: Executable, + parameters: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + ) -> AsyncResult: + """Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object.""" + + conn = self._sync_connection() + + result = await greenlet_spawn( + conn._execute_20, + statement, + parameters, + util.EMPTY_DICT.merge_with( + execution_options, {"stream_results": True} + ), + ) + if not result.context._is_server_side: + # TODO: real exception here + assert False, "server side result expected" + return AsyncResult(result) + + async def execute( + self, + statement: Executable, + parameters: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + ) -> Result: + r"""Executes a SQL statement construct and return a buffered + :class:`_engine.Result`. + + :param object: The statement to be executed. This is always + an object that is in both the :class:`_expression.ClauseElement` and + :class:`_expression.Executable` hierarchies, including: + + * :class:`_expression.Select` + * :class:`_expression.Insert`, :class:`_expression.Update`, + :class:`_expression.Delete` + * :class:`_expression.TextClause` and + :class:`_expression.TextualSelect` + * :class:`_schema.DDL` and objects which inherit from + :class:`_schema.DDLElement` + + :param parameters: parameters which will be bound into the statement. + This may be either a dictionary of parameter names to values, + or a mutable sequence (e.g. a list) of dictionaries. When a + list of dictionaries is passed, the underlying statement execution + will make use of the DBAPI ``cursor.executemany()`` method. + When a single dictionary is passed, the DBAPI ``cursor.execute()`` + method will be used. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_future.Connection.execution_options`. + + :return: a :class:`_engine.Result` object. + + """ + conn = self._sync_connection() + + result = await greenlet_spawn( + conn._execute_20, statement, parameters, execution_options, + ) + if result.context._is_server_side: + raise async_exc.AsyncMethodRequired( + "Can't use the connection.execute() method with a " + "server-side cursor." + "Use the connection.stream() method for an async " + "streaming result set." + ) + return result + + async def scalar( + self, + statement: Executable, + parameters: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + ) -> Any: + r"""Executes a SQL statement construct and returns a scalar object. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalar` method after invoking the + :meth:`_future.Connection.execute` method. Parameters are equivalent. + + :return: a scalar Python value representing the first column of the + first row returned. + + """ + result = await self.execute(statement, parameters, execution_options) + return result.scalar() + + async def run_sync(self, fn: Callable, *arg, **kw) -> Any: + """"Invoke the given sync callable passing self as the first argument. + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + E.g.:: + + with async_engine.begin() as conn: + await conn.run_sync(metadata.create_all) + + """ + + conn = self._sync_connection() + + return await greenlet_spawn(fn, conn, *arg, **kw) + + def __await__(self): + return self.start().__await__() + + async def __aexit__(self, type_, value, traceback): + await self.close() + + +class AsyncEngine: + """An asyncio proxy for a :class:`_engine.Engine`. + + :class:`_asyncio.AsyncEngine` is acquired using the + :func:`_asyncio.create_async_engine` function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + .. versionadded:: 1.4 + + + """ # noqa + + __slots__ = ("sync_engine",) + + _connection_cls = AsyncConnection + + _option_cls: type + + class _trans_ctx(StartableContext): + def __init__(self, conn): + self.conn = conn + + async def start(self): + await self.conn.start() + self.transaction = self.conn.begin() + await self.transaction.__aenter__() + + return self.conn + + async def __aexit__(self, type_, value, traceback): + if type_ is not None: + await self.transaction.rollback() + else: + if self.transaction.is_active: + await self.transaction.commit() + await self.conn.close() + + def __init__(self, sync_engine: Engine): + self.sync_engine = sync_engine + + def begin(self): + """Return a context manager which when entered will deliver an + :class:`_asyncio.AsyncConnection` with an + :class:`_asyncio.AsyncTransaction` established. + + E.g.:: + + async with async_engine.begin() as conn: + await conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) + await conn.execute(text("my_special_procedure(5)")) + + + """ + conn = self.connect() + return self._trans_ctx(conn) + + def connect(self) -> AsyncConnection: + """Return an :class:`_asyncio.AsyncConnection` object. + + The :class:`_asyncio.AsyncConnection` will procure a database + connection from the underlying connection pool when it is entered + as an async context manager:: + + async with async_engine.connect() as conn: + result = await conn.execute(select(user_table)) + + The :class:`_asyncio.AsyncConnection` may also be started outside of a + context manager by invoking its :meth:`_asyncio.AsyncConnection.start` + method. + + """ + + return self._connection_cls(self.sync_engine) + + async def raw_connection(self) -> Any: + """Return a "raw" DBAPI connection from the connection pool. + + .. seealso:: + + :ref:`dbapi_connections` + + """ + return await greenlet_spawn(self.sync_engine.raw_connection) + + +class AsyncOptionEngine(OptionEngineMixin, AsyncEngine): + pass + + +AsyncEngine._option_cls = AsyncOptionEngine + + +class AsyncTransaction(StartableContext): + """An asyncio proxy for a :class:`_engine.Transaction`.""" + + __slots__ = ("connection", "sync_transaction", "nested") + + def __init__(self, connection: AsyncConnection, nested: bool = False): + self.connection = connection + self.sync_transaction: Optional[Transaction] = None + self.nested = nested + + def _sync_transaction(self): + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + @property + def is_valid(self) -> bool: + return self._sync_transaction().is_valid + + @property + def is_active(self) -> bool: + return self._sync_transaction().is_active + + async def close(self): + """Close this :class:`.Transaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + await greenlet_spawn(self._sync_transaction().close) + + async def rollback(self): + """Roll back this :class:`.Transaction`. + + """ + await greenlet_spawn(self._sync_transaction().rollback) + + async def commit(self): + """Commit this :class:`.Transaction`.""" + + await greenlet_spawn(self._sync_transaction().commit) + + async def start(self): + """Start this :class:`_asyncio.AsyncTransaction` object's context + outside of using a Python ``with:`` block. + + """ + + self.sync_transaction = await greenlet_spawn( + self.connection._sync_connection().begin_nested + if self.nested + else self.connection._sync_connection().begin + ) + return self + + async def __aexit__(self, type_, value, traceback): + if type_ is None and self.is_active: + try: + await self.commit() + except: + with util.safe_reraise(): + await self.rollback() + else: + await self.rollback() + + +def _get_sync_engine(async_engine): + try: + return async_engine.sync_engine + except AttributeError as e: + raise exc.ArgumentError( + "AsyncEngine expected, got %r" % async_engine + ) from e diff --git a/lib/sqlalchemy/ext/asyncio/exc.py b/lib/sqlalchemy/ext/asyncio/exc.py new file mode 100644 index 0000000000..6137bf6df6 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/exc.py @@ -0,0 +1,14 @@ +from ... import exc + + +class AsyncMethodRequired(exc.InvalidRequestError): + """an API can't be used because its result would not be + compatible with async""" + + +class AsyncContextNotStarted(exc.InvalidRequestError): + """a startable context manager has not been started.""" + + +class AsyncContextAlreadyStarted(exc.InvalidRequestError): + """a startable context manager is already started.""" diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py new file mode 100644 index 0000000000..52b40acbab --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -0,0 +1,669 @@ +import operator + +from ... import util +from ...engine.result import _NO_ROW +from ...engine.result import FilterResult +from ...engine.result import FrozenResult +from ...engine.result import MergedResult +from ...util.concurrency import greenlet_spawn + +if util.TYPE_CHECKING: + from typing import Any + from typing import List + from typing import Optional + from typing import Int + from typing import Iterator + from typing import Mapping + from ...engine.result import Row + + +class AsyncResult(FilterResult): + """An asyncio wrapper around a :class:`_result.Result` object. + + The :class:`_asyncio.AsyncResult` only applies to statement executions that + use a server-side cursor. It is returned only from the + :meth:`_asyncio.AsyncConnection.stream` and + :meth:`_asyncio.AsyncSession.stream` methods. + + .. versionadded:: 1.4 + + """ + + def __init__(self, real_result): + self._real_result = real_result + + self._metadata = real_result._metadata + self._unique_filter_state = real_result._unique_filter_state + + # BaseCursorResult pre-generates the "_row_getter". Use that + # if available rather than building a second one + if "_row_getter" in real_result.__dict__: + self._set_memoized_attribute( + "_row_getter", real_result.__dict__["_row_getter"] + ) + + def keys(self): + """Return the :meth:`_engine.Result.keys` collection from the + underlying :class:`_engine.Result`. + + """ + return self._metadata.keys + + def unique(self, strategy=None): + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncResult`. + + Refer to :meth:`_engine.Result.unique` in the synchronous + SQLAlchemy API for a complete behavioral description. + + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions): + # type: (*object) -> AsyncResult + r"""Establish the columns that should be returned in each row. + + Refer to :meth:`_engine.Result.columns` in the synchronous + SQLAlchemy API for a complete behavioral description. + + + """ + return self._column_slices(col_expressions) + + async def partitions(self, size=None): + # type: (Optional[Int]) -> Iterator[List[Any]] + """Iterate through sub-lists of rows of the size given. + + An async iterator is returned:: + + async def scroll_results(connection): + result = await connection.stream(select(users_table)) + + async for partition in result.partitions(100): + print("list of rows: %s" % partition) + + .. seealso:: + + :meth:`_engine.Result.partitions` + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchone(self): + # type: () -> Row + """Fetch one row. + + When all rows are exhausted, returns None. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch the first row of a result only, use the + :meth:`_engine.Result.first` method. To iterate through all + rows, iterate the :class:`_engine.Result` object directly. + + :return: a :class:`.Row` object if no filters are applied, or None + if no rows remain. + + """ + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany(self, size=None): + # type: (Optional[Int]) -> List[Row] + """Fetch many rows. + + When all rows are exhausted, returns an empty list. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch rows in groups, use the + :meth:`._asyncio.AsyncResult.partitions` method. + + :return: a list of :class:`.Row` objects. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.partitions` + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self): + # type: () -> List[Row] + """Return all rows in a list. + + Closes the result set after invocation. Subsequent invocations + will return an empty list. + + :return: a list of :class:`.Row` objects. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self): + return self + + async def __anext__(self): + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self): + # type: () -> Row + """Fetch the first row or None if no row is present. + + Closes the result set and discards remaining rows. + + .. note:: This method returns one **row**, e.g. tuple, by default. To + return exactly one single scalar value, that is, the first column of + the first row, use the :meth:`_asyncio.AsyncResult.scalar` method, + or combine :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.first`. + + :return: a :class:`.Row` object, or None + if no rows remain. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.scalar` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self): + # type: () -> Optional[Row] + """Return at most one result or raise an exception. + + Returns ``None`` if the result has no rows. + Raises :class:`.MultipleResultsFound` + if multiple rows are returned. + + .. versionadded:: 1.4 + + :return: The first :class:`.Row` or None if no row is available. + + :raises: :class:`.MultipleResultsFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def scalar_one(self): + # type: () -> Any + """Return exactly one scalar result or raise an exception. + + This is equvalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, True, True) + + async def scalar_one_or_none(self): + # type: () -> Optional[Any] + """Return exactly one or no scalar result. + + This is equvalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one_or_none`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, False, True) + + async def one(self): + # type: () -> Row + """Return exactly one row or raise an exception. + + Raises :class:`.NoResultFound` if the result returns no + rows, or :class:`.MultipleResultsFound` if multiple rows + would be returned. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_asyncio.AsyncResult.scalar_one` method, or combine + :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.one`. + + .. versionadded:: 1.4 + + :return: The first :class:`.Row`. + + :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalar_one` + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + async def scalar(self): + # type: () -> Optional[Any] + """Fetch the first column of the first row, and close the result set. + + Returns None if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or None if no rows remain. + + """ + return await greenlet_spawn(self._only_one_row, False, False, True) + + async def freeze(self): + """Return a callable object that will produce copies of this + :class:`_asyncio.AsyncResult` when invoked. + + The callable object returned is an instance of + :class:`_engine.FrozenResult`. + + This is used for result set caching. The method must be called + on the result when it has been unconsumed, and calling the method + will consume the result fully. When the :class:`_engine.FrozenResult` + is retrieved from a cache, it can be called any number of times where + it will produce a new :class:`_engine.Result` object each time + against its stored set of rows. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + """ + + return await greenlet_spawn(FrozenResult, self) + + def merge(self, *others): + """Merge this :class:`_asyncio.AsyncResult` with other compatible result + objects. + + The object returned is an instance of :class:`_engine.MergedResult`, + which will be composed of iterators from the given result + objects. + + The new result will use the metadata from this result object. + The subsequent result objects must be against an identical + set of result / cursor metadata, otherwise the behavior is + undefined. + + """ + return MergedResult(self._metadata, (self,) + others) + + def scalars(self, index=0): + # type: (Int) -> AsyncScalarResult + """Return an :class:`_asyncio.AsyncScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + Refer to :meth:`_result.Result.scalars` in the synchronous + SQLAlchemy API for a complete behavioral description. + + :param index: integer or row key indicating the column to be fetched + from each row, defaults to ``0`` indicating the first column. + + :return: a new :class:`_asyncio.AsyncScalarResult` filtering object + referring to this :class:`_asyncio.AsyncResult` object. + + """ + return AsyncScalarResult(self._real_result, index) + + def mappings(self): + # type() -> AsyncMappingResult + """Apply a mappings filter to returned rows, returning an instance of + :class:`_asyncio.AsyncMappingResult`. + + When this filter is applied, fetching rows will return + :class:`.RowMapping` objects instead of :class:`.Row` objects. + + Refer to :meth:`_result.Result.mappings` in the synchronous + SQLAlchemy API for a complete behavioral description. + + :return: a new :class:`_asyncio.AsyncMappingResult` filtering object + referring to the underlying :class:`_result.Result` object. + + """ + + return AsyncMappingResult(self._real_result) + + +class AsyncScalarResult(FilterResult): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values + rather than :class:`_row.Row` values. + + The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.scalars` method. + + Refer to the :class:`_result.ScalarResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + _generate_rows = False + + def __init__(self, real_result, index): + self._real_result = real_result + + if real_result._source_supports_scalars: + self._metadata = real_result._metadata + self._post_creational_filter = None + else: + self._metadata = real_result._metadata._reduce([index]) + self._post_creational_filter = operator.itemgetter(0) + + self._unique_filter_state = real_result._unique_filter_state + + def unique(self, strategy=None): + # type: () -> AsyncScalarResult + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncScalarResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + async def partitions(self, size=None): + # type: (Optional[Int]) -> Iterator[List[Any]] + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self): + # type: () -> List[Any] + """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchmany(self, size=None): + # type: (Optional[Int]) -> List[Any] + """Fetch many objects. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self): + # type: () -> List[Any] + """Return all scalar values in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._allrows) + + def __aiter__(self): + return self + + async def __anext__(self): + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self): + # type: () -> Optional[Any] + """Fetch the first object or None if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self): + # type: () -> Optional[Any] + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self): + # type: () -> Any + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +class AsyncMappingResult(FilterResult): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values + rather than :class:`_engine.Row` values. + + The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.mappings` method. + + Refer to the :class:`_result.MappingResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + _generate_rows = True + + _post_creational_filter = operator.attrgetter("_mapping") + + def __init__(self, result): + self._real_result = result + self._unique_filter_state = result._unique_filter_state + self._metadata = result._metadata + if result._source_supports_scalars: + self._metadata = self._metadata._reduce([0]) + + def keys(self): + """Return an iterable view which yields the string keys that would + be represented by each :class:`.Row`. + + The view also can be tested for key containment using the Python + ``in`` operator, which will test both for the string keys represented + in the view, as well as for alternate keys such as column objects. + + .. versionchanged:: 1.4 a key view object is returned rather than a + plain list. + + + """ + return self._metadata.keys + + def unique(self, strategy=None): + # type: () -> AsyncMappingResult + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncMappingResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions): + # type: (*object) -> AsyncMappingResult + r"""Establish the columns that should be returned in each row. + + + """ + return self._column_slices(col_expressions) + + async def partitions(self, size=None): + # type: (Optional[Int]) -> Iterator[List[Mapping]] + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self): + # type: () -> List[Mapping] + """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchone(self): + # type: () -> Mapping + """Fetch one object. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany(self, size=None): + # type: (Optional[Int]) -> List[Mapping] + """Fetch many objects. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self): + # type: () -> List[Mapping] + """Return all scalar values in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self): + return self + + async def __anext__(self): + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self): + # type: () -> Optional[Mapping] + """Fetch the first object or None if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self): + # type: () -> Optional[Mapping] + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self): + # type: () -> Mapping + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py new file mode 100644 index 0000000000..1673017808 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -0,0 +1,293 @@ +from typing import Any +from typing import Callable +from typing import List +from typing import Mapping +from typing import Optional + +from . import engine +from . import result as _result +from .base import StartableContext +from .engine import AsyncEngine +from ... import util +from ...engine import Result +from ...orm import Session +from ...sql import Executable +from ...util.concurrency import greenlet_spawn + + +class AsyncSession: + """Asyncio version of :class:`_orm.Session`. + + + .. versionadded:: 1.4 + + """ + + def __init__( + self, + bind: AsyncEngine = None, + binds: Mapping[object, AsyncEngine] = None, + **kw + ): + kw["future"] = True + if bind: + bind = engine._get_sync_engine(bind) + + if binds: + binds = { + key: engine._get_sync_engine(b) for key, b in binds.items() + } + + self.sync_session = Session(bind=bind, binds=binds, **kw) + + def add(self, instance: object) -> None: + """Place an object in this :class:`_asyncio.AsyncSession`. + + .. seealso:: + + :meth:`_orm.Session.add` + + """ + self.sync_session.add(instance) + + def add_all(self, instances: List[object]) -> None: + """Add the given collection of instances to this + :class:`_asyncio.AsyncSession`.""" + + self.sync_session.add_all(instances) + + def expire_all(self): + """Expires all persistent instances within this Session. + + See :meth:`_orm.Session.expire_all` for usage details. + + """ + self.sync_session.expire_all() + + def expire(self, instance, attribute_names=None): + """Expire the attributes on an instance. + + See :meth:`._orm.Session.expire` for usage details. + + """ + self.sync_session.expire() + + async def refresh( + self, instance, attribute_names=None, with_for_update=None + ): + """Expire and refresh the attributes on the given instance. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + This is the async version of the :meth:`_orm.Session.refresh` method. + See that method for a complete description of all options. + + """ + + return await greenlet_spawn( + self.sync_session.refresh, + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + async def run_sync(self, fn: Callable, *arg, **kw) -> Any: + """Invoke the given sync callable passing sync self as the first + argument. + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + E.g.:: + + with AsyncSession(async_engine) as session: + await session.run_sync(some_business_method) + + """ + + return await greenlet_spawn(fn, self.sync_session, *arg, **kw) + + async def execute( + self, + statement: Executable, + params: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + bind_arguments: Optional[Mapping] = None, + **kw + ) -> Result: + """Execute a statement and return a buffered + :class:`_engine.Result` object.""" + + execution_options = execution_options.union({"prebuffer_rows": True}) + + return await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + + async def stream( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object.""" + + execution_options = execution_options.union({"stream_results": True}) + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return _result.AsyncResult(result) + + async def merge(self, instance, load=True): + """Copy the state of a given instance into a corresponding instance + within this :class:`_asyncio.AsyncSession`. + + """ + return await greenlet_spawn( + self.sync_session.merge, instance, load=load + ) + + async def flush(self, objects=None): + """Flush all the object changes to the database. + + .. seealso:: + + :meth:`_orm.Session.flush` + + """ + await greenlet_spawn(self.sync_session.flush, objects=objects) + + async def connection(self): + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this + :class:`.Session` object's transactional state. + + """ + sync_connection = await greenlet_spawn(self.sync_session.connection) + return engine.AsyncConnection(sync_connection.engine, sync_connection) + + def begin(self, **kw): + """Return an :class:`_asyncio.AsyncSessionTransaction` object. + + The underlying :class:`_orm.Session` will perform the + "begin" action when the :class:`_asyncio.AsyncSessionTransaction` + object is entered:: + + async with async_session.begin(): + # .. ORM transaction is begun + + Note that database IO will not normally occur when the session-level + transaction is begun, as database transactions begin on an + on-demand basis. However, the begin block is async to accommodate + for a :meth:`_orm.SessionEvents.after_transaction_create` + event hook that may perform IO. + + For a general description of ORM begin, see + :meth:`_orm.Session.begin`. + + """ + + return AsyncSessionTransaction(self) + + def begin_nested(self, **kw): + """Return an :class:`_asyncio.AsyncSessionTransaction` object + which will begin a "nested" transaction, e.g. SAVEPOINT. + + Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`. + + For a general description of ORM begin nested, see + :meth:`_orm.Session.begin_nested`. + + """ + + return AsyncSessionTransaction(self, nested=True) + + async def rollback(self): + return await greenlet_spawn(self.sync_session.rollback) + + async def commit(self): + return await greenlet_spawn(self.sync_session.commit) + + async def close(self): + return await greenlet_spawn(self.sync_session.close) + + async def __aenter__(self): + return self + + async def __aexit__(self, type_, value, traceback): + await self.close() + + +class AsyncSessionTransaction(StartableContext): + """A wrapper for the ORM :class:`_orm.SessionTransaction` object. + + This object is provided so that a transaction-holding object + for the :meth:`_asyncio.AsyncSession.begin` may be returned. + + The object supports both explicit calls to + :meth:`_asyncio.AsyncSessionTransaction.commit` and + :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an + async context manager. + + + .. versionadded:: 1.4 + + """ + + __slots__ = ("session", "sync_transaction", "nested") + + def __init__(self, session, nested=False): + self.session = session + self.nested = nested + self.sync_transaction = None + + @property + def is_active(self): + return ( + self._sync_transaction() is not None + and self._sync_transaction().is_active + ) + + def _sync_transaction(self): + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + async def rollback(self): + """Roll back this :class:`_asyncio.AsyncTransaction`. + + """ + await greenlet_spawn(self._sync_transaction().rollback) + + async def commit(self): + """Commit this :class:`_asyncio.AsyncTransaction`.""" + + await greenlet_spawn(self._sync_transaction().commit) + + async def start(self): + self.sync_transaction = await greenlet_spawn( + self.session.sync_session.begin_nested + if self.nested + else self.session.sync_session.begin + ) + return self + + async def __aexit__(self, type_, value, traceback): + return await greenlet_spawn( + self._sync_transaction().__exit__, type_, value, traceback + ) diff --git a/lib/sqlalchemy/future/__init__.py b/lib/sqlalchemy/future/__init__.py index 37ce46e477..b07b9b040b 100644 --- a/lib/sqlalchemy/future/__init__.py +++ b/lib/sqlalchemy/future/__init__.py @@ -14,4 +14,5 @@ from .engine import Engine # noqa from ..sql.selectable import Select # noqa from ..util.langhelpers import public_factory + select = public_factory(Select._create_future_select, ".future.select") diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py index d5922daa32..dd72360eda 100644 --- a/lib/sqlalchemy/future/engine.py +++ b/lib/sqlalchemy/future/engine.py @@ -359,6 +359,22 @@ class Engine(_LegacyEngine): execution_options=legacy_engine._execution_options, ) + class _trans_ctx(object): + def __init__(self, conn): + self.conn = conn + + def __enter__(self): + self.transaction = self.conn.begin() + return self.conn + + def __exit__(self, type_, value, traceback): + if type_ is not None: + self.transaction.rollback() + else: + if self.transaction.is_active: + self.transaction.commit() + self.conn.close() + def begin(self): """Return a :class:`_future.Connection` object with a transaction begun. @@ -381,7 +397,8 @@ class Engine(_LegacyEngine): :meth:`_future.Connection.begin` """ - return super(Engine, self).begin() + conn = self.connect() + return self._trans_ctx(conn) def connect(self): """Return a new :class:`_future.Connection` object. diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 2eb3e13689..fd3e92055b 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -125,8 +125,21 @@ def instances(cursor, context): if not yield_per: break + if context.execution_options.get("prebuffer_rows", False): + # this is a bit of a hack at the moment. + # I would rather have some option in the result to pre-buffer + # internally. + _prebuffered = list(chunks(None)) + + def chunks(size): + return iter(_prebuffered) + result = ChunkedIteratorResult( - row_metadata, chunks, source_supports_scalars=single_entity, raw=cursor + row_metadata, + chunks, + source_supports_scalars=single_entity, + raw=cursor, + dynamic_yield_per=cursor.context._is_server_side, ) result._attributes = result._attributes.union( diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 7c254c61bd..676dd438c1 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1211,7 +1211,6 @@ def _emit_insert_statements( has_all_pks, has_all_defaults, ) in records: - if value_params: result = connection.execute( statement.values(value_params), params diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index eb0d375173..353f34333c 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -28,6 +28,7 @@ from .base import reset_rollback from .dbapi_proxy import clear_managers from .dbapi_proxy import manage from .impl import AssertionPool +from .impl import AsyncAdaptedQueuePool from .impl import NullPool from .impl import QueuePool from .impl import SingletonThreadPool @@ -44,6 +45,7 @@ __all__ = [ "AssertionPool", "NullPool", "QueuePool", + "AsyncAdaptedQueuePool", "SingletonThreadPool", "StaticPool", ] diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index 0fe7612b92..e1a9f00db1 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -33,6 +33,8 @@ class QueuePool(Pool): """ + _queue_class = sqla_queue.Queue + def __init__( self, creator, @@ -95,7 +97,7 @@ class QueuePool(Pool): """ Pool.__init__(self, creator, **kw) - self._pool = sqla_queue.Queue(pool_size, use_lifo=use_lifo) + self._pool = self._queue_class(pool_size, use_lifo=use_lifo) self._overflow = 0 - pool_size self._max_overflow = max_overflow self._timeout = timeout @@ -215,6 +217,10 @@ class QueuePool(Pool): return self._pool.maxsize - self._pool.qsize() + self._overflow +class AsyncAdaptedQueuePool(QueuePool): + _queue_class = sqla_queue.AsyncAdaptedQueue + + class NullPool(Pool): """A Pool which does not pool connections. diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 2fe6f35d2b..8f6dc8e72d 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -5,6 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from .base import Executable # noqa from .compiler import COLLECT_CARTESIAN_PRODUCTS # noqa from .compiler import FROM_LINTING # noqa from .compiler import NO_LINTING # noqa diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 186f885d8a..64663a6b01 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2316,6 +2316,22 @@ class JSON(Indexable, TypeEngine): """ + class JSONIntIndexType(JSONIndexType): + """Placeholder for the datatype of a JSON index value. + + This allows execution-time processing of JSON index values + for special syntaxes. + + """ + + class JSONStrIndexType(JSONIndexType): + """Placeholder for the datatype of a JSON index value. + + This allows execution-time processing of JSON index values + for special syntaxes. + + """ + class JSONPathType(JSONElementType): """Placeholder type for JSON path operations. @@ -2346,7 +2362,9 @@ class JSON(Indexable, TypeEngine): index, expr=self.expr, operator=operators.json_getitem_op, - bindparam_type=JSON.JSONIndexType, + bindparam_type=JSON.JSONIntIndexType + if isinstance(index, int) + else JSON.JSONStrIndexType, ) operator = operators.json_getitem_op diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 79b7f9eb3d..9b1164874a 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -12,7 +12,6 @@ from .assertions import assert_raises # noqa from .assertions import assert_raises_context_ok # noqa from .assertions import assert_raises_message # noqa from .assertions import assert_raises_message_context_ok # noqa -from .assertions import assert_raises_return # noqa from .assertions import AssertsCompiledSQL # noqa from .assertions import AssertsExecutionResults # noqa from .assertions import ComparesTables # noqa @@ -23,6 +22,8 @@ from .assertions import eq_ignore_whitespace # noqa from .assertions import eq_regex # noqa from .assertions import expect_deprecated # noqa from .assertions import expect_deprecated_20 # noqa +from .assertions import expect_raises # noqa +from .assertions import expect_raises_message # noqa from .assertions import expect_warnings # noqa from .assertions import in_ # noqa from .assertions import is_ # noqa @@ -35,6 +36,7 @@ from .assertions import ne_ # noqa from .assertions import not_in_ # noqa from .assertions import startswith_ # noqa from .assertions import uses_deprecated # noqa +from .config import async_test # noqa from .config import combinations # noqa from .config import db # noqa from .config import fixture # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index ecc6a4ab83..fe74be8235 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -298,10 +298,6 @@ def assert_raises_context_ok(except_cls, callable_, *args, **kw): return _assert_raises(except_cls, callable_, args, kw,) -def assert_raises_return(except_cls, callable_, *args, **kw): - return _assert_raises(except_cls, callable_, args, kw, check_context=True) - - def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): return _assert_raises( except_cls, callable_, args, kwargs, msg=msg, check_context=True @@ -317,14 +313,26 @@ def assert_raises_message_context_ok( def _assert_raises( except_cls, callable_, args, kwargs, msg=None, check_context=False ): - ret_err = None + + with _expect_raises(except_cls, msg, check_context) as ec: + callable_(*args, **kwargs) + return ec.error + + +class _ErrorContainer(object): + error = None + + +@contextlib.contextmanager +def _expect_raises(except_cls, msg=None, check_context=False): + ec = _ErrorContainer() if check_context: are_we_already_in_a_traceback = sys.exc_info()[0] try: - callable_(*args, **kwargs) + yield ec success = False except except_cls as err: - ret_err = err + ec.error = err success = True if msg is not None: assert re.search( @@ -337,7 +345,13 @@ def _assert_raises( # assert outside the block so it works for AssertionError too ! assert success, "Callable did not raise an exception" - return ret_err + +def expect_raises(except_cls): + return _expect_raises(except_cls, check_context=True) + + +def expect_raises_message(except_cls, msg): + return _expect_raises(except_cls, msg=msg, check_context=True) class AssertsCompiledSQL(object): diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py new file mode 100644 index 0000000000..2e274de16f --- /dev/null +++ b/lib/sqlalchemy/testing/asyncio.py @@ -0,0 +1,14 @@ +from .assertions import assert_raises as _assert_raises +from .assertions import assert_raises_message as _assert_raises_message +from ..util import await_fallback as await_ +from ..util import greenlet_spawn + + +async def assert_raises_async(except_cls, msg, coroutine): + await greenlet_spawn(_assert_raises, except_cls, await_, coroutine) + + +async def assert_raises_message_async(except_cls, msg, coroutine): + await greenlet_spawn( + _assert_raises_message, except_cls, msg, await_, coroutine + ) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index e97821d722..8c232f3198 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -178,3 +178,7 @@ class Config(object): def skip_test(msg): raise _fixture_functions.skip_test_exception(msg) + + +def async_test(fn): + return _fixture_functions.async_test(fn) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 1583147d47..85d3374de1 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -61,6 +61,7 @@ class TestBase(object): @config.fixture() def connection(self): eng = getattr(self, "bind", config.db) + conn = eng.connect() trans = conn.begin() try: diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index b31a4ff3e3..49ff0f9757 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -48,7 +48,6 @@ testing = None util = None file_config = None - logging = None include_tags = set() exclude_tags = set() @@ -193,6 +192,12 @@ def setup_options(make_option): default=False, help="Unconditionally write/update profiling data.", ) + make_option( + "--dump-pyannotate", + type=str, + dest="dump_pyannotate", + help="Run pyannotate and dump json info to given file", + ) def configure_follower(follower_ident): @@ -378,7 +383,6 @@ def _engine_uri(options, file_config): cfg = provision.setup_config( db_url, options, file_config, provision.FOLLOWER_IDENT ) - if not config._current: cfg.set_as_current(cfg, testing) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 015598952d..3df239afa9 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -25,6 +25,11 @@ else: if typing.TYPE_CHECKING: from typing import Sequence +try: + import asyncio +except ImportError: + pass + try: import xdist # noqa @@ -101,6 +106,24 @@ def pytest_configure(config): plugin_base.set_fixture_functions(PytestFixtureFunctions) + if config.option.dump_pyannotate: + global DUMP_PYANNOTATE + DUMP_PYANNOTATE = True + + +DUMP_PYANNOTATE = False + + +@pytest.fixture(autouse=True) +def collect_types_fixture(): + if DUMP_PYANNOTATE: + from pyannotate_runtime import collect_types + + collect_types.start() + yield + if DUMP_PYANNOTATE: + collect_types.stop() + def pytest_sessionstart(session): plugin_base.post_begin() @@ -109,6 +132,31 @@ def pytest_sessionstart(session): def pytest_sessionfinish(session): plugin_base.final_process_cleanup() + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + collect_types.dump_stats(session.config.option.dump_pyannotate) + + +def pytest_collection_finish(session): + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + lib_sqlalchemy = os.path.abspath("lib/sqlalchemy") + + def _filter(filename): + filename = os.path.normpath(os.path.abspath(filename)) + if "lib/sqlalchemy" not in os.path.commonpath( + [filename, lib_sqlalchemy] + ): + return None + if "testing" in filename: + return None + + return filename + + collect_types.init_types_collection(filter_filename=_filter) + if has_xdist: import uuid @@ -518,3 +566,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): def get_current_test_name(self): return os.environ.get("PYTEST_CURRENT_TEST") + + def async_test(self, fn): + @_pytest_fn_decorator + def decorate(fn, *args, **kwargs): + asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs)) + + return decorate(fn) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 25998c07bb..36d0ce4c61 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1193,6 +1193,12 @@ class SuiteRequirements(Requirements): except ImportError: return False + @property + def async_dialect(self): + """dialect makes use of await_() to invoke operations on the DBAPI.""" + + return exclusions.closed() + @property def computed_columns(self): "Supports computed columns" diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 2eb986c74a..e6f6068c89 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -238,6 +238,8 @@ class ServerSideCursorsTest( elif self.engine.dialect.driver == "mysqldb": sscursor = __import__("MySQLdb.cursors").cursors.SSCursor return isinstance(cursor, sscursor) + elif self.engine.dialect.driver == "asyncpg": + return cursor.server_side else: return False @@ -331,29 +333,74 @@ class ServerSideCursorsTest( result.close() @testing.provide_metadata - def test_roundtrip(self): + def test_roundtrip_fetchall(self): md = self.metadata - self._fixture(True) + engine = self._fixture(True) test_table = Table( "test_table", md, Column("id", Integer, primary_key=True), Column("data", String(50)), ) - test_table.create(checkfirst=True) - test_table.insert().execute(data="data1") - test_table.insert().execute(data="data2") - eq_( - test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, "data1"), (2, "data2")], - ) - test_table.update().where(test_table.c.id == 2).values( - data=test_table.c.data + " updated" - ).execute() - eq_( - test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, "data1"), (2, "data2 updated")], + + with engine.connect() as connection: + test_table.create(connection, checkfirst=True) + connection.execute(test_table.insert(), dict(data="data1")) + connection.execute(test_table.insert(), dict(data="data2")) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2")], + ) + connection.execute( + test_table.update() + .where(test_table.c.id == 2) + .values(data=test_table.c.data + " updated") + ) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) + connection.execute(test_table.delete()) + eq_( + connection.scalar( + select([func.count("*")]).select_from(test_table) + ), + 0, + ) + + @testing.provide_metadata + def test_roundtrip_fetchmany(self): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), ) - test_table.delete().execute() - eq_(select([func.count("*")]).select_from(test_table).scalar(), 0) + + with engine.connect() as connection: + test_table.create(connection, checkfirst=True) + connection.execute( + test_table.insert(), + [dict(data="data%d" % i) for i in range(1, 20)], + ) + + result = connection.execute( + test_table.select().order_by(test_table.c.id) + ) + + eq_( + result.fetchmany(5), [(i, "data%d" % i) for i in range(1, 6)], + ) + eq_( + result.fetchmany(10), + [(i, "data%d" % i) for i in range(6, 16)], + ) + eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)]) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 48144f8859..5e6ac1eabd 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -35,6 +35,7 @@ from ... import Text from ... import Time from ... import TIMESTAMP from ... import type_coerce +from ... import TypeDecorator from ... import Unicode from ... import UnicodeText from ... import util @@ -282,6 +283,9 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): @classmethod def define_tables(cls, metadata): + class Decorated(TypeDecorator): + impl = cls.datatype + Table( "date_table", metadata, @@ -289,6 +293,7 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("date_data", cls.datatype), + Column("decorated_date_data", Decorated), ) def test_round_trip(self, connection): @@ -302,6 +307,21 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): eq_(row, (compare,)) assert isinstance(row[0], type(compare)) + def test_round_trip_decorated(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"decorated_date_data": self.data} + ) + + row = connection.execute( + select(date_table.c.decorated_date_data) + ).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + def test_null(self, connection): date_table = self.tables.date_table @@ -526,6 +546,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): Float(precision=8, asdecimal=True), [15.7563, decimal.Decimal("15.7563"), None], [decimal.Decimal("15.7563"), None], + filter_=lambda n: n is not None and round(n, 4) or None, ) def test_float_as_float(self): @@ -777,6 +798,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): # ("json", {"foo": "bar"}), id_="sa", )(fn) + return fn @_index_fixtures @@ -1139,7 +1161,15 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest): and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6" ) - def test_crit_against_string_coerce_type(self): + def test_crit_against_int_basic(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + self._test_index_criteria( + and_(name == "r6", cast(col["a"], String) == "5"), "r6" + ) + + def _dont_test_crit_against_string_coerce_type(self): name = self.tables.data_table.c.name col = self.tables.data_table.c["data"] @@ -1152,15 +1182,7 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest): test_literal=False, ) - def test_crit_against_int_basic(self): - name = self.tables.data_table.c.name - col = self.tables.data_table.c["data"] - - self._test_index_criteria( - and_(name == "r6", cast(col["a"], String) == "5"), "r6" - ) - - def test_crit_against_int_coerce_type(self): + def _dont_test_crit_against_int_coerce_type(self): name = self.tables.data_table.c.name col = self.tables.data_table.c["data"] diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index ce96027455..1e3eb9a29e 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -90,6 +90,10 @@ from .compat import unquote_plus # noqa from .compat import win32 # noqa from .compat import with_metaclass # noqa from .compat import zip_longest # noqa +from .concurrency import asyncio # noqa +from .concurrency import await_fallback # noqa +from .concurrency import await_only # noqa +from .concurrency import greenlet_spawn # noqa from .deprecations import deprecated # noqa from .deprecations import deprecated_20 # noqa from .deprecations import deprecated_20_cls # noqa diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py new file mode 100644 index 0000000000..3b112ff7db --- /dev/null +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -0,0 +1,110 @@ +import asyncio +import sys +from typing import Any +from typing import Callable +from typing import Coroutine + +from .. import exc + +try: + import greenlet + + # implementation based on snaury gist at + # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef + # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 + + class _AsyncIoGreenlet(greenlet.greenlet): + def __init__(self, fn, driver): + greenlet.greenlet.__init__(self, fn, driver) + self.driver = driver + + def await_only(awaitable: Coroutine) -> Any: + """Awaits an async function in a sync method. + + The sync method must be insice a :func:`greenlet_spawn` context. + :func:`await_` calls cannot be nested. + + :param awaitable: The coroutine to call. + + """ + # this is called in the context greenlet while running fn + current = greenlet.getcurrent() + if not isinstance(current, _AsyncIoGreenlet): + raise exc.InvalidRequestError( + "greenlet_spawn has not been called; can't call await_() here." + ) + + # returns the control to the driver greenlet passing it + # a coroutine to run. Once the awaitable is done, the driver greenlet + # switches back to this greenlet with the result of awaitable that is + # then returned to the caller (or raised as error) + return current.driver.switch(awaitable) + + def await_fallback(awaitable: Coroutine) -> Any: + """Awaits an async function in a sync method. + + The sync method must be insice a :func:`greenlet_spawn` context. + :func:`await_` calls cannot be nested. + + :param awaitable: The coroutine to call. + + """ + # this is called in the context greenlet while running fn + current = greenlet.getcurrent() + if not isinstance(current, _AsyncIoGreenlet): + loop = asyncio.get_event_loop() + if loop.is_running(): + raise exc.InvalidRequestError( + "greenlet_spawn has not been called and asyncio event " + "loop is already running; can't call await_() here." + ) + return loop.run_until_complete(awaitable) + + return current.driver.switch(awaitable) + + async def greenlet_spawn(fn: Callable, *args, **kwargs) -> Any: + """Runs a sync function ``fn`` in a new greenlet. + + The sync function can then use :func:`await_` to wait for async + functions. + + :param fn: The sync callable to call. + :param \\*args: Positional arguments to pass to the ``fn`` callable. + :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable. + """ + context = _AsyncIoGreenlet(fn, greenlet.getcurrent()) + # runs the function synchronously in gl greenlet. If the execution + # is interrupted by await_, context is not dead and result is a + # coroutine to wait. If the context is dead the function has + # returned, and its result can be returned. + try: + result = context.switch(*args, **kwargs) + while not context.dead: + try: + # wait for a coroutine from await_ and then return its + # result back to it. + value = await result + except Exception: + # this allows an exception to be raised within + # the moderated greenlet so that it can continue + # its expected flow. + result = context.throw(*sys.exc_info()) + else: + result = context.switch(value) + finally: + # clean up to avoid cycle resolution by gc + del context.driver + return result + + +except ImportError: # pragma: no cover + greenlet = None + + def await_fallback(awaitable): + return asyncio.get_event_loop().run_until_complete(awaitable) + + def await_only(awaitable): + raise ValueError("Greenlet is required to use this function") + + async def greenlet_spawn(fn, *args, **kw): + raise ValueError("Greenlet is required to use this function") diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py new file mode 100644 index 0000000000..4c4ea20d12 --- /dev/null +++ b/lib/sqlalchemy/util/concurrency.py @@ -0,0 +1,21 @@ +from . import compat + + +if compat.py3k: + import asyncio + from ._concurrency_py3k import await_only + from ._concurrency_py3k import await_fallback + from ._concurrency_py3k import greenlet + from ._concurrency_py3k import greenlet_spawn +else: + asyncio = None + greenlet = None + + def await_only(thing): + return thing + + def await_fallback(thing): + return thing + + def greenlet_spawn(fn, *args, **kw): + raise ValueError("Cannot use this function in py2.") diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 3433657d6b..5f71c7bd6f 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -21,7 +21,10 @@ condition. from collections import deque from time import time as _time +from . import compat from .compat import threading +from .concurrency import asyncio +from .concurrency import await_fallback __all__ = ["Empty", "Full", "Queue"] @@ -196,3 +199,64 @@ class Queue: else: # FIFO return self.queue.popleft() + + +class AsyncAdaptedQueue: + await_ = await_fallback + + def __init__(self, maxsize=0, use_lifo=False): + if use_lifo: + self._queue = asyncio.LifoQueue(maxsize=maxsize) + else: + self._queue = asyncio.Queue(maxsize=maxsize) + self.maxsize = maxsize + self.empty = self._queue.empty + self.full = self._queue.full + self.qsize = self._queue.qsize + + def put_nowait(self, item): + try: + return self._queue.put_nowait(item) + except asyncio.queues.QueueFull as err: + compat.raise_( + Full(), replace_context=err, + ) + + def put(self, item, block=True, timeout=None): + if not block: + return self.put_nowait(item) + + try: + if timeout: + return self.await_( + asyncio.wait_for(self._queue.put(item), timeout) + ) + else: + return self.await_(self._queue.put(item)) + except asyncio.queues.QueueFull as err: + compat.raise_( + Full(), replace_context=err, + ) + + def get_nowait(self): + try: + return self._queue.get_nowait() + except asyncio.queues.QueueEmpty as err: + compat.raise_( + Empty(), replace_context=err, + ) + + def get(self, block=True, timeout=None): + if not block: + return self.get_nowait() + try: + if timeout: + return self.await_( + asyncio.wait_for(self._queue.get(), timeout) + ) + else: + return self.await_(self._queue.get()) + except asyncio.queues.QueueEmpty as err: + compat.raise_( + Empty(), replace_context=err, + ) diff --git a/setup.cfg b/setup.cfg index 9cbdbd838d..387f422efd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,10 +38,12 @@ python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.* package_dir = =lib install_requires = - importlib-metadata;python_version<"3.8" - + importlib-metadata;python_version<"3.8" + greenlet [options.extras_require] +asyncio = + greenlet mssql = pyodbc mssql_pymssql = pymssql mssql_pyodbc = pyodbc @@ -53,6 +55,9 @@ oracle = cx_oracle>=7;python_version>="3" postgresql = psycopg2>=2.7 postgresql_pg8000 = pg8000 +postgresql_asyncpg = + asyncpg;python_version>="3" + greenlet postgresql_psycopg2binary = psycopg2-binary postgresql_psycopg2cffi = psycopg2cffi pymysql = pymysql @@ -110,6 +115,7 @@ default = sqlite:///:memory: sqlite = sqlite:///:memory: sqlite_file = sqlite:///querytest.db postgresql = postgresql://scott:tiger@127.0.0.1:5432/test +asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test mysql = mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py new file mode 100644 index 0000000000..10b89291e0 --- /dev/null +++ b/test/base/test_concurrency_py3k.py @@ -0,0 +1,103 @@ +from sqlalchemy import exc +from sqlalchemy.testing import async_test +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.util import await_fallback +from sqlalchemy.util import await_only +from sqlalchemy.util import greenlet_spawn + + +async def run1(): + return 1 + + +async def run2(): + return 2 + + +def go(*fns): + return sum(await_only(fn()) for fn in fns) + + +class TestAsyncioCompat(fixtures.TestBase): + @async_test + async def test_ok(self): + + eq_(await greenlet_spawn(go, run1, run2), 3) + + @async_test + async def test_async_error(self): + async def err(): + raise ValueError("an error") + + with expect_raises_message(ValueError, "an error"): + await greenlet_spawn(go, run1, err) + + @async_test + async def test_sync_error(self): + def go(): + await_only(run1()) + raise ValueError("sync error") + + with expect_raises_message(ValueError, "sync error"): + await greenlet_spawn(go) + + def test_await_fallback_no_greenlet(self): + to_await = run1() + await_fallback(to_await) + + def test_await_only_no_greenlet(self): + to_await = run1() + with expect_raises_message( + exc.InvalidRequestError, + r"greenlet_spawn has not been called; can't call await_\(\) here.", + ): + await_only(to_await) + + # ensure no warning + await_fallback(to_await) + + @async_test + async def test_await_fallback_error(self): + to_await = run1() + + await to_await + + async def inner_await(): + nonlocal to_await + to_await = run1() + await_fallback(to_await) + + def go(): + await_fallback(inner_await()) + + with expect_raises_message( + exc.InvalidRequestError, + "greenlet_spawn has not been called and asyncio event loop", + ): + await greenlet_spawn(go) + + await to_await + + @async_test + async def test_await_only_error(self): + to_await = run1() + + await to_await + + async def inner_await(): + nonlocal to_await + to_await = run1() + await_only(to_await) + + def go(): + await_only(inner_await()) + + with expect_raises_message( + exc.InvalidRequestError, + r"greenlet_spawn has not been called; can't call await_\(\) here.", + ): + await greenlet_spawn(go) + + await to_await diff --git a/test/conftest.py b/test/conftest.py index 5c6b89fde7..92d3e07768 100755 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,11 @@ import sys import pytest + +collect_ignore_glob = [] +if sys.version_info[0] < 3: + collect_ignore_glob.append("*_py3k.py") + pytest.register_assert_rewrite("sqlalchemy.testing.assertions") diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index f6aba550ec..57c243442e 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -937,9 +937,7 @@ $$ LANGUAGE plpgsql; stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric) assert_raises(exc.InvalidRequestError, connection.execute, stmt) - @testing.only_if( - "postgresql >= 8.2", "requires standard_conforming_strings" - ) + @testing.only_on("postgresql+psycopg2") def test_serial_integer(self): class BITD(TypeDecorator): impl = Integer diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index ffd32813c0..5ab65f9e34 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -738,17 +738,14 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): def teardown_class(cls): metadata.drop_all() - @testing.fails_on("postgresql+pg8000", "uses positional") + @testing.requires.pyformat_paramstyle def test_expression_pyformat(self): self.assert_compile( matchtable.c.title.match("somstr"), "matchtable.title @@ to_tsquery(%(title_1)s" ")", ) - @testing.fails_on("postgresql+psycopg2", "uses pyformat") - @testing.fails_on("postgresql+pypostgresql", "uses pyformat") - @testing.fails_on("postgresql+pygresql", "uses pyformat") - @testing.fails_on("postgresql+psycopg2cffi", "uses pyformat") + @testing.requires.format_paramstyle def test_expression_positional(self): self.assert_compile( matchtable.c.title.match("somstr"), diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 95486b1979..503477833d 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -27,6 +27,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import Text from sqlalchemy import text +from sqlalchemy import type_coerce from sqlalchemy import TypeDecorator from sqlalchemy import types from sqlalchemy import Unicode @@ -774,7 +775,12 @@ class RegClassTest(fixtures.TestBase): regclass = cast("pg_class", postgresql.REGCLASS) oid = self._scalar(cast(regclass, postgresql.OID)) assert isinstance(oid, int) - eq_(self._scalar(cast(oid, postgresql.REGCLASS)), "pg_class") + eq_( + self._scalar( + cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS) + ), + "pg_class", + ) def test_cast_whereclause(self): pga = Table( @@ -1801,10 +1807,13 @@ class ArrayEnum(fixtures.TestBase): testing.db, ) + @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") @testing.combinations( - sqltypes.ARRAY, postgresql.ARRAY, _ArrayOfEnum, argnames="array_cls" + sqltypes.ARRAY, + postgresql.ARRAY, + (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")), + argnames="array_cls", ) - @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") @testing.provide_metadata def test_array_of_enums(self, array_cls, enum_cls, connection): tbl = Table( @@ -1845,6 +1854,8 @@ class ArrayEnum(fixtures.TestBase): sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc()) eq_(connection.scalar(sel), [MyEnum.a]) + self.metadata.drop_all(connection) + class ArrayJSON(fixtures.TestBase): __backend__ = True diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py index af6bc1d369..624fa90053 100644 --- a/test/engine/test_logging.py +++ b/test/engine/test_logging.py @@ -10,8 +10,8 @@ from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import util from sqlalchemy.sql import util as sql_util +from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message -from sqlalchemy.testing import assert_raises_return from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import eq_regex @@ -104,7 +104,7 @@ class LogParamsTest(fixtures.TestBase): def test_log_positional_array(self): with self.eng.connect() as conn: - exc_info = assert_raises_return( + exc_info = assert_raises( tsa.exc.DBAPIError, conn.execute, tsa.text("SELECT * FROM foo WHERE id IN :foo AND bar=:bar"), diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index d91105f411..48eb485cb7 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1356,7 +1356,14 @@ class InvalidateDuringResultTest(fixtures.TestBase): "cx_oracle 6 doesn't allow a close like this due to open cursors", ) @testing.fails_if( - ["+mysqlconnector", "+mysqldb", "+cymysql", "+pymysql", "+pg8000"], + [ + "+mysqlconnector", + "+mysqldb", + "+cymysql", + "+pymysql", + "+pg8000", + "+asyncpg", + ], "Buffers the result set and doesn't check for connection close", ) def test_invalidate_on_results(self): @@ -1365,5 +1372,8 @@ class InvalidateDuringResultTest(fixtures.TestBase): for x in range(20): result.fetchone() self.engine.test_shutdown() - _assert_invalidated(result.fetchone) - assert conn.invalidated + try: + _assert_invalidated(result.fetchone) + assert conn.invalidated + finally: + conn.invalidate() diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 8981028d2c..cd144e45f4 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -461,11 +461,8 @@ class TransactionTest(fixtures.TestBase): assert not savepoint.is_active if util.py3k: - # driver error - assert exc_.__cause__ - - # and that's it, no other context - assert not exc_.__cause__.__context__ + # ensure cause comes from the DBAPI + assert isinstance(exc_.__cause__, testing.db.dialect.dbapi.Error) def test_retains_through_options(self, local_connection): connection = local_connection diff --git a/test/ext/asyncio/__init__.py b/test/ext/asyncio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py new file mode 100644 index 0000000000..ec513cb649 --- /dev/null +++ b/test/ext/asyncio/test_engine_py3k.py @@ -0,0 +1,340 @@ +from sqlalchemy import Column +from sqlalchemy import delete +from sqlalchemy import exc +from sqlalchemy import func +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import union_all +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import exc as asyncio_exc +from sqlalchemy.testing import async_test +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.asyncio import assert_raises_message_async + + +class EngineFixture(fixtures.TablesTest): + __requires__ = ("async_dialect",) + + @testing.fixture + def async_engine(self): + return create_async_engine(testing.db.url) + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column("user_id", Integer, primary_key=True, autoincrement=False), + Column("user_name", String(20)), + ) + + @classmethod + def insert_data(cls, connection): + users = cls.tables.users + with connection.begin(): + connection.execute( + users.insert(), + [ + {"user_id": i, "user_name": "name%d" % i} + for i in range(1, 20) + ], + ) + + +class AsyncEngineTest(EngineFixture): + __backend__ = True + + @async_test + async def test_connect_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + result = await conn.execute(select(1)) + eq_(result.scalar(), 1) + + @async_test + async def test_connect_plain(self, async_engine): + conn = await async_engine.connect() + try: + result = await conn.execute(select(1)) + eq_(result.scalar(), 1) + finally: + await conn.close() + + @async_test + async def test_connection_not_started(self, async_engine): + + conn = async_engine.connect() + testing.assert_raises_message( + asyncio_exc.AsyncContextNotStarted, + "AsyncConnection context has not been started and " + "object has not been awaited.", + conn.begin, + ) + + @async_test + async def test_transaction_commit(self, async_engine): + users = self.tables.users + + async with async_engine.begin() as conn: + await conn.execute(delete(users)) + + async with async_engine.connect() as conn: + eq_(await conn.scalar(select(func.count(users.c.user_id))), 0) + + @async_test + async def test_savepoint_rollback_noctx(self, async_engine): + users = self.tables.users + + async with async_engine.begin() as conn: + + savepoint = await conn.begin_nested() + await conn.execute(delete(users)) + await savepoint.rollback() + + async with async_engine.connect() as conn: + eq_(await conn.scalar(select(func.count(users.c.user_id))), 19) + + @async_test + async def test_savepoint_commit_noctx(self, async_engine): + users = self.tables.users + + async with async_engine.begin() as conn: + + savepoint = await conn.begin_nested() + await conn.execute(delete(users)) + await savepoint.commit() + + async with async_engine.connect() as conn: + eq_(await conn.scalar(select(func.count(users.c.user_id))), 0) + + @async_test + async def test_transaction_rollback(self, async_engine): + users = self.tables.users + + async with async_engine.connect() as conn: + trans = conn.begin() + await trans.start() + await conn.execute(delete(users)) + await trans.rollback() + + async with async_engine.connect() as conn: + eq_(await conn.scalar(select(func.count(users.c.user_id))), 19) + + @async_test + async def test_conn_transaction_not_started(self, async_engine): + + async with async_engine.connect() as conn: + trans = conn.begin() + await assert_raises_message_async( + asyncio_exc.AsyncContextNotStarted, + "AsyncTransaction context has not been started " + "and object has not been awaited.", + trans.rollback(), + ) + + +class AsyncResultTest(EngineFixture): + @testing.combinations( + (None,), ("scalars",), ("mappings",), argnames="filter_" + ) + @async_test + async def test_all(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + if filter_ == "mappings": + result = result.mappings() + elif filter_ == "scalars": + result = result.scalars(1) + + all_ = await result.all() + if filter_ == "mappings": + eq_( + all_, + [ + {"user_id": i, "user_name": "name%d" % i} + for i in range(1, 20) + ], + ) + elif filter_ == "scalars": + eq_( + all_, ["name%d" % i for i in range(1, 20)], + ) + else: + eq_(all_, [(i, "name%d" % i) for i in range(1, 20)]) + + @testing.combinations( + (None,), ("scalars",), ("mappings",), argnames="filter_" + ) + @async_test + async def test_aiter(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + if filter_ == "mappings": + result = result.mappings() + elif filter_ == "scalars": + result = result.scalars(1) + + rows = [] + + async for row in result: + rows.append(row) + + if filter_ == "mappings": + eq_( + rows, + [ + {"user_id": i, "user_name": "name%d" % i} + for i in range(1, 20) + ], + ) + elif filter_ == "scalars": + eq_( + rows, ["name%d" % i for i in range(1, 20)], + ) + else: + eq_(rows, [(i, "name%d" % i) for i in range(1, 20)]) + + @testing.combinations((None,), ("mappings",), argnames="filter_") + @async_test + async def test_keys(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + if filter_ == "mappings": + result = result.mappings() + + eq_(result.keys(), ["user_id", "user_name"]) + + @async_test + async def test_unique_all(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + union_all(select(users), select(users)).order_by( + users.c.user_id + ) + ) + + all_ = await result.unique().all() + eq_(all_, [(i, "name%d" % i) for i in range(1, 20)]) + + @async_test + async def test_columns_all(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + all_ = await result.columns(1).all() + eq_(all_, [("name%d" % i,) for i in range(1, 20)]) + + @testing.combinations( + (None,), ("scalars",), ("mappings",), argnames="filter_" + ) + @async_test + async def test_partitions(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + if filter_ == "mappings": + result = result.mappings() + elif filter_ == "scalars": + result = result.scalars(1) + + check_result = [] + async for partition in result.partitions(5): + check_result.append(partition) + + if filter_ == "mappings": + eq_( + check_result, + [ + [ + {"user_id": i, "user_name": "name%d" % i} + for i in range(a, b) + ] + for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)] + ], + ) + elif filter_ == "scalars": + eq_( + check_result, + [ + ["name%d" % i for i in range(a, b)] + for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)] + ], + ) + else: + eq_( + check_result, + [ + [(i, "name%d" % i) for i in range(a, b)] + for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)] + ], + ) + + @testing.combinations( + (None,), ("scalars",), ("mappings",), argnames="filter_" + ) + @async_test + async def test_one_success(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + select(users).limit(1).order_by(users.c.user_name) + ) + + if filter_ == "mappings": + result = result.mappings() + elif filter_ == "scalars": + result = result.scalars() + u1 = await result.one() + + if filter_ == "mappings": + eq_(u1, {"user_id": 1, "user_name": "name%d" % 1}) + elif filter_ == "scalars": + eq_(u1, 1) + else: + eq_(u1, (1, "name%d" % 1)) + + @async_test + async def test_one_no_result(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + select(users).where(users.c.user_name == "nonexistent") + ) + + async def go(): + await result.one() + + await assert_raises_message_async( + exc.NoResultFound, + "No row was found when one was required", + go(), + ) + + @async_test + async def test_one_multi_result(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + select(users).where(users.c.user_name.in_(["name3", "name5"])) + ) + + async def go(): + await result.one() + + await assert_raises_message_async( + exc.MultipleResultsFound, + "Multiple rows were found when exactly one was required", + go(), + ) diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py new file mode 100644 index 0000000000..e8caaca3e4 --- /dev/null +++ b/test/ext/asyncio/test_session_py3k.py @@ -0,0 +1,200 @@ +from sqlalchemy import exc +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy import testing +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import selectinload +from sqlalchemy.testing import async_test +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import is_ +from ...orm import _fixtures + + +class AsyncFixture(_fixtures.FixtureTest): + __requires__ = ("async_dialect",) + + @classmethod + def setup_mappers(cls): + cls._setup_stock_mapping() + + @testing.fixture + def async_engine(self): + return create_async_engine(testing.db.url) + + @testing.fixture + def async_session(self, async_engine): + return AsyncSession(async_engine) + + +class AsyncSessionTest(AsyncFixture): + def test_requires_async_engine(self, async_engine): + testing.assert_raises_message( + exc.ArgumentError, + "AsyncEngine expected, got Engine", + AsyncSession, + bind=async_engine.sync_engine, + ) + + +class AsyncSessionQueryTest(AsyncFixture): + @async_test + async def test_execute(self, async_session): + User = self.classes.User + + stmt = ( + select(User) + .options(selectinload(User.addresses)) + .order_by(User.id) + ) + + result = await async_session.execute(stmt) + eq_(result.scalars().all(), self.static.user_address_result) + + @async_test + async def test_stream_partitions(self, async_session): + User = self.classes.User + + stmt = ( + select(User) + .options(selectinload(User.addresses)) + .order_by(User.id) + ) + + result = await async_session.stream(stmt) + + assert_result = [] + async for partition in result.scalars().partitions(3): + assert_result.append(partition) + + eq_( + assert_result, + [ + self.static.user_address_result[0:3], + self.static.user_address_result[3:], + ], + ) + + +class AsyncSessionTransactionTest(AsyncFixture): + run_inserts = None + + @async_test + async def test_trans(self, async_session, async_engine): + async with async_engine.connect() as outer_conn: + + User = self.classes.User + + async with async_session.begin(): + + eq_(await outer_conn.scalar(select(func.count(User.id))), 0) + + u1 = User(name="u1") + + async_session.add(u1) + + result = await async_session.execute(select(User)) + eq_(result.scalar(), u1) + + eq_(await outer_conn.scalar(select(func.count(User.id))), 1) + + @async_test + async def test_commit_as_you_go(self, async_session, async_engine): + async with async_engine.connect() as outer_conn: + + User = self.classes.User + + eq_(await outer_conn.scalar(select(func.count(User.id))), 0) + + u1 = User(name="u1") + + async_session.add(u1) + + result = await async_session.execute(select(User)) + eq_(result.scalar(), u1) + + await async_session.commit() + + eq_(await outer_conn.scalar(select(func.count(User.id))), 1) + + @async_test + async def test_trans_noctx(self, async_session, async_engine): + async with async_engine.connect() as outer_conn: + + User = self.classes.User + + trans = await async_session.begin() + try: + eq_(await outer_conn.scalar(select(func.count(User.id))), 0) + + u1 = User(name="u1") + + async_session.add(u1) + + result = await async_session.execute(select(User)) + eq_(result.scalar(), u1) + finally: + await trans.commit() + + eq_(await outer_conn.scalar(select(func.count(User.id))), 1) + + @async_test + async def test_flush(self, async_session): + User = self.classes.User + + async with async_session.begin(): + u1 = User(name="u1") + + async_session.add(u1) + + conn = await async_session.connection() + + eq_(await conn.scalar(select(func.count(User.id))), 0) + + await async_session.flush() + + eq_(await conn.scalar(select(func.count(User.id))), 1) + + @async_test + async def test_refresh(self, async_session): + User = self.classes.User + + async with async_session.begin(): + u1 = User(name="u1") + + async_session.add(u1) + await async_session.flush() + + conn = await async_session.connection() + + await conn.execute( + update(User) + .values(name="u2") + .execution_options(synchronize_session=None) + ) + + eq_(u1.name, "u1") + + await async_session.refresh(u1) + + eq_(u1.name, "u2") + + eq_(await conn.scalar(select(func.count(User.id))), 1) + + @async_test + async def test_merge(self, async_session): + User = self.classes.User + + async with async_session.begin(): + u1 = User(id=1, name="u1") + + async_session.add(u1) + + async with async_session.begin(): + new_u = User(id=1, name="new u1") + + new_u_merged = await async_session.merge(new_u) + + is_(new_u_merged, u1) + eq_(u1.name, "new u1") diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 75dca1c990..047ef25aee 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -1421,6 +1421,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): # this would work with Firebird if you do literal_column('1') # instead case_stmt = case([(Document.title.in_(subq), True)], else_=False) + s.query(Document).update( {"flag": case_stmt}, synchronize_session=False ) diff --git a/test/requirements.py b/test/requirements.py index 28f955fa5f..fdb7c2ff33 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -198,7 +198,7 @@ class DefaultRequirements(SuiteRequirements): "mysql+pymysql", "mysql+cymysql", "mysql+mysqlconnector", - "postgresql", + "postgresql+pg8000", ] ) @@ -1162,20 +1162,6 @@ class DefaultRequirements(SuiteRequirements): "Firebird still has FP inaccuracy even " "with only four decimal places", ), - ( - "mssql+pyodbc", - None, - None, - "mssql+pyodbc has FP inaccuracy even with " - "only four decimal places ", - ), - ( - "mssql+pymssql", - None, - None, - "mssql+pymssql has FP inaccuracy even with " - "only four decimal places ", - ), ( "postgresql+pg8000", None, @@ -1280,6 +1266,12 @@ class DefaultRequirements(SuiteRequirements): return only_if(check_range_types) + @property + def async_dialect(self): + """dialect makes use of await_() to invoke operations on the DBAPI.""" + + return only_on(["postgresql+asyncpg"]) + @property def oracle_test_dblink(self): return skip_if( diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 676c46db65..aa1c0d48d7 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -948,7 +948,7 @@ class PKDefaultTest(fixtures.TablesTest): metadata, Column( "date_id", - DateTime, + DateTime(timezone=True), default=text("current_timestamp"), primary_key=True, ), diff --git a/tox.ini b/tox.ini index 0ce79d7a1a..e3539ce611 100644 --- a/tox.ini +++ b/tox.ini @@ -17,9 +17,11 @@ usedevelop= deps=pytest!=3.9.1,!=3.9.2 pytest-xdist + greenlet mock; python_version < '3.3' importlib_metadata; python_version < '3.8' postgresql: .[postgresql] + postgresql: .[postgresql_asyncpg] mysql: .[mysql] mysql: .[pymysql] oracle: .[oracle] @@ -56,7 +58,7 @@ setenv= cov: COVERAGE={[testenv]cov_args} sqlite: SQLITE={env:TOX_SQLITE:--db sqlite} sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file} - postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql} + postgresql: POSTGRESQL={env:TOX_POSTGRESQL_W_ASYNCPG:--db postgresql} mysql: MYSQL={env:TOX_MYSQL:--db mysql --db pymysql} oracle: ORACLE={env:TOX_ORACLE:--db oracle} mssql: MSSQL={env:TOX_MSSQL:--db mssql} @@ -68,7 +70,7 @@ setenv= # tox as of 2.0 blocks all environment variables from the # outside, unless they are here (or in TOX_TESTENV_PASSENV, # wildcards OK). Need at least these -passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_MYSQL TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS +passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL_W_ASYNCPG TOX_MYSQL TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS # for nocext, we rm *.so in lib in case we are doing usedevelop=True commands=