]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement rudimentary asyncio support w/ asyncpg
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Jul 2020 16:21:36 +0000 (12:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 13 Aug 2020 22:41:53 +0000 (18:41 -0400)
Using the approach introduced at
https://gist.github.com/zzzeek/6287e28054d3baddc07fa21a7227904e

We can now create asyncio endpoints that are then handled
in "implicit IO" form within the majority of the Core internals.
Then coroutines are re-exposed at the point at which we call
into asyncpg methods.

Patch includes:

* asyncpg dialect

* asyncio package

* engine, result, ORM session classes

* new test fixtures, tests

* some work with pep-484 and a short plugin for the
  pyannotate package, which seems to have so-so results

Change-Id: Idbcc0eff72c4cad572914acdd6f40ddb1aef1a7d
Fixes: #3414
67 files changed:
.gitignore
doc/build/changelog/migration_14.rst
doc/build/changelog/migration_20.rst
doc/build/changelog/unreleased_14/3414.rst [new file with mode: 0644]
doc/build/conf.py
doc/build/core/connections.rst
doc/build/dialects/postgresql.rst
doc/build/index.rst
doc/build/intro.rst
doc/build/orm/examples.rst
doc/build/orm/extensions/asyncio.rst [new file with mode: 0644]
doc/build/orm/extensions/index.rst
examples/asyncio/__init__.py [new file with mode: 0644]
examples/asyncio/async_orm.py [new file with mode: 0644]
examples/asyncio/basic.py [new file with mode: 0644]
examples/asyncio/greenlet_orm.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/ext/asyncio/__init__.py [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/base.py [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/engine.py [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/exc.py [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/result.py [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/session.py [new file with mode: 0644]
lib/sqlalchemy/future/__init__.py
lib/sqlalchemy/future/engine.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/pool/__init__.py
lib/sqlalchemy/pool/impl.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/asyncio.py [new file with mode: 0644]
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/testing/suite/test_types.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_concurrency_py3k.py [new file with mode: 0644]
lib/sqlalchemy/util/concurrency.py [new file with mode: 0644]
lib/sqlalchemy/util/queue.py
setup.cfg
test/base/test_concurrency_py3k.py [new file with mode: 0644]
test/conftest.py
test/dialect/postgresql/test_dialect.py
test/dialect/postgresql/test_query.py
test/dialect/postgresql/test_types.py
test/engine/test_logging.py
test/engine/test_reconnect.py
test/engine/test_transaction.py
test/ext/asyncio/__init__.py [new file with mode: 0644]
test/ext/asyncio/test_engine_py3k.py [new file with mode: 0644]
test/ext/asyncio/test_session_py3k.py [new file with mode: 0644]
test/orm/test_update_delete.py
test/requirements.py
test/sql/test_defaults.py
tox.ini

index 4931017b78f32851a2a11bb02b46ca429a42fba9..3916fe299b521f3907396e3969f26ab70ab843b9 100644 (file)
@@ -39,3 +39,5 @@ test/test_schema.db
 /.ipynb_checkpoints/
 *.ipynb
 /querytest.db
+/.mypy_cache
+/.pytest_cache
\ No newline at end of file
index 5753cb089d894eb7f68da12670522ab966b2391b..14584fd430b8e7d8b7fdd0e6163639b5f2e3e636 100644 (file)
@@ -20,8 +20,8 @@ What's New in SQLAlchemy 1.4?
 
     For the current status of SQLAlchemy 2.0, see :ref:`migration_20_toplevel`.
 
-Behavioral Changes - General
-============================
+Major API changes and features - General
+=========================================
 
 .. _change_5159:
 
@@ -224,6 +224,92 @@ driven in order to support this new feature.
 :ticket:`4808`
 :ticket:`5004`
 
+.. _change_3414:
+
+Asynchronous IO Support for Core and ORM
+------------------------------------------
+
+SQLAlchemy now supports Python ``asyncio``-compatible database drivers using an
+all-new asyncio front-end interface to :class:`_engine.Connection` for Core
+usage as well as :class:`_orm.Session` for ORM use, using the
+:class:`_asyncio.AsyncConnection` and :class:`_asyncio.AsyncSession` objects.
+
+.. note::  The new asyncio feature should be considered **alpha level** for
+   the initial releases of SQLAlchemy 1.4.   This is super new stuff that uses
+   some previously unfamiliar programming techniques.
+
+The initial database API supported is the :ref:`dialect-postgresql-asyncpg`
+asyncio driver for PostgreSQL.
+
+The internal features of SQLAlchemy are fully integrated by making use of
+the `greenlet <https://greenlet.readthedocs.io/en/latest/>`_ library in order
+to adapt the flow of execution within SQLAlchemy's internals to propagate
+asyncio ``await`` keywords outwards from the database driver to the end-user
+API, which features ``async`` methods.  Using this approach, the asyncpg
+driver is fully operational within SQLAlchemy's own test suite and features
+compatibility with most psycopg2 features.   The approach was vetted and
+improved upon by developers of the greenlet project for which SQLAlchemy
+is appreciative.
+
+.. sidebar:: greenlets are good
+
+  Don't confuse the greenlet_ library with event-based IO libraries that build
+  on top of it such as ``gevent`` and ``eventlet``; while the use of these
+  libraries with SQLAlchemy is common, SQLAlchemy's asyncio integration
+  **does not** make use of these event based systems in any way. The asyncio
+  API integrates with the user-provided event loop, typically Python's own
+  asyncio event loop, without the use of additional threads or event systems.
+  The approach involves a single greenlet context switch per ``await`` call,
+  and the extension which makes it possible is less than 20 lines of code.
+
+The user facing ``async`` API itself is focused around IO-oriented methods such
+as :meth:`_asyncio.AsyncEngine.connect` and
+:meth:`_asyncio.AsyncConnection.execute`.   The new Core constructs strictly
+support :term:`2.0 style` usage only; which means all statements must be
+invoked given a connection object, in this case
+:class:`_asyncio.AsyncConnection`.
+
+Within the ORM, :term:`2.0 style` query execution is
+supported, using :func:`_sql.select` constructs in conjunction with
+:meth:`_asyncio.AsyncSession.execute`; the legacy :class:`_orm.Query`
+object itself is not supported by the :class:`_asyncio.AsyncSession` class.
+
+ORM features such as lazy loading of related attributes as well as unexpiry of
+expired attributes are by definition disallowed in the traditional asyncio
+programming model, as they indicate IO operations that would run implicitly
+within the scope of a Python ``getattr()`` operation.   To overcome this, the
+**traditional** asyncio application should make judicious use of :ref:`eager
+loading <loading_toplevel>` techniques as well as forego the use of features
+such as :ref:`expire on commit <session_committing>` so that such loads are not
+needed.
+
+For the asyncio application developer who **chooses to break** with
+tradition, the new API provides a **strictly optional
+feature** such that applications that wish to make use of such ORM features
+can opt to organize database-related code into functions which can then be
+run within greenlets using the :meth:`_asyncio.AsyncSession.run_sync`
+method. See the ``greenlet_orm.py`` example at :ref:`examples_asyncio`
+for a demonstration.
+
+Support for asynchronous cursors is also provided using new methods
+:meth:`_asyncio.AsyncConnection.stream` and
+:meth:`_asyncio.AsyncSession.stream`, which support a new
+:class:`_asyncio.AsyncResult` object that itself provides awaitable
+versions of common methods like
+:meth:`_asyncio.AsyncResult.all` and
+:meth:`_asyncio.AsyncResult.fetchmany`.   Both Core and ORM are integrated
+with the feature which corresponds to the use of "server side cursors"
+in traditional SQLAlchemy.
+
+.. seealso::
+
+  :ref:`asyncio_toplevel`
+
+  :ref:`examples_asyncio`
+
+
+
+:ticket:`3414`
 
 .. _change_deferred_construction:
 
index 535756f53e171aab2a465ccf81a0f8a8d381127e..7b3d23c8ca68af8e44096c53d184a8d976b84b72 100644 (file)
@@ -1252,10 +1252,7 @@ Asyncio Support
 
 .. admonition:: Certainty: definite
 
-  A surprising development will allow asyncio support including with the
-  ORM to be fully implemented.   There will even be a **completely optional**
-  path to having lazy loading be available, for those willing to make use of
-  some "controversial" patterns.
+  This is now implemented in 1.4.
 
 There was previously an entire section here detailing how asyncio is a nice to
 have, but not really necessary from a technical standpoint, there are some
@@ -1267,113 +1264,7 @@ an entirely separate version of everything be maintained, therefore this makes
 it feasible to deliver this feature to those users who prefer an all-async
 application style without impact on the traditional blocking archictecture.
 
-The proof of concept at https://gist.github.com/zzzeek/4e89ce6226826e7a8df13e1b573ad354
-illustrates how to write an asyncio application that makes use of a pure asyncio
-driver (asyncpg), with part of the code **in between** remaining as sync code
-without the use of any await/async keywords.  The central technique involves
-minimal use of a greenlet (e.g. stackless Python) to perform the necessary
-context switches when an "await" occurs.   The approach has been vetted
-both with asyncio developers as well as greenlet developers, the latter
-of which contributed a great degree of simplification the already simple recipe
-such that can context switch async coroutines with no decrease in performance.
-
-The proof of concept has then been expanded to work within SQLAlchemy Core
-and is presently in a Gerrit review.   A SQLAlchemy dialect for the asyncpg
-driver has been written and it passes most tests.
-
-Example ORM use will look similar to the following; this example is already
-runnable with the in-review codebase::
-
-    import asyncio
-
-    from sqlalchemy.asyncio import create_async_engine
-    from sqlalchemy.asyncio import AsyncSession
-    # ... other imports ...
-
-    async def async_main():
-        engine = create_async_engine(
-            "postgresql+asyncpg://scott:tiger@localhost/test", echo=True,
-        )
-
-
-        # assume a typical ORM model with classes A and B
-
-        session = AsyncSession(engine)
-        session.add_all(
-            [
-                A(bs=[B(), B()], data="a1"),
-                A(bs=[B()], data="a2"),
-                A(bs=[B(), B()], data="a3"),
-            ]
-        )
-        await session.commit()
-        stmt = select(A).options(selectinload(A.bs))
-        result = await session.execute(stmt)
-        for a1 in result.scalars():
-            print(a1)
-            for b1 in a1.bs:
-                print(b1)
-
-        result = await session.execute(select(A).order_by(A.id))
-
-        a1 = result.scalars().first()
-        a1.data = "new data"
-        await session.commit()
-
-    asyncio.run(async_main())
-
-The "controversial" feature, if provided, would include that the "greenlet"
-context would be supplied as front-facing API.  This would allow an asyncio
-application to spawn a greenlet that contains sync-code, which could use the
-Core and ORM in a fully traditional manner including that lazy loading
-for columns and relationships would be present.  This mode of use is
-somewhat similar to running an application under an event-based
-programming library such as gevent or eventlet, however the underyling
-network calls would be within a pure asyncio context, i.e. like that of the
-asyncpg driver.   An example of this use, which is also runnable with
-the in-review codebase::
-
-    import asyncio
-
-    from sqlalchemy.asyncio import greenlet_spawn
-
-    from sqlalchemy import create_engine
-    from sqlalchemy.orm import Session
-    # ... other imports ...
-
-    def main():
-        # standard "sync" engine with the "async" driver.
-        engine = create_engine(
-            "postgresql+asyncpg://scott:tiger@localhost/test", echo=True,
-        )
-
-        # assume a typical ORM model with classes A and B
-
-        session = Session(engine)
-        session.add_all(
-            [
-                A(bs=[B(), B()], data="a1"),
-                A(bs=[B()], data="a2"),
-                A(bs=[B(), B()], data="a3"),
-            ]
-        )
-        session.commit()
-        for a1 in session.query(A).all():
-            print("a: %s" % a1)
-            print("bs: %s" % (a1.bs))  # emits a lazyload.
-
-    asyncio.run(greenlet_spawn(main))
-
-
-Above, we see a ``main()`` function that contains within it a 100% normal
-looking Python program using the SQLAlchemy ORM, using plain ORM imports and
-basically absolutely nothing out of the ordinary.  It just happens to be called
-from inside of an ``asyncio.run()`` call rather than directly, and it uses a
-DBAPI that is only compatible with asyncio.   There is no "monkeypatching" or
-anything else like that involved.    Any asyncio program can opt
-to place it's database-related business methods into the above pattern,
-if preferred, rather than using the asyncio SQLAlchemy API directly.  This
-technique is also being adapted to other frameworks such as Flask and will
-hopefully lead to greater interoperability between blocking and non-blocking
-libraries and frameworks.
+SQLAlchemy 1.4 now includes full asyncio capability with initial support
+using the :ref:`dialect-postgresql-asyncpg` Python database driver;
+see :ref:`asyncio_toplevel`.
 
diff --git a/doc/build/changelog/unreleased_14/3414.rst b/doc/build/changelog/unreleased_14/3414.rst
new file mode 100644 (file)
index 0000000..a278244
--- /dev/null
@@ -0,0 +1,17 @@
+.. change::
+    :tags: feature, engine, orm
+    :tickets: 3414
+
+    SQLAlchemy now includes support for Python asyncio within both Core and
+    ORM, using the included :ref:`asyncio extension <asyncio_toplevel>`. The
+    extension makes use of the `greenlet
+    <https://greenlet.readthedocs.io/en/latest/>`_ library in order to adapt
+    SQLAlchemy's sync-oriented internals such that an asyncio interface that
+    ultimately interacts with an asyncio database adapter is now feasible.  The
+    single driver supported at the moment is the
+    :ref:`dialect-postgresql-asyncpg` driver for PostgreSQL.
+
+    .. seealso::
+
+        :ref:`change_3414`
+
index 13d573296077061a073eaf35de86b3ad73b2616b..d4fdf58a00a7a462e91a4fc5fd982e5c27f210ec 100644 (file)
@@ -106,6 +106,9 @@ autodocmods_convert_modname = {
     "sqlalchemy.engine.row": "sqlalchemy.engine",
     "sqlalchemy.engine.cursor": "sqlalchemy.engine",
     "sqlalchemy.engine.result": "sqlalchemy.engine",
+    "sqlalchemy.ext.asyncio.result": "sqlalchemy.ext.asyncio",
+    "sqlalchemy.ext.asyncio.engine": "sqlalchemy.ext.asyncio",
+    "sqlalchemy.ext.asyncio.session": "sqlalchemy.ext.asyncio",
     "sqlalchemy.util._collections": "sqlalchemy.util",
     "sqlalchemy.orm.relationships": "sqlalchemy.orm",
     "sqlalchemy.orm.interfaces": "sqlalchemy.orm",
@@ -128,6 +131,7 @@ zzzeeksphinx_module_prefixes = {
     "_row": "sqlalchemy.engine",
     "_schema": "sqlalchemy.schema",
     "_types": "sqlalchemy.types",
+    "_asyncio": "sqlalchemy.ext.asyncio",
     "_expression": "sqlalchemy.sql.expression",
     "_sql": "sqlalchemy.sql.expression",
     "_dml": "sqlalchemy.sql.expression",
index c6186cbaa36334f0494ad090cc10ad0c86b4b8dc..b9605bb498e461b19245642685008cdbd3f59141 100644 (file)
@@ -1225,18 +1225,9 @@ The above will respond to ``create_engine("mysql+foodialect://")`` and load the
 Connection / Engine API
 =======================
 
-.. autoclass:: BaseCursorResult
-    :members:
-
-.. autoclass:: ChunkedIteratorResult
-    :members:
-
 .. autoclass:: Connection
    :members:
 
-.. autoclass:: Connectable
-   :members:
-
 .. autoclass:: CreateEnginePlugin
    :members:
 
@@ -1246,6 +1237,25 @@ Connection / Engine API
 .. autoclass:: ExceptionContext
    :members:
 
+.. autoclass:: NestedTransaction
+    :members:
+
+.. autoclass:: Transaction
+    :members:
+
+.. autoclass:: TwoPhaseTransaction
+    :members:
+
+
+Result Set  API
+=================
+
+.. autoclass:: BaseCursorResult
+    :members:
+
+.. autoclass:: ChunkedIteratorResult
+    :members:
+
 .. autoclass:: FrozenResult
     :members:
 
@@ -1258,9 +1268,6 @@ Connection / Engine API
 .. autoclass:: MergedResult
     :members:
 
-.. autoclass:: NestedTransaction
-    :members:
-
 .. autoclass:: Result
     :members:
     :inherited-members:
@@ -1291,9 +1298,3 @@ Connection / Engine API
 .. autoclass:: RowMapping
     :members:
 
-.. autoclass:: Transaction
-    :members:
-
-.. autoclass:: TwoPhaseTransaction
-    :members:
-
index 35ed285eb2fecc2684072a643cd00ec569449e42..6c36e581470dca3ff87d433cb1a87c7a508ea978 100644 (file)
@@ -196,6 +196,13 @@ pg8000
 
 .. automodule:: sqlalchemy.dialects.postgresql.pg8000
 
+.. _dialect-postgresql-asyncpg:
+
+asyncpg
+-------
+
+.. automodule:: sqlalchemy.dialects.postgresql.asyncpg
+
 psycopg2cffi
 ------------
 
index 6afef508336effbb838386226e768b144baf610c..bee062f89dc557e1ff340cc17f3061eb1342f2fc 100644 (file)
@@ -44,7 +44,8 @@ of Python objects, proceed first to the tutorial.
 
 * **ORM Usage:**
   :doc:`Session Usage and Guidelines <orm/session>` |
-  :doc:`Loading Objects <orm/loading_objects>`
+  :doc:`Loading Objects <orm/loading_objects>` |
+  :doc:`AsyncIO Support <orm/extensions/asyncio>`
 
 * **Extending the ORM:**
   :doc:`ORM Events and Internals <orm/extending>`
@@ -68,6 +69,7 @@ are documented here.  In contrast to the ORM's domain-centric mode of usage, the
 * **Engines, Connections, Pools:**
   :doc:`Engine Configuration <core/engines>` |
   :doc:`Connections, Transactions <core/connections>` |
+  :doc:`AsyncIO Support <orm/extensions/asyncio>` |
   :doc:`Connection Pooling <core/pooling>`
 
 * **Schema Definition:**
index 828ba31b318c77b429dc60139198005a100133d8..4b9376ab0ff132ef20ab66cb14370c2a866486d2 100644 (file)
@@ -146,7 +146,6 @@ mechanism::
    setuptools.
 
 
-
 Installing a Database API
 ----------------------------------
 
index 7a79104b9bedd50e1252d7b0d86e350ded6da2ab..10cafb2d2a209fca382675f38deb2f1510c65bae 100644 (file)
@@ -30,6 +30,13 @@ Associations
 
 .. automodule:: examples.association
 
+.. _examples_asyncio:
+
+Asyncio Integration
+-------------------
+
+.. automodule:: examples.asyncio
+
 Directed Graphs
 ---------------
 
diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst
new file mode 100644 (file)
index 0000000..388dee9
--- /dev/null
@@ -0,0 +1,292 @@
+.. _asyncio_toplevel:
+
+asyncio
+=======
+
+Support for Python asyncio.    Support for Core and ORM usage is
+included, using asyncio-compatible dialects.
+
+.. versionadded:: 1.4
+
+
+.. note:: The asyncio should be regarded as **alpha level** for the
+   1.4 release of SQLAlchemy.  API details are **subject to change** at
+   any time.
+
+
+.. seealso::
+
+    :ref:`change_3414` - initial feature announcement
+
+    :ref:`examples_asyncio` - example scripts illustrating working examples
+    of Core and ORM use within the asyncio extension.
+
+Synopsis - Core
+---------------
+
+For Core use, the :func:`_asyncio.create_async_engine` function creates an
+instance of :class:`_asyncio.AsyncEngine` which then offers an async version of
+the traditional :class:`_engine.Engine` API.   The
+:class:`_asyncio.AsyncEngine` delivers an :class:`_asyncio.AsyncConnection` via
+its :meth:`_asyncio.AsyncEngine.connect` and :meth:`_asyncio.AsyncEngine.begin`
+methods which both deliver asynchronous context managers.   The
+:class:`_asyncio.AsyncConnection` can then invoke statements using either the
+:meth:`_asyncio.AsyncConnection.execute` method to deliver a buffered
+:class:`_engine.Result`, or the :meth:`_asyncio.AsyncConnection.stream` method
+to deliver a streaming server-side :class:`_asyncio.AsyncResult`::
+
+    import asyncio
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+
+    async def async_main():
+        engine = create_async_engine(
+            "postgresql+asyncpg://scott:tiger@localhost/test", echo=True,
+        )
+
+        async with engine.begin() as conn:
+            await conn.run_sync(meta.drop_all)
+            await conn.run_sync(meta.create_all)
+
+            await conn.execute(
+                t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}]
+            )
+
+        async with engine.connect() as conn:
+
+            # select a Result, which will be delivered with buffered
+            # results
+            result = await conn.execute(select(t1).where(t1.c.name == "some name 1"))
+
+            print(result.fetchall())
+
+
+    asyncio.run(async_main())
+
+Above, the :meth:`_asyncio.AsyncConnection.run_sync` method may be used to
+invoke special DDL functions such as :meth:`_schema.MetaData.create_all` that
+don't include an awaitable hook.
+
+The :class:`_asyncio.AsyncConnection` also features a "streaming" API via
+the :meth:`_asyncio.AsyncConnection.stream` method that returns an
+:class:`_asyncio.AsyncResult` object.  This result object uses a server-side
+cursor and provides an async/await API, such as an async iterator::
+
+    async with engine.connect() as conn:
+        async_result = await conn.stream(select(t1))
+
+        async for row in async_result:
+            print("row: %s" % (row, ))
+
+
+Synopsis - ORM
+---------------
+
+Using :term:`2.0 style` querying, the :class:`_asyncio.AsyncSession` class
+provides full ORM functionality.   Within the default mode of use, special care
+must be taken to avoid :term:`lazy loading` of ORM relationships and column
+attributes, as below where the :func:`_orm.selectinload` eager loading strategy
+is used to ensure the ``A.bs`` on each ``A`` object is loaded::
+
+    import asyncio
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    from sqlalchemy.ext.asyncio import AsyncSession
+
+    async def async_main():
+        engine = create_async_engine(
+            "postgresql+asyncpg://scott:tiger@localhost/test", echo=True,
+        )
+        async with engine.begin() as conn:
+            await conn.run_sync(Base.metadata.drop_all)
+            await conn.run_sync(Base.metadata.create_all)
+
+        async with AsyncSession(engine) as session:
+            async with session.begin():
+                session.add_all(
+                    [
+                        A(bs=[B(), B()], data="a1"),
+                        A(bs=[B()], data="a2"),
+                        A(bs=[B(), B()], data="a3"),
+                    ]
+                )
+
+            stmt = select(A).options(selectinload(A.bs))
+
+            result = await session.execute(stmt)
+
+            for a1 in result.scalars():
+                print(a1)
+                for b1 in a1.bs:
+                    print(b1)
+
+            result = await session.execute(select(A).order_by(A.id))
+
+            a1 = result.scalars().first()
+
+            a1.data = "new data"
+
+            await session.commit()
+
+    asyncio.run(async_main())
+
+Above, the :func:`_orm.selectinload` eager loader is employed in order
+to eagerly load the ``A.bs`` collection within the scope of the
+``await session.execute()`` call.   If the default loader strategy of
+"lazyload" were left in place, the access of the ``A.bs`` attribute would
+raise an asyncio exception.  Using traditional asyncio, the application
+needs to avoid any points at which IO-on-attribute access may occur.
+This also includes that methods such as :meth:`_orm.Session.expire` should be
+avoided in favor of :meth:`_asyncio.AsyncSession.refresh`, and that
+appropriate loader options should be employed for :func:`_orm.deferred`
+columns as well as for :func:`_orm.relationship` constructs.
+
+Adapting ORM Lazy loads to asyncio
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. deepalchemy::  This approach is essentially exposing publicly the
+   mechanism by which SQLAlchemy is able to provide the asyncio interface
+   in the first place.   While there is no technical issue with doing so, overall
+   the approach can probably be considered "controversial" as it works against
+   some of the central philosophies of the asyncio programming model, which
+   is essentially that any programming statement that can potentially result
+   in IO being invoked **must** have an ``await`` call, lest the program
+   does not make it explicitly clear every line at which IO may occur.
+   This approach does not change that general idea, except that it allows
+   a series of synchronous IO instructions to be exempted from this rule
+   within the scope of a function call, essentially bundled up into a single
+   awaitable.
+
+As an alternative means of integrating traditional SQLAlchemy "lazy loading"
+within an asyncio event loop, an **optional** method known as
+:meth:`_asyncio.AsyncSession.run_sync` is provided which will run any
+Python function inside of a greenlet, where traditional synchronous
+programming concepts will be translated to use ``await`` when they reach the
+database driver.   A hypothetical approach here is an asyncio-oriented
+application can package up database-related methods into functions that are
+invoked using :meth:`_asyncio.AsyncSession.run_sync`.
+
+Altering the above example, if we didn't use :func:`_orm.selectinload`
+for the ``A.bs`` collection, we could accomplish our treatment of these
+attribute accesses within a separate function::
+
+    import asyncio
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    from sqlalchemy.ext.asyncio import AsyncSession
+
+    def fetch_and_update_objects(session):
+        """run traditional sync-style ORM code in a function that will be
+        invoked within an awaitable.
+
+        """
+
+        # the session object here is a traditional ORM Session.
+        # all features are available here including legacy Query use.
+
+        stmt = select(A)
+
+        result = session.execute(stmt)
+        for a1 in result.scalars():
+            print(a1)
+
+            # lazy loads
+            for b1 in a1.bs:
+                print(b1)
+
+        # legacy Query use
+        a1 = session.query(A).order_by(A.id).first()
+
+        a1.data = "new data"
+
+
+    async def async_main():
+        engine = create_async_engine(
+            "postgresql+asyncpg://scott:tiger@localhost/test", echo=True,
+        )
+        async with engine.begin() as conn:
+            await conn.run_sync(Base.metadata.drop_all)
+            await conn.run_sync(Base.metadata.create_all)
+
+        async with AsyncSession(engine) as session:
+            async with session.begin():
+                session.add_all(
+                    [
+                        A(bs=[B(), B()], data="a1"),
+                        A(bs=[B()], data="a2"),
+                        A(bs=[B(), B()], data="a3"),
+                    ]
+                )
+
+            session.run_sync(fetch_and_update_objects)
+
+            await session.commit()
+
+    asyncio.run(async_main())
+
+The above approach of running certain functions within a "sync" runner
+has some parallels to an application that runs a SQLAlchemy application
+on top of an event-based programming library such as ``gevent``.  The
+differences are as follows:
+
+1. unlike when using ``gevent``, we can continue to use the standard Python
+   asyncio event loop, or any custom event loop, without the need to integrate
+   into the ``gevent`` event loop.
+
+2. There is no "monkeypatching" whatsoever.   The above example makes use of
+   a real asyncio driver and the underlying SQLAlchemy connection pool is also
+   using the Python built-in ``asyncio.Queue`` for pooling connections.
+
+3. The program can freely switch between async/await code and contained
+   functions that use sync code with virtually no performance penalty.  There
+   is no "thread executor" or any additional waiters or synchronization in use.
+
+4. The underlying network drivers are also using pure Python asyncio
+   concepts, no third party networking libraries as ``gevent`` and ``eventlet``
+   provides are in use.
+
+.. currentmodule:: sqlalchemy.ext.asyncio
+
+Engine API Documentation
+-------------------------
+
+.. autofunction:: create_async_engine
+
+.. autoclass:: AsyncEngine
+   :members:
+
+.. autoclass:: AsyncConnection
+   :members:
+
+.. autoclass:: AsyncTransaction
+   :members:
+
+Result Set API Documentation
+----------------------------------
+
+The :class:`_asyncio.AsyncResult` object is an async-adapted version of the
+:class:`_result.Result` object.  It is only returned when using the
+:meth:`_asyncio.AsyncConnection.stream` or :meth:`_asyncio.AsyncSession.stream`
+methods, which return a result object that is on top of an active database
+cursor.
+
+.. autoclass:: AsyncResult
+   :members:
+
+.. autoclass:: AsyncScalarResult
+   :members:
+
+.. autoclass:: AsyncMappingResult
+   :members:
+
+ORM Session API Documentation
+-----------------------------
+
+.. autoclass:: AsyncSession
+   :members:
+
+.. autoclass:: AsyncSessionTransaction
+   :members:
+
+
+
index e23fd55ee720c9ececa6187330bd11ce6432e520..ba040b9f65f84d608b03080ff9e7379477a83814 100644 (file)
@@ -15,6 +15,7 @@ behavior.   In particular the "Horizontal Sharding", "Hybrid Attributes", and
 .. toctree::
     :maxdepth: 1
 
+    asyncio
     associationproxy
     automap
     baked
diff --git a/examples/asyncio/__init__.py b/examples/asyncio/__init__.py
new file mode 100644 (file)
index 0000000..c53120f
--- /dev/null
@@ -0,0 +1,6 @@
+"""
+Examples illustrating the asyncio engine feature of SQLAlchemy.
+
+.. autosource::
+
+"""
diff --git a/examples/asyncio/async_orm.py b/examples/asyncio/async_orm.py
new file mode 100644 (file)
index 0000000..b1054a2
--- /dev/null
@@ -0,0 +1,89 @@
+"""Illustrates use of the sqlalchemy.ext.asyncio.AsyncSession object
+for asynchronous ORM use.
+
+"""
+
+import asyncio
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.future import select
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
+
+Base = declarative_base()
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    data = Column(String)
+    bs = relationship("B")
+
+
+class B(Base):
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+    a_id = Column(ForeignKey("a.id"))
+    data = Column(String)
+
+
+async def async_main():
+    """Main program function."""
+
+    engine = create_async_engine(
+        "postgresql+asyncpg://scott:tiger@localhost/test", echo=True,
+    )
+
+    async with engine.begin() as conn:
+        await conn.run_sync(Base.metadata.drop_all)
+        await conn.run_sync(Base.metadata.create_all)
+
+    async with AsyncSession(engine) as session:
+        async with session.begin():
+            session.add_all(
+                [
+                    A(bs=[B(), B()], data="a1"),
+                    A(bs=[B()], data="a2"),
+                    A(bs=[B(), B()], data="a3"),
+                ]
+            )
+
+        # for relationship loading, eager loading should be applied.
+        stmt = select(A).options(selectinload(A.bs))
+
+        # AsyncSession.execute() is used for 2.0 style ORM execution
+        # (same as the synchronous API).
+        result = await session.execute(stmt)
+
+        # result is a buffered Result object.
+        for a1 in result.scalars():
+            print(a1)
+            for b1 in a1.bs:
+                print(b1)
+
+        # for streaming ORM results, AsyncSession.stream() may be used.
+        result = await session.stream(stmt)
+
+        # result is a streaming AsyncResult object.
+        async for a1 in result.scalars():
+            print(a1)
+            for b1 in a1.bs:
+                print(b1)
+
+        result = await session.execute(select(A).order_by(A.id))
+
+        a1 = result.scalars().first()
+
+        a1.data = "new data"
+
+        await session.commit()
+
+
+asyncio.run(async_main())
diff --git a/examples/asyncio/basic.py b/examples/asyncio/basic.py
new file mode 100644 (file)
index 0000000..05cdd8a
--- /dev/null
@@ -0,0 +1,71 @@
+"""Illustrates the asyncio engine / connection interface.
+
+In this example, we have an async engine created by
+:func:`_engine.create_async_engine`.   We then use it using await
+within a coroutine.
+
+"""
+
+
+import asyncio
+
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.ext.asyncio import create_async_engine
+
+
+meta = MetaData()
+
+t1 = Table(
+    "t1", meta, Column("id", Integer, primary_key=True), Column("name", String)
+)
+
+
+async def async_main():
+    # engine is an instance of AsyncEngine
+    engine = create_async_engine(
+        "postgresql+asyncpg://scott:tiger@localhost/test", echo=True,
+    )
+
+    # conn is an instance of AsyncConnection
+    async with engine.begin() as conn:
+
+        # to support SQLAlchemy DDL methods as well as legacy functions, the
+        # AsyncConnection.run_sync() awaitable method will pass a "sync"
+        # version of the AsyncConnection object to any synchronous method,
+        # where synchronous IO calls will be transparently translated for
+        # await.
+        await conn.run_sync(meta.drop_all)
+        await conn.run_sync(meta.create_all)
+
+        # for normal statement execution, a traditional "await execute()"
+        # pattern is used.
+        await conn.execute(
+            t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}]
+        )
+
+    async with engine.connect() as conn:
+
+        # the default result object is the
+        # sqlalchemy.engine.Result object
+        result = await conn.execute(t1.select())
+
+        # the results are buffered so no await call is necessary
+        # for this case.
+        print(result.fetchall())
+
+        # for a streaming result that buffers only segments of the
+        # result at time, the AsyncConnection.stream() method is used.
+        # this returns a sqlalchemy.ext.asyncio.AsyncResult object.
+        async_result = await conn.stream(t1.select())
+
+        # this object supports async iteration and awaitable
+        # versions of methods like .all(), fetchmany(), etc.
+        async for row in async_result:
+            print(row)
+
+
+asyncio.run(async_main())
diff --git a/examples/asyncio/greenlet_orm.py b/examples/asyncio/greenlet_orm.py
new file mode 100644 (file)
index 0000000..e0b568c
--- /dev/null
@@ -0,0 +1,92 @@
+"""Illustrates use of the sqlalchemy.ext.asyncio.AsyncSession object
+for asynchronous ORM use, including the optional run_sync() method.
+
+
+"""
+
+import asyncio
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.future import select
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    data = Column(String)
+    bs = relationship("B")
+
+
+class B(Base):
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+    a_id = Column(ForeignKey("a.id"))
+    data = Column(String)
+
+
+def run_queries(session):
+    """A function written in "synchronous" style that will be invoked
+    within the asyncio event loop.
+
+    The session object passed is a traditional orm.Session object with
+    synchronous interface.
+
+    """
+
+    stmt = select(A)
+
+    result = session.execute(stmt)
+
+    for a1 in result.scalars():
+        print(a1)
+        # lazy loads
+        for b1 in a1.bs:
+            print(b1)
+
+    result = session.execute(select(A).order_by(A.id))
+
+    a1 = result.scalars().first()
+
+    a1.data = "new data"
+
+
+async def async_main():
+    """Main program function."""
+
+    engine = create_async_engine(
+        "postgresql+asyncpg://scott:tiger@localhost/test", echo=True,
+    )
+    async with engine.begin() as conn:
+        await conn.run_sync(Base.metadata.drop_all)
+        await conn.run_sync(Base.metadata.create_all)
+
+    async with AsyncSession(engine) as session:
+        async with session.begin():
+            session.add_all(
+                [
+                    A(bs=[B(), B()], data="a1"),
+                    A(bs=[B()], data="a2"),
+                    A(bs=[B(), B()], data="a3"),
+                ]
+            )
+
+        # we have the option to run a function written in sync style
+        # within the AsyncSession.run_sync() method.  The function will
+        # be passed a synchronous-style Session object and the function
+        # can use traditional ORM patterns.
+        await session.run_sync(run_queries)
+
+        await session.commit()
+
+
+asyncio.run(async_main())
index 06d22872a980838b554fc1420f213f390f236d7b..2762a9971b60d8bf94ddb038553e81aa8e633a9f 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
-
 from . import base
 from . import pg8000  # noqa
 from . import psycopg2  # noqa
@@ -58,7 +57,10 @@ from .ranges import INT8RANGE
 from .ranges import NUMRANGE
 from .ranges import TSRANGE
 from .ranges import TSTZRANGE
+from ...util import compat
 
+if compat.py3k:
+    from . import asyncpg  # noqa
 
 base.dialect = dialect = psycopg2.dialect
 
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
new file mode 100644 (file)
index 0000000..515ef6e
--- /dev/null
@@ -0,0 +1,786 @@
+# postgresql/asyncpg.py
+# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+asyncpg
+    :name: asyncpg
+    :dbapi: asyncpg
+    :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
+    :url: https://magicstack.github.io/asyncpg/
+
+The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
+
+Using a special asyncio mediation layer, the asyncpg dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
+
+The dialect can also be run as a "synchronous" dialect within the
+:func:`_sa.create_engine` function, which will pass "await" calls into
+an ad-hoc event loop.  This mode of operation is of **limited use**
+and is for special testing scenarios only.  The mode can be enabled by
+adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
+in conjunction with :func:`_sa.craete_engine`::
+
+    # for testing purposes only; do not use in production!
+    engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
+
+
+.. versionadded:: 1.4
+
+"""  # noqa
+
+import collections
+import decimal
+import itertools
+import re
+
+from . import json
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import ENUM
+from .base import INTERVAL
+from .base import OID
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import PGIdentifierPreparer
+from .base import REGCLASS
+from .base import UUID
+from ... import exc
+from ... import pool
+from ... import processors
+from ... import util
+from ...sql import sqltypes
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+try:
+    from uuid import UUID as _python_UUID  # noqa
+except ImportError:
+    _python_UUID = None
+
+
+class AsyncpgTime(sqltypes.Time):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.TIME
+
+
+class AsyncpgDate(sqltypes.Date):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.DATE
+
+
+class AsyncpgDateTime(sqltypes.DateTime):
+    def get_dbapi_type(self, dbapi):
+        if self.timezone:
+            return dbapi.TIMESTAMP_W_TZ
+        else:
+            return dbapi.TIMESTAMP
+
+
+class AsyncpgBoolean(sqltypes.Boolean):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.BOOLEAN
+
+
+class AsyncPgInterval(INTERVAL):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.INTERVAL
+
+    @classmethod
+    def adapt_emulated_to_native(cls, interval, **kw):
+
+        return AsyncPgInterval(precision=interval.second_precision)
+
+
+class AsyncPgEnum(ENUM):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.ENUM
+
+
+class AsyncpgInteger(sqltypes.Integer):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.INTEGER
+
+
+class AsyncpgBigInteger(sqltypes.BigInteger):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.BIGINTEGER
+
+
+class AsyncpgJSON(json.JSON):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.JSON
+
+
+class AsyncpgJSONB(json.JSONB):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.JSONB
+
+
+class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
+    def get_dbapi_type(self, dbapi):
+        raise NotImplementedError("should not be here")
+
+
+class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.INTEGER
+
+
+class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.STRING
+
+
+class AsyncpgJSONPathType(json.JSONPathType):
+    def bind_processor(self, dialect):
+        def process(value):
+            assert isinstance(value, util.collections_abc.Sequence)
+            tokens = [util.text_type(elem) for elem in value]
+            return tokens
+
+        return process
+
+
+class AsyncpgUUID(UUID):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.UUID
+
+    def bind_processor(self, dialect):
+        if not self.as_uuid and dialect.use_native_uuid:
+
+            def process(value):
+                if value is not None:
+                    value = _python_UUID(value)
+                return value
+
+            return process
+
+    def result_processor(self, dialect, coltype):
+        if not self.as_uuid and dialect.use_native_uuid:
+
+            def process(value):
+                if value is not None:
+                    value = str(value)
+                return value
+
+            return process
+
+
+class AsyncpgNumeric(sqltypes.Numeric):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect, coltype):
+        if self.asdecimal:
+            if coltype in _FLOAT_TYPES:
+                return processors.to_decimal_processor_factory(
+                    decimal.Decimal, self._effective_decimal_return_scale
+                )
+            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+                # pg8000 returns Decimal natively for 1700
+                return None
+            else:
+                raise exc.InvalidRequestError(
+                    "Unknown PG numeric type: %d" % coltype
+                )
+        else:
+            if coltype in _FLOAT_TYPES:
+                # pg8000 returns float natively for 701
+                return None
+            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+                return processors.to_float
+            else:
+                raise exc.InvalidRequestError(
+                    "Unknown PG numeric type: %d" % coltype
+                )
+
+
+class AsyncpgREGCLASS(REGCLASS):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.STRING
+
+
+class AsyncpgOID(OID):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.INTEGER
+
+
+class PGExecutionContext_asyncpg(PGExecutionContext):
+    def pre_exec(self):
+        if self.isddl:
+            self._dbapi_connection.reset_schema_state()
+
+        if not self.compiled:
+            return
+
+        # we have to exclude ENUM because "enum" not really a "type"
+        # we can cast to, it has to be the name of the type itself.
+        # for now we just omit it from casting
+        self.set_input_sizes(exclude_types={AsyncAdapt_asyncpg_dbapi.ENUM})
+
+    def create_server_side_cursor(self):
+        return self._dbapi_connection.cursor(server_side=True)
+
+
+class PGCompiler_asyncpg(PGCompiler):
+    pass
+
+
+class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
+    pass
+
+
+class AsyncAdapt_asyncpg_cursor:
+    __slots__ = (
+        "_adapt_connection",
+        "_connection",
+        "_rows",
+        "description",
+        "arraysize",
+        "rowcount",
+        "_inputsizes",
+        "_cursor",
+    )
+
+    server_side = False
+
+    def __init__(self, adapt_connection):
+        self._adapt_connection = adapt_connection
+        self._connection = adapt_connection._connection
+        self._rows = []
+        self._cursor = None
+        self.description = None
+        self.arraysize = 1
+        self.rowcount = -1
+        self._inputsizes = None
+
+    def close(self):
+        self._rows[:] = []
+
+    def _handle_exception(self, error):
+        self._adapt_connection._handle_exception(error)
+
+    def _parameters(self):
+        if not self._inputsizes:
+            return ("$%d" % idx for idx in itertools.count(1))
+        else:
+
+            return (
+                "$%d::%s" % (idx, typ) if typ else "$%d" % idx
+                for idx, typ in enumerate(
+                    (_pg_types.get(typ) for typ in self._inputsizes), 1
+                )
+            )
+
+    async def _prepare_and_execute(self, operation, parameters):
+        # TODO: I guess cache these in an LRU cache, or see if we can
+        # use some asyncpg concept
+
+        # TODO: would be nice to support the dollar numeric thing
+        # directly, this is much easier for now
+
+        if not self._adapt_connection._started:
+            await self._adapt_connection._start_transaction()
+
+        params = self._parameters()
+        operation = re.sub(r"\?", lambda m: next(params), operation)
+
+        try:
+            prepared_stmt = await self._connection.prepare(operation)
+
+            attributes = prepared_stmt.get_attributes()
+            if attributes:
+                self.description = [
+                    (attr.name, attr.type.oid, None, None, None, None, None)
+                    for attr in prepared_stmt.get_attributes()
+                ]
+            else:
+                self.description = None
+
+            if self.server_side:
+                self._cursor = await prepared_stmt.cursor(*parameters)
+                self.rowcount = -1
+            else:
+                self._rows = await prepared_stmt.fetch(*parameters)
+                status = prepared_stmt.get_statusmsg()
+
+                reg = re.match(r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status)
+                if reg:
+                    self.rowcount = int(reg.group(1))
+                else:
+                    self.rowcount = -1
+
+        except Exception as error:
+            self._handle_exception(error)
+
+    def execute(self, operation, parameters=()):
+        try:
+            self._adapt_connection.await_(
+                self._prepare_and_execute(operation, parameters)
+            )
+        except Exception as error:
+            self._handle_exception(error)
+
+    def executemany(self, operation, seq_of_parameters):
+        if not self._adapt_connection._started:
+            self._adapt_connection.await_(
+                self._adapt_connection._start_transaction()
+            )
+
+        params = self._parameters()
+        operation = re.sub(r"\?", lambda m: next(params), operation)
+        try:
+            return self._adapt_connection.await_(
+                self._connection.executemany(operation, seq_of_parameters)
+            )
+        except Exception as error:
+            self._handle_exception(error)
+
+    def setinputsizes(self, *inputsizes):
+        self._inputsizes = inputsizes
+
+    def __iter__(self):
+        while self._rows:
+            yield self._rows.pop(0)
+
+    def fetchone(self):
+        if self._rows:
+            return self._rows.pop(0)
+        else:
+            return None
+
+    def fetchmany(self, size=None):
+        if size is None:
+            size = self.arraysize
+
+        retval = self._rows[0:size]
+        self._rows[:] = self._rows[size:]
+        return retval
+
+    def fetchall(self):
+        retval = self._rows[:]
+        self._rows[:] = []
+        return retval
+
+
+class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
+
+    server_side = True
+    __slots__ = ("_rowbuffer",)
+
+    def __init__(self, adapt_connection):
+        super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection)
+        self._rowbuffer = None
+
+    def close(self):
+        self._cursor = None
+        self._rowbuffer = None
+
+    def _buffer_rows(self):
+        new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
+        self._rowbuffer = collections.deque(new_rows)
+
+    def __aiter__(self):
+        return self
+
+    async def __anext__(self):
+        if not self._rowbuffer:
+            self._buffer_rows()
+
+        while True:
+            while self._rowbuffer:
+                yield self._rowbuffer.popleft()
+
+            self._buffer_rows()
+            if not self._rowbuffer:
+                break
+
+    def fetchone(self):
+        if not self._rowbuffer:
+            self._buffer_rows()
+            if not self._rowbuffer:
+                return None
+        return self._rowbuffer.popleft()
+
+    def fetchmany(self, size=None):
+        if size is None:
+            return self.fetchall()
+
+        if not self._rowbuffer:
+            self._buffer_rows()
+
+        buf = list(self._rowbuffer)
+        lb = len(buf)
+        if size > lb:
+            buf.extend(
+                self._adapt_connection.await_(self._cursor.fetch(size - lb))
+            )
+
+        result = buf[0:size]
+        self._rowbuffer = collections.deque(buf[size:])
+        return result
+
+    def fetchall(self):
+        ret = list(self._rowbuffer) + list(
+            self._adapt_connection.await_(self._all())
+        )
+        self._rowbuffer.clear()
+        return ret
+
+    async def _all(self):
+        rows = []
+
+        # TODO: looks like we have to hand-roll some kind of batching here.
+        # hardcoding for the moment but this should be improved.
+        while True:
+            batch = await self._cursor.fetch(1000)
+            if batch:
+                rows.extend(batch)
+                continue
+            else:
+                break
+        return rows
+
+    def executemany(self, operation, seq_of_parameters):
+        raise NotImplementedError(
+            "server side cursor doesn't support executemany yet"
+        )
+
+
+class AsyncAdapt_asyncpg_connection:
+    __slots__ = (
+        "dbapi",
+        "_connection",
+        "isolation_level",
+        "_transaction",
+        "_started",
+    )
+
+    await_ = staticmethod(await_only)
+
+    def __init__(self, dbapi, connection):
+        self.dbapi = dbapi
+        self._connection = connection
+        self.isolation_level = "read_committed"
+        self._transaction = None
+        self._started = False
+        self.await_(self._setup_type_codecs())
+
+    async def _setup_type_codecs(self):
+        """set up type decoders at the asyncpg level.
+
+        this is first to accommodate the "char" value of
+        pg_catalog.pg_attribute.attgenerated being returned as bytes.
+        Even though the doc at
+        https://magicstack.github.io/asyncpg/current/usage.html#type-conversion
+        claims "char" is returned as "str", it looks like this is actually
+        the 'bpchar' datatype, blank padded.  'char' seems to be some
+        more obscure type (oid 18) and asyncpg codes this to bytea:
+        https://github.com/MagicStack/asyncpg/blob/master/asyncpg/protocol/
+        codecs/pgproto.pyx#L26
+
+        all the other drivers treat this as a string.
+
+        """
+
+        await self._connection.set_type_codec(
+            "char",
+            schema="pg_catalog",
+            encoder=lambda value: value,
+            decoder=lambda value: value,
+            format="text",
+        )
+
+    def _handle_exception(self, error):
+        if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
+            exception_mapping = self.dbapi._asyncpg_error_translate
+
+            for super_ in type(error).__mro__:
+                if super_ in exception_mapping:
+                    translated_error = exception_mapping[super_](
+                        "%s: %s" % (type(error), error)
+                    )
+                    raise translated_error from error
+            else:
+                raise error
+        else:
+            raise error
+
+    def set_isolation_level(self, level):
+        if self._started:
+            self.rollback()
+        self.isolation_level = level
+
+    async def _start_transaction(self):
+        if self.isolation_level == "autocommit":
+            return
+
+        try:
+            self._transaction = self._connection.transaction(
+                isolation=self.isolation_level
+            )
+            await self._transaction.start()
+        except Exception as error:
+            self._handle_exception(error)
+        else:
+            self._started = True
+
+    def cursor(self, server_side=False):
+        if server_side:
+            return AsyncAdapt_asyncpg_ss_cursor(self)
+        else:
+            return AsyncAdapt_asyncpg_cursor(self)
+
+    def reset_schema_state(self):
+        self.await_(self._connection.reload_schema_state())
+
+    def rollback(self):
+        if self._started:
+            self.await_(self._transaction.rollback())
+
+            self._transaction = None
+            self._started = False
+
+    def commit(self):
+        if self._started:
+            self.await_(self._transaction.commit())
+            self._transaction = None
+            self._started = False
+
+    def close(self):
+        self.rollback()
+
+        self.await_(self._connection.close())
+
+
+class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
+    await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_asyncpg_dbapi:
+    def __init__(self, asyncpg):
+        self.asyncpg = asyncpg
+        self.paramstyle = "qmark"
+
+    def connect(self, *arg, **kw):
+        async_fallback = kw.pop("async_fallback", False)
+
+        if async_fallback:
+            return AsyncAdaptFallback_asyncpg_connection(
+                self, await_fallback(self.asyncpg.connect(*arg, **kw)),
+            )
+        else:
+            return AsyncAdapt_asyncpg_connection(
+                self, await_only(self.asyncpg.connect(*arg, **kw)),
+            )
+
+    class Error(Exception):
+        pass
+
+    class Warning(Exception):  # noqa
+        pass
+
+    class InterfaceError(Error):
+        pass
+
+    class DatabaseError(Error):
+        pass
+
+    class InternalError(DatabaseError):
+        pass
+
+    class OperationalError(DatabaseError):
+        pass
+
+    class ProgrammingError(DatabaseError):
+        pass
+
+    class IntegrityError(DatabaseError):
+        pass
+
+    class DataError(DatabaseError):
+        pass
+
+    class NotSupportedError(DatabaseError):
+        pass
+
+    @util.memoized_property
+    def _asyncpg_error_translate(self):
+        import asyncpg
+
+        return {
+            asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError,  # noqa
+            asyncpg.exceptions.PostgresError: self.Error,
+            asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
+            asyncpg.exceptions.InterfaceError: self.InterfaceError,
+        }
+
+    def Binary(self, value):
+        return value
+
+    STRING = util.symbol("STRING")
+    TIMESTAMP = util.symbol("TIMESTAMP")
+    TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
+    TIME = util.symbol("TIME")
+    DATE = util.symbol("DATE")
+    INTERVAL = util.symbol("INTERVAL")
+    NUMBER = util.symbol("NUMBER")
+    FLOAT = util.symbol("FLOAT")
+    BOOLEAN = util.symbol("BOOLEAN")
+    INTEGER = util.symbol("INTEGER")
+    BIGINTEGER = util.symbol("BIGINTEGER")
+    BYTES = util.symbol("BYTES")
+    DECIMAL = util.symbol("DECIMAL")
+    JSON = util.symbol("JSON")
+    JSONB = util.symbol("JSONB")
+    ENUM = util.symbol("ENUM")
+    UUID = util.symbol("UUID")
+    BYTEA = util.symbol("BYTEA")
+
+    DATETIME = TIMESTAMP
+    BINARY = BYTEA
+
+
+_pg_types = {
+    AsyncAdapt_asyncpg_dbapi.STRING: "varchar",
+    AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp",
+    AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
+    AsyncAdapt_asyncpg_dbapi.DATE: "date",
+    AsyncAdapt_asyncpg_dbapi.TIME: "time",
+    AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
+    AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
+    AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
+    AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool",
+    AsyncAdapt_asyncpg_dbapi.INTEGER: "integer",
+    AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint",
+    AsyncAdapt_asyncpg_dbapi.BYTES: "bytes",
+    AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal",
+    AsyncAdapt_asyncpg_dbapi.JSON: "json",
+    AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb",
+    AsyncAdapt_asyncpg_dbapi.ENUM: "enum",
+    AsyncAdapt_asyncpg_dbapi.UUID: "uuid",
+    AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea",
+}
+
+
+class PGDialect_asyncpg(PGDialect):
+    driver = "asyncpg"
+
+    supports_unicode_statements = True
+    supports_server_side_cursors = True
+
+    supports_unicode_binds = True
+
+    default_paramstyle = "qmark"
+    supports_sane_multi_rowcount = False
+    execution_ctx_cls = PGExecutionContext_asyncpg
+    statement_compiler = PGCompiler_asyncpg
+    preparer = PGIdentifierPreparer_asyncpg
+
+    use_native_uuid = True
+
+    colspecs = util.update_copy(
+        PGDialect.colspecs,
+        {
+            sqltypes.Time: AsyncpgTime,
+            sqltypes.Date: AsyncpgDate,
+            sqltypes.DateTime: AsyncpgDateTime,
+            sqltypes.Interval: AsyncPgInterval,
+            INTERVAL: AsyncPgInterval,
+            UUID: AsyncpgUUID,
+            sqltypes.Boolean: AsyncpgBoolean,
+            sqltypes.Integer: AsyncpgInteger,
+            sqltypes.BigInteger: AsyncpgBigInteger,
+            sqltypes.Numeric: AsyncpgNumeric,
+            sqltypes.JSON: AsyncpgJSON,
+            json.JSONB: AsyncpgJSONB,
+            sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
+            sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
+            sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
+            sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
+            sqltypes.Enum: AsyncPgEnum,
+            OID: AsyncpgOID,
+            REGCLASS: AsyncpgREGCLASS,
+        },
+    )
+
+    def __init__(self, server_side_cursors=False, **kwargs):
+        PGDialect.__init__(self, **kwargs)
+        self.server_side_cursors = server_side_cursors
+
+    @util.memoized_property
+    def _dbapi_version(self):
+        if self.dbapi and hasattr(self.dbapi, "__version__"):
+            return tuple(
+                [
+                    int(x)
+                    for x in re.findall(
+                        r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
+                    )
+                ]
+            )
+        else:
+            return (99, 99, 99)
+
+    @classmethod
+    def dbapi(cls):
+        return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
+
+    @util.memoized_property
+    def _isolation_lookup(self):
+        return {
+            "AUTOCOMMIT": "autocommit",
+            "READ COMMITTED": "read_committed",
+            "REPEATABLE READ": "repeatable_read",
+            "SERIALIZABLE": "serializable",
+        }
+
+    def set_isolation_level(self, connection, level):
+        try:
+            level = self._isolation_lookup[level.replace("_", " ")]
+        except KeyError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "Invalid value '%s' for isolation_level. "
+                    "Valid isolation levels for %s are %s"
+                    % (level, self.name, ", ".join(self._isolation_lookup))
+                ),
+                replace_context=err,
+            )
+
+        connection.set_isolation_level(level)
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username="user")
+        if "port" in opts:
+            opts["port"] = int(opts["port"])
+        opts.update(url.query)
+        return ([], opts)
+
+    @classmethod
+    def get_pool_class(self, url):
+        return pool.AsyncAdaptedQueuePool
+
+    def is_disconnect(self, e, connection, cursor):
+        if connection:
+            return connection._connection.is_closed()
+        else:
+            return isinstance(
+                e, self.dbapi.InterfaceError
+            ) and "connection is closed" in str(e)
+
+
+dialect = PGDialect_asyncpg
index 3bd7e62d53eebcbd6dbb6d3bea1ebdacba11d434..7717a2526bf0b428d383c3b97f01b0b027731966 100644 (file)
@@ -1299,6 +1299,14 @@ class UUID(sqltypes.TypeEngine):
          """
         self.as_uuid = as_uuid
 
+    def coerce_compared_value(self, op, value):
+        """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
+
+        if isinstance(value, util.string_types):
+            return self
+        else:
+            return super(UUID, self).coerce_compared_value(op, value)
+
     def bind_processor(self, dialect):
         if self.as_uuid:
 
index d60f14f315507a216c6db335325f5dcd32e9abe3..34bf720b7842d156627cc39a93941897c45316cc 100644 (file)
@@ -89,6 +89,7 @@ class Connection(Connectable):
                 if connection is not None
                 else engine.raw_connection()
             )
+
             self._transaction = self._nested_transaction = None
             self.__savepoint_seq = 0
             self.__in_begin = False
@@ -623,6 +624,9 @@ class Connection(Connectable):
 
         self._dbapi_connection.detach()
 
+    def _autobegin(self):
+        self.begin()
+
     def begin(self):
         """Begin a transaction and return a transaction handle.
 
@@ -1433,7 +1437,7 @@ class Connection(Connectable):
             self._invalid_transaction()
 
         if self._is_future and self._transaction is None:
-            self.begin()
+            self._autobegin()
 
         context.pre_exec()
 
@@ -2592,6 +2596,7 @@ class Engine(Connectable, log.Identified):
             return self.conn
 
         def __exit__(self, type_, value, traceback):
+
             if type_ is not None:
                 self.transaction.rollback()
             else:
index dc895ee15d93c64bbf3f5a26b820a0ef7a04ba3c..66173d9b038b1b1635cb7dfea7fa7faa997cb427 100644 (file)
@@ -553,7 +553,7 @@ def create_engine(url, **kwargs):
 
         poolclass = pop_kwarg("poolclass", None)
         if poolclass is None:
-            poolclass = dialect_cls.get_pool_class(u)
+            poolclass = dialect.get_dialect_pool_class(u)
         pool_args = {"dialect": dialect}
 
         # consume pool arguments from kwargs, translating a few of
index c76f820f9b496696b126b9cb3edaebb6cd944a18..4fb20a3d509317d9c780a4e287b37a637ca7cad5 100644 (file)
@@ -317,6 +317,9 @@ class DefaultDialect(interfaces.Dialect):
     def get_pool_class(cls, url):
         return getattr(cls, "poolclass", pool.QueuePool)
 
+    def get_dialect_pool_class(self, url):
+        return self.get_pool_class(url)
+
     @classmethod
     def load_provisioning(cls):
         package = ".".join(cls.__module__.split(".")[0:-1])
index 9badbffc3ca71059a9bd85e123895bdb03c2513a..10a88c7d880e85563471190babbc336120e4fe99 100644 (file)
@@ -20,6 +20,7 @@ from ..sql.base import _generative
 from ..sql.base import HasMemoized
 from ..sql.base import InPlaceGenerative
 from ..util import collections_abc
+from ..util import py2k
 
 if util.TYPE_CHECKING:
     from typing import Any
@@ -616,6 +617,16 @@ class ResultInternal(InPlaceGenerative):
         else:
             return row
 
+    def _iter_impl(self):
+        return self._iterator_getter(self)
+
+    def _next_impl(self):
+        row = self._onerow_getter(self)
+        if row is _NO_ROW:
+            raise StopIteration()
+        else:
+            return row
+
     @_generative
     def _column_slices(self, indexes):
         real_result = self._real_result if self._real_result else self
@@ -892,16 +903,15 @@ class Result(ResultInternal):
         raise NotImplementedError()
 
     def __iter__(self):
-        return self._iterator_getter(self)
+        return self._iter_impl()
 
     def __next__(self):
-        row = self._onerow_getter(self)
-        if row is _NO_ROW:
-            raise StopIteration()
-        else:
-            return row
+        return self._next_impl()
+
+    if py2k:
 
-    next = __next__
+        def next(self):  # noqa
+            return self._next_impl()
 
     def partitions(self, size=None):
         # type: (Optional[Int]) -> Iterator[List[Row]]
@@ -1015,12 +1025,10 @@ class Result(ResultInternal):
            column of the first row, use the :meth:`.Result.scalar` method,
            or combine :meth:`.Result.scalars` and :meth:`.Result.first`.
 
-        .. comment: A warning is emitted if additional rows remain.
-
         :return: a :class:`.Row` object, or None
          if no rows remain.
 
-         .. seealso::
+        .. seealso::
 
             :meth:`_result.Result.scalar`
 
@@ -1186,18 +1194,6 @@ class FilterResult(ResultInternal):
     def _attributes(self):
         return self._real_result._attributes
 
-    def __iter__(self):
-        return self._iterator_getter(self)
-
-    def __next__(self):
-        row = self._onerow_getter(self)
-        if row is _NO_ROW:
-            raise StopIteration()
-        else:
-            return row
-
-    next = __next__
-
     def _fetchiter_impl(self):
         return self._real_result._fetchiter_impl()
 
@@ -1299,6 +1295,17 @@ class ScalarResult(FilterResult):
         """
         return self._allrows()
 
+    def __iter__(self):
+        return self._iter_impl()
+
+    def __next__(self):
+        return self._next_impl()
+
+    if py2k:
+
+        def next(self):  # noqa
+            return self._next_impl()
+
     def first(self):
         # type: () -> Optional[Any]
         """Fetch the first object or None if no object is present.
@@ -1409,7 +1416,7 @@ class MappingResult(FilterResult):
 
     def fetchall(self):
         # type: () -> List[Mapping]
-        """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+        """A synonym for the :meth:`_engine.MappingResult.all` method."""
 
         return self._allrows()
 
@@ -1453,6 +1460,17 @@ class MappingResult(FilterResult):
 
         return self._allrows()
 
+    def __iter__(self):
+        return self._iter_impl()
+
+    def __next__(self):
+        return self._next_impl()
+
+    if py2k:
+
+        def next(self):  # noqa
+            return self._next_impl()
+
     def first(self):
         # type: () -> Optional[Mapping]
         """Fetch the first object or None if no object is present.
@@ -1519,13 +1537,11 @@ class FrozenResult(object):
 
     .. seealso::
 
-        .. seealso::
-
-            :ref:`do_orm_execute_re_executing` - example usage within the
-            ORM to implement a result-set cache.
+        :ref:`do_orm_execute_re_executing` - example usage within the
+        ORM to implement a result-set cache.
 
-            :func:`_orm.loading.merge_frozen_result` - ORM function to merge
-            a frozen result back into a :class:`_orm.Session`.
+        :func:`_orm.loading.merge_frozen_result` - ORM function to merge
+        a frozen result back into a :class:`_orm.Session`.
 
     """
 
@@ -1624,21 +1640,36 @@ class ChunkedIteratorResult(IteratorResult):
     """
 
     def __init__(
-        self, cursor_metadata, chunks, source_supports_scalars=False, raw=None
+        self,
+        cursor_metadata,
+        chunks,
+        source_supports_scalars=False,
+        raw=None,
+        dynamic_yield_per=False,
     ):
         self._metadata = cursor_metadata
         self.chunks = chunks
         self._source_supports_scalars = source_supports_scalars
         self.raw = raw
         self.iterator = itertools.chain.from_iterable(self.chunks(None))
+        self.dynamic_yield_per = dynamic_yield_per
 
     @_generative
     def yield_per(self, num):
+        # TODO: this throws away the iterator which may be holding
+        # onto a chunk.   the yield_per cannot be changed once any
+        # rows have been fetched.   either find a way to enforce this,
+        # or we can't use itertools.chain and will instead have to
+        # keep track.
+
         self._yield_per = num
-        # TODO: this should raise if the iterator has already been started.
-        # we can't change the yield mid-stream like this
         self.iterator = itertools.chain.from_iterable(self.chunks(num))
 
+    def _fetchmany_impl(self, size=None):
+        if self.dynamic_yield_per:
+            self.iterator = itertools.chain.from_iterable(self.chunks(size))
+        return super(ChunkedIteratorResult, self)._fetchmany_impl(size=size)
+
 
 class MergedResult(IteratorResult):
     """A :class:`_engine.Result` that is merged from any number of
@@ -1677,6 +1708,5 @@ class MergedResult(IteratorResult):
     def _soft_close(self, hard=False):
         for r in self._results:
             r._soft_close(hard=hard)
-
         if hard:
             self.closed = True
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py
new file mode 100644 (file)
index 0000000..fbbc958
--- /dev/null
@@ -0,0 +1,9 @@
+from .engine import AsyncConnection  # noqa
+from .engine import AsyncEngine  # noqa
+from .engine import AsyncTransaction  # noqa
+from .engine import create_async_engine  # noqa
+from .result import AsyncMappingResult  # noqa
+from .result import AsyncResult  # noqa
+from .result import AsyncScalarResult  # noqa
+from .session import AsyncSession  # noqa
+from .session import AsyncSessionTransaction  # noqa
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
new file mode 100644 (file)
index 0000000..051f9e2
--- /dev/null
@@ -0,0 +1,25 @@
+import abc
+
+from . import exc as async_exc
+
+
+class StartableContext(abc.ABC):
+    @abc.abstractmethod
+    async def start(self) -> "StartableContext":
+        pass
+
+    def __await__(self):
+        return self.start().__await__()
+
+    async def __aenter__(self):
+        return await self.start()
+
+    @abc.abstractmethod
+    async def __aexit__(self, type_, value, traceback):
+        pass
+
+    def _raise_for_not_started(self):
+        raise async_exc.AsyncContextNotStarted(
+            "%s context has not been started and object has not been awaited."
+            % (self.__class__.__name__)
+        )
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
new file mode 100644 (file)
index 0000000..2d9198d
--- /dev/null
@@ -0,0 +1,461 @@
+from typing import Any
+from typing import Callable
+from typing import Mapping
+from typing import Optional
+
+from . import exc as async_exc
+from .base import StartableContext
+from .result import AsyncResult
+from ... import exc
+from ... import util
+from ...engine import Connection
+from ...engine import create_engine as _create_engine
+from ...engine import Engine
+from ...engine import Result
+from ...engine import Transaction
+from ...engine.base import OptionEngineMixin
+from ...sql import Executable
+from ...util.concurrency import greenlet_spawn
+
+
+def create_async_engine(*arg, **kw):
+    """Create a new async engine instance.
+
+    Arguments passed to :func:`_asyncio.create_async_engine` are mostly
+    identical to those passed to the :func:`_sa.create_engine` function.
+    The specified dialect must be an asyncio-compatible dialect
+    such as :ref:`dialect-postgresql-asyncpg`.
+
+    .. versionadded:: 1.4
+
+    """
+
+    if kw.get("server_side_cursors", False):
+        raise exc.AsyncMethodRequired(
+            "Can't set server_side_cursors for async engine globally; "
+            "use the connection.stream() method for an async "
+            "streaming result set"
+        )
+    kw["future"] = True
+    sync_engine = _create_engine(*arg, **kw)
+    return AsyncEngine(sync_engine)
+
+
+class AsyncConnection(StartableContext):
+    """An asyncio proxy for a :class:`_engine.Connection`.
+
+    :class:`_asyncio.AsyncConnection` is acquired using the
+    :meth:`_asyncio.AsyncEngine.connect`
+    method of :class:`_asyncio.AsyncEngine`::
+
+        from sqlalchemy.ext.asyncio import create_async_engine
+        engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+        async with engine.connect() as conn:
+            result = await conn.execute(select(table))
+
+    .. versionadded:: 1.4
+
+    """  # noqa
+
+    __slots__ = (
+        "sync_engine",
+        "sync_connection",
+    )
+
+    def __init__(
+        self, sync_engine: Engine, sync_connection: Optional[Connection] = None
+    ):
+        self.sync_engine = sync_engine
+        self.sync_connection = sync_connection
+
+    async def start(self):
+        """Start this :class:`_asyncio.AsyncConnection` object's context
+        outside of using a Python ``with:`` block.
+
+        """
+        if self.sync_connection:
+            raise exc.InvalidRequestError("connection is already started")
+        self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
+        return self
+
+    def _sync_connection(self):
+        if not self.sync_connection:
+            self._raise_for_not_started()
+        return self.sync_connection
+
+    def begin(self) -> "AsyncTransaction":
+        """Begin a transaction prior to autobegin occurring.
+
+        """
+        self._sync_connection()
+        return AsyncTransaction(self)
+
+    def begin_nested(self) -> "AsyncTransaction":
+        """Begin a nested transaction and return a transaction handle.
+
+        """
+        self._sync_connection()
+        return AsyncTransaction(self, nested=True)
+
+    async def commit(self):
+        """Commit the transaction that is currently in progress.
+
+        This method commits the current transaction if one has been started.
+        If no transaction was started, the method has no effect, assuming
+        the connection is in a non-invalidated state.
+
+        A transaction is begun on a :class:`_future.Connection` automatically
+        whenever a statement is first executed, or when the
+        :meth:`_future.Connection.begin` method is called.
+
+        """
+        conn = self._sync_connection()
+        await greenlet_spawn(conn.commit)
+
+    async def rollback(self):
+        """Roll back the transaction that is currently in progress.
+
+        This method rolls back the current transaction if one has been started.
+        If no transaction was started, the method has no effect.  If a
+        transaction was started and the connection is in an invalidated state,
+        the transaction is cleared using this method.
+
+        A transaction is begun on a :class:`_future.Connection` automatically
+        whenever a statement is first executed, or when the
+        :meth:`_future.Connection.begin` method is called.
+
+
+        """
+        conn = self._sync_connection()
+        await greenlet_spawn(conn.rollback)
+
+    async def close(self):
+        """Close this :class:`_asyncio.AsyncConnection`.
+
+        This has the effect of also rolling back the transaction if one
+        is in place.
+
+        """
+        conn = self._sync_connection()
+        await greenlet_spawn(conn.close)
+
+    async def exec_driver_sql(
+        self,
+        statement: Executable,
+        parameters: Optional[Mapping] = None,
+        execution_options: Mapping = util.EMPTY_DICT,
+    ) -> Result:
+        r"""Executes a driver-level SQL string and return buffered
+        :class:`_engine.Result`.
+
+        """
+
+        conn = self._sync_connection()
+
+        result = await greenlet_spawn(
+            conn.exec_driver_sql, statement, parameters, execution_options,
+        )
+        if result.context._is_server_side:
+            raise async_exc.AsyncMethodRequired(
+                "Can't use the connection.exec_driver_sql() method with a "
+                "server-side cursor."
+                "Use the connection.stream() method for an async "
+                "streaming result set."
+            )
+
+        return result
+
+    async def stream(
+        self,
+        statement: Executable,
+        parameters: Optional[Mapping] = None,
+        execution_options: Mapping = util.EMPTY_DICT,
+    ) -> AsyncResult:
+        """Execute a statement and return a streaming
+        :class:`_asyncio.AsyncResult` object."""
+
+        conn = self._sync_connection()
+
+        result = await greenlet_spawn(
+            conn._execute_20,
+            statement,
+            parameters,
+            util.EMPTY_DICT.merge_with(
+                execution_options, {"stream_results": True}
+            ),
+        )
+        if not result.context._is_server_side:
+            # TODO: real exception here
+            assert False, "server side result expected"
+        return AsyncResult(result)
+
+    async def execute(
+        self,
+        statement: Executable,
+        parameters: Optional[Mapping] = None,
+        execution_options: Mapping = util.EMPTY_DICT,
+    ) -> Result:
+        r"""Executes a SQL statement construct and return a buffered
+        :class:`_engine.Result`.
+
+        :param object: The statement to be executed.  This is always
+         an object that is in both the :class:`_expression.ClauseElement` and
+         :class:`_expression.Executable` hierarchies, including:
+
+         * :class:`_expression.Select`
+         * :class:`_expression.Insert`, :class:`_expression.Update`,
+           :class:`_expression.Delete`
+         * :class:`_expression.TextClause` and
+           :class:`_expression.TextualSelect`
+         * :class:`_schema.DDL` and objects which inherit from
+           :class:`_schema.DDLElement`
+
+        :param parameters: parameters which will be bound into the statement.
+         This may be either a dictionary of parameter names to values,
+         or a mutable sequence (e.g. a list) of dictionaries.  When a
+         list of dictionaries is passed, the underlying statement execution
+         will make use of the DBAPI ``cursor.executemany()`` method.
+         When a single dictionary is passed, the DBAPI ``cursor.execute()``
+         method will be used.
+
+        :param execution_options: optional dictionary of execution options,
+         which will be associated with the statement execution.  This
+         dictionary can provide a subset of the options that are accepted
+         by :meth:`_future.Connection.execution_options`.
+
+        :return: a :class:`_engine.Result` object.
+
+        """
+        conn = self._sync_connection()
+
+        result = await greenlet_spawn(
+            conn._execute_20, statement, parameters, execution_options,
+        )
+        if result.context._is_server_side:
+            raise async_exc.AsyncMethodRequired(
+                "Can't use the connection.execute() method with a "
+                "server-side cursor."
+                "Use the connection.stream() method for an async "
+                "streaming result set."
+            )
+        return result
+
+    async def scalar(
+        self,
+        statement: Executable,
+        parameters: Optional[Mapping] = None,
+        execution_options: Mapping = util.EMPTY_DICT,
+    ) -> Any:
+        r"""Executes a SQL statement construct and returns a scalar object.
+
+        This method is shorthand for invoking the
+        :meth:`_engine.Result.scalar` method after invoking the
+        :meth:`_future.Connection.execute` method.  Parameters are equivalent.
+
+        :return: a scalar Python value representing the first column of the
+         first row returned.
+
+        """
+        result = await self.execute(statement, parameters, execution_options)
+        return result.scalar()
+
+    async def run_sync(self, fn: Callable, *arg, **kw) -> Any:
+        """"Invoke the given sync callable passing self as the first argument.
+
+        This method maintains the asyncio event loop all the way through
+        to the database connection by running the given callable in a
+        specially instrumented greenlet.
+
+        E.g.::
+
+            with async_engine.begin() as conn:
+                await conn.run_sync(metadata.create_all)
+
+        """
+
+        conn = self._sync_connection()
+
+        return await greenlet_spawn(fn, conn, *arg, **kw)
+
+    def __await__(self):
+        return self.start().__await__()
+
+    async def __aexit__(self, type_, value, traceback):
+        await self.close()
+
+
+class AsyncEngine:
+    """An asyncio proxy for a :class:`_engine.Engine`.
+
+    :class:`_asyncio.AsyncEngine` is acquired using the
+    :func:`_asyncio.create_async_engine` function::
+
+        from sqlalchemy.ext.asyncio import create_async_engine
+        engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+    .. versionadded:: 1.4
+
+
+    """  # noqa
+
+    __slots__ = ("sync_engine",)
+
+    _connection_cls = AsyncConnection
+
+    _option_cls: type
+
+    class _trans_ctx(StartableContext):
+        def __init__(self, conn):
+            self.conn = conn
+
+        async def start(self):
+            await self.conn.start()
+            self.transaction = self.conn.begin()
+            await self.transaction.__aenter__()
+
+            return self.conn
+
+        async def __aexit__(self, type_, value, traceback):
+            if type_ is not None:
+                await self.transaction.rollback()
+            else:
+                if self.transaction.is_active:
+                    await self.transaction.commit()
+            await self.conn.close()
+
+    def __init__(self, sync_engine: Engine):
+        self.sync_engine = sync_engine
+
+    def begin(self):
+        """Return a context manager which when entered will deliver an
+        :class:`_asyncio.AsyncConnection` with an
+        :class:`_asyncio.AsyncTransaction` established.
+
+        E.g.::
+
+            async with async_engine.begin() as conn:
+                await conn.execute(
+                    text("insert into table (x, y, z) values (1, 2, 3)")
+                )
+                await conn.execute(text("my_special_procedure(5)"))
+
+
+        """
+        conn = self.connect()
+        return self._trans_ctx(conn)
+
+    def connect(self) -> AsyncConnection:
+        """Return an :class:`_asyncio.AsyncConnection` object.
+
+        The :class:`_asyncio.AsyncConnection` will procure a database
+        connection from the underlying connection pool when it is entered
+        as an async context manager::
+
+            async with async_engine.connect() as conn:
+                result = await conn.execute(select(user_table))
+
+        The :class:`_asyncio.AsyncConnection` may also be started outside of a
+        context manager by invoking its :meth:`_asyncio.AsyncConnection.start`
+        method.
+
+        """
+
+        return self._connection_cls(self.sync_engine)
+
+    async def raw_connection(self) -> Any:
+        """Return a "raw" DBAPI connection from the connection pool.
+
+        .. seealso::
+
+            :ref:`dbapi_connections`
+
+        """
+        return await greenlet_spawn(self.sync_engine.raw_connection)
+
+
+class AsyncOptionEngine(OptionEngineMixin, AsyncEngine):
+    pass
+
+
+AsyncEngine._option_cls = AsyncOptionEngine
+
+
+class AsyncTransaction(StartableContext):
+    """An asyncio proxy for a :class:`_engine.Transaction`."""
+
+    __slots__ = ("connection", "sync_transaction", "nested")
+
+    def __init__(self, connection: AsyncConnection, nested: bool = False):
+        self.connection = connection
+        self.sync_transaction: Optional[Transaction] = None
+        self.nested = nested
+
+    def _sync_transaction(self):
+        if not self.sync_transaction:
+            self._raise_for_not_started()
+        return self.sync_transaction
+
+    @property
+    def is_valid(self) -> bool:
+        return self._sync_transaction().is_valid
+
+    @property
+    def is_active(self) -> bool:
+        return self._sync_transaction().is_active
+
+    async def close(self):
+        """Close this :class:`.Transaction`.
+
+        If this transaction is the base transaction in a begin/commit
+        nesting, the transaction will rollback().  Otherwise, the
+        method returns.
+
+        This is used to cancel a Transaction without affecting the scope of
+        an enclosing transaction.
+
+        """
+        await greenlet_spawn(self._sync_transaction().close)
+
+    async def rollback(self):
+        """Roll back this :class:`.Transaction`.
+
+        """
+        await greenlet_spawn(self._sync_transaction().rollback)
+
+    async def commit(self):
+        """Commit this :class:`.Transaction`."""
+
+        await greenlet_spawn(self._sync_transaction().commit)
+
+    async def start(self):
+        """Start this :class:`_asyncio.AsyncTransaction` object's context
+        outside of using a Python ``with:`` block.
+
+        """
+
+        self.sync_transaction = await greenlet_spawn(
+            self.connection._sync_connection().begin_nested
+            if self.nested
+            else self.connection._sync_connection().begin
+        )
+        return self
+
+    async def __aexit__(self, type_, value, traceback):
+        if type_ is None and self.is_active:
+            try:
+                await self.commit()
+            except:
+                with util.safe_reraise():
+                    await self.rollback()
+        else:
+            await self.rollback()
+
+
+def _get_sync_engine(async_engine):
+    try:
+        return async_engine.sync_engine
+    except AttributeError as e:
+        raise exc.ArgumentError(
+            "AsyncEngine expected, got %r" % async_engine
+        ) from e
diff --git a/lib/sqlalchemy/ext/asyncio/exc.py b/lib/sqlalchemy/ext/asyncio/exc.py
new file mode 100644 (file)
index 0000000..6137bf6
--- /dev/null
@@ -0,0 +1,14 @@
+from ... import exc
+
+
+class AsyncMethodRequired(exc.InvalidRequestError):
+    """an API can't be used because its result would not be
+    compatible with async"""
+
+
+class AsyncContextNotStarted(exc.InvalidRequestError):
+    """a startable context manager has not been started."""
+
+
+class AsyncContextAlreadyStarted(exc.InvalidRequestError):
+    """a startable context manager is already started."""
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
new file mode 100644 (file)
index 0000000..52b40ac
--- /dev/null
@@ -0,0 +1,669 @@
+import operator
+
+from ... import util
+from ...engine.result import _NO_ROW
+from ...engine.result import FilterResult
+from ...engine.result import FrozenResult
+from ...engine.result import MergedResult
+from ...util.concurrency import greenlet_spawn
+
+if util.TYPE_CHECKING:
+    from typing import Any
+    from typing import List
+    from typing import Optional
+    from typing import Int
+    from typing import Iterator
+    from typing import Mapping
+    from ...engine.result import Row
+
+
+class AsyncResult(FilterResult):
+    """An asyncio wrapper around a :class:`_result.Result` object.
+
+    The :class:`_asyncio.AsyncResult` only applies to statement executions that
+    use a server-side cursor.  It is returned only from the
+    :meth:`_asyncio.AsyncConnection.stream` and
+    :meth:`_asyncio.AsyncSession.stream` methods.
+
+    .. versionadded:: 1.4
+
+    """
+
+    def __init__(self, real_result):
+        self._real_result = real_result
+
+        self._metadata = real_result._metadata
+        self._unique_filter_state = real_result._unique_filter_state
+
+        # BaseCursorResult pre-generates the "_row_getter".  Use that
+        # if available rather than building a second one
+        if "_row_getter" in real_result.__dict__:
+            self._set_memoized_attribute(
+                "_row_getter", real_result.__dict__["_row_getter"]
+            )
+
+    def keys(self):
+        """Return the :meth:`_engine.Result.keys` collection from the
+        underlying :class:`_engine.Result`.
+
+        """
+        return self._metadata.keys
+
+    def unique(self, strategy=None):
+        """Apply unique filtering to the objects returned by this
+        :class:`_asyncio.AsyncResult`.
+
+        Refer to :meth:`_engine.Result.unique` in the synchronous
+        SQLAlchemy API for a complete behavioral description.
+
+
+        """
+        self._unique_filter_state = (set(), strategy)
+        return self
+
+    def columns(self, *col_expressions):
+        # type: (*object) -> AsyncResult
+        r"""Establish the columns that should be returned in each row.
+
+        Refer to :meth:`_engine.Result.columns` in the synchronous
+        SQLAlchemy API for a complete behavioral description.
+
+
+        """
+        return self._column_slices(col_expressions)
+
+    async def partitions(self, size=None):
+        # type: (Optional[Int]) -> Iterator[List[Any]]
+        """Iterate through sub-lists of rows of the size given.
+
+        An async iterator is returned::
+
+            async def scroll_results(connection):
+                result = await connection.stream(select(users_table))
+
+                async for partition in result.partitions(100):
+                    print("list of rows: %s" % partition)
+
+        .. seealso::
+
+            :meth:`_engine.Result.partitions`
+
+        """
+
+        getter = self._manyrow_getter
+
+        while True:
+            partition = await greenlet_spawn(getter, self, size)
+            if partition:
+                yield partition
+            else:
+                break
+
+    async def fetchone(self):
+        # type: () -> Row
+        """Fetch one row.
+
+        When all rows are exhausted, returns None.
+
+        This method is provided for backwards compatibility with
+        SQLAlchemy 1.x.x.
+
+        To fetch the first row of a result only, use the
+        :meth:`_engine.Result.first` method.  To iterate through all
+        rows, iterate the :class:`_engine.Result` object directly.
+
+        :return: a :class:`.Row` object if no filters are applied, or None
+         if no rows remain.
+
+        """
+        row = await greenlet_spawn(self._onerow_getter, self)
+        if row is _NO_ROW:
+            return None
+        else:
+            return row
+
+    async def fetchmany(self, size=None):
+        # type: (Optional[Int]) -> List[Row]
+        """Fetch many rows.
+
+        When all rows are exhausted, returns an empty list.
+
+        This method is provided for backwards compatibility with
+        SQLAlchemy 1.x.x.
+
+        To fetch rows in groups, use the
+        :meth:`._asyncio.AsyncResult.partitions` method.
+
+        :return: a list of :class:`.Row` objects.
+
+        .. seealso::
+
+            :meth:`_asyncio.AsyncResult.partitions`
+
+        """
+
+        return await greenlet_spawn(self._manyrow_getter, self, size)
+
+    async def all(self):
+        # type: () -> List[Row]
+        """Return all rows in a list.
+
+        Closes the result set after invocation.   Subsequent invocations
+        will return an empty list.
+
+        :return: a list of :class:`.Row` objects.
+
+        """
+
+        return await greenlet_spawn(self._allrows)
+
+    def __aiter__(self):
+        return self
+
+    async def __anext__(self):
+        row = await greenlet_spawn(self._onerow_getter, self)
+        if row is _NO_ROW:
+            raise StopAsyncIteration()
+        else:
+            return row
+
+    async def first(self):
+        # type: () -> Row
+        """Fetch the first row or None if no row is present.
+
+        Closes the result set and discards remaining rows.
+
+        .. note::  This method returns one **row**, e.g. tuple, by default. To
+           return exactly one single scalar value, that is, the first column of
+           the first row, use the :meth:`_asyncio.AsyncResult.scalar` method,
+           or combine :meth:`_asyncio.AsyncResult.scalars` and
+           :meth:`_asyncio.AsyncResult.first`.
+
+        :return: a :class:`.Row` object, or None
+         if no rows remain.
+
+        .. seealso::
+
+            :meth:`_asyncio.AsyncResult.scalar`
+
+            :meth:`_asyncio.AsyncResult.one`
+
+        """
+        return await greenlet_spawn(self._only_one_row, False, False, False)
+
+    async def one_or_none(self):
+        # type: () -> Optional[Row]
+        """Return at most one result or raise an exception.
+
+        Returns ``None`` if the result has no rows.
+        Raises :class:`.MultipleResultsFound`
+        if multiple rows are returned.
+
+        .. versionadded:: 1.4
+
+        :return: The first :class:`.Row` or None if no row is available.
+
+        :raises: :class:`.MultipleResultsFound`
+
+        .. seealso::
+
+            :meth:`_asyncio.AsyncResult.first`
+
+            :meth:`_asyncio.AsyncResult.one`
+
+        """
+        return await greenlet_spawn(self._only_one_row, True, False, False)
+
+    async def scalar_one(self):
+        # type: () -> Any
+        """Return exactly one scalar result or raise an exception.
+
+        This is equvalent to calling :meth:`_asyncio.AsyncResult.scalars` and
+        then :meth:`_asyncio.AsyncResult.one`.
+
+        .. seealso::
+
+            :meth:`_asyncio.AsyncResult.one`
+
+            :meth:`_asyncio.AsyncResult.scalars`
+
+        """
+        return await greenlet_spawn(self._only_one_row, True, True, True)
+
+    async def scalar_one_or_none(self):
+        # type: () -> Optional[Any]
+        """Return exactly one or no scalar result.
+
+        This is equvalent to calling :meth:`_asyncio.AsyncResult.scalars` and
+        then :meth:`_asyncio.AsyncResult.one_or_none`.
+
+        .. seealso::
+
+            :meth:`_asyncio.AsyncResult.one_or_none`
+
+            :meth:`_asyncio.AsyncResult.scalars`
+
+        """
+        return await greenlet_spawn(self._only_one_row, True, False, True)
+
+    async def one(self):
+        # type: () -> Row
+        """Return exactly one row or raise an exception.
+
+        Raises :class:`.NoResultFound` if the result returns no
+        rows, or :class:`.MultipleResultsFound` if multiple rows
+        would be returned.
+
+        .. note::  This method returns one **row**, e.g. tuple, by default.
+           To return exactly one single scalar value, that is, the first
+           column of the first row, use the
+           :meth:`_asyncio.AsyncResult.scalar_one` method, or combine
+           :meth:`_asyncio.AsyncResult.scalars` and
+           :meth:`_asyncio.AsyncResult.one`.
+
+        .. versionadded:: 1.4
+
+        :return: The first :class:`.Row`.
+
+        :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
+
+        .. seealso::
+
+            :meth:`_asyncio.AsyncResult.first`
+
+            :meth:`_asyncio.AsyncResult.one_or_none`
+
+            :meth:`_asyncio.AsyncResult.scalar_one`
+
+        """
+        return await greenlet_spawn(self._only_one_row, True, True, False)
+
+    async def scalar(self):
+        # type: () -> Optional[Any]
+        """Fetch the first column of the first row, and close the result set.
+
+        Returns None if there are no rows to fetch.
+
+        No validation is performed to test if additional rows remain.
+
+        After calling this method, the object is fully closed,
+        e.g. the :meth:`_engine.CursorResult.close`
+        method will have been called.
+
+        :return: a Python scalar value , or None if no rows remain.
+
+        """
+        return await greenlet_spawn(self._only_one_row, False, False, True)
+
+    async def freeze(self):
+        """Return a callable object that will produce copies of this
+        :class:`_asyncio.AsyncResult` when invoked.
+
+        The callable object returned is an instance of
+        :class:`_engine.FrozenResult`.
+
+        This is used for result set caching.  The method must be called
+        on the result when it has been unconsumed, and calling the method
+        will consume the result fully.   When the :class:`_engine.FrozenResult`
+        is retrieved from a cache, it can be called any number of times where
+        it will produce a new :class:`_engine.Result` object each time
+        against its stored set of rows.
+
+        .. seealso::
+
+            :ref:`do_orm_execute_re_executing` - example usage within the
+            ORM to implement a result-set cache.
+
+        """
+
+        return await greenlet_spawn(FrozenResult, self)
+
+    def merge(self, *others):
+        """Merge this :class:`_asyncio.AsyncResult` with other compatible result
+        objects.
+
+        The object returned is an instance of :class:`_engine.MergedResult`,
+        which will be composed of iterators from the given result
+        objects.
+
+        The new result will use the metadata from this result object.
+        The subsequent result objects must be against an identical
+        set of result / cursor metadata, otherwise the behavior is
+        undefined.
+
+        """
+        return MergedResult(self._metadata, (self,) + others)
+
+    def scalars(self, index=0):
+        # type: (Int) -> AsyncScalarResult
+        """Return an :class:`_asyncio.AsyncScalarResult` filtering object which
+        will return single elements rather than :class:`_row.Row` objects.
+
+        Refer to :meth:`_result.Result.scalars` in the synchronous
+        SQLAlchemy API for a complete behavioral description.
+
+        :param index: integer or row key indicating the column to be fetched
+         from each row, defaults to ``0`` indicating the first column.
+
+        :return: a new :class:`_asyncio.AsyncScalarResult` filtering object
+         referring to this :class:`_asyncio.AsyncResult` object.
+
+        """
+        return AsyncScalarResult(self._real_result, index)
+
+    def mappings(self):
+        # type() -> AsyncMappingResult
+        """Apply a mappings filter to returned rows, returning an instance of
+        :class:`_asyncio.AsyncMappingResult`.
+
+        When this filter is applied, fetching rows will return
+        :class:`.RowMapping` objects instead of :class:`.Row` objects.
+
+        Refer to :meth:`_result.Result.mappings` in the synchronous
+        SQLAlchemy API for a complete behavioral description.
+
+        :return: a new :class:`_asyncio.AsyncMappingResult` filtering object
+         referring to the underlying :class:`_result.Result` object.
+
+        """
+
+        return AsyncMappingResult(self._real_result)
+
+
+class AsyncScalarResult(FilterResult):
+    """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
+    rather than :class:`_row.Row` values.
+
+    The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the
+    :meth:`_asyncio.AsyncResult.scalars` method.
+
+    Refer to the :class:`_result.ScalarResult` object in the synchronous
+    SQLAlchemy API for a complete behavioral description.
+
+    .. versionadded:: 1.4
+
+    """
+
+    _generate_rows = False
+
+    def __init__(self, real_result, index):
+        self._real_result = real_result
+
+        if real_result._source_supports_scalars:
+            self._metadata = real_result._metadata
+            self._post_creational_filter = None
+        else:
+            self._metadata = real_result._metadata._reduce([index])
+            self._post_creational_filter = operator.itemgetter(0)
+
+        self._unique_filter_state = real_result._unique_filter_state
+
+    def unique(self, strategy=None):
+        # type: () -> AsyncScalarResult
+        """Apply unique filtering to the objects returned by this
+        :class:`_asyncio.AsyncScalarResult`.
+
+        See :meth:`_asyncio.AsyncResult.unique` for usage details.
+
+        """
+        self._unique_filter_state = (set(), strategy)
+        return self
+
+    async def partitions(self, size=None):
+        # type: (Optional[Int]) -> Iterator[List[Any]]
+        """Iterate through sub-lists of elements of the size given.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
+        scalar values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+
+        getter = self._manyrow_getter
+
+        while True:
+            partition = await greenlet_spawn(getter, self, size)
+            if partition:
+                yield partition
+            else:
+                break
+
+    async def fetchall(self):
+        # type: () -> List[Any]
+        """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
+
+        return await greenlet_spawn(self._allrows)
+
+    async def fetchmany(self, size=None):
+        # type: (Optional[Int]) -> List[Any]
+        """Fetch many objects.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
+        scalar values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+        return await greenlet_spawn(self._manyrow_getter, self, size)
+
+    async def all(self):
+        # type: () -> List[Any]
+        """Return all scalar values in a list.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.all` except that
+        scalar values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+        return await greenlet_spawn(self._allrows)
+
+    def __aiter__(self):
+        return self
+
+    async def __anext__(self):
+        row = await greenlet_spawn(self._onerow_getter, self)
+        if row is _NO_ROW:
+            raise StopAsyncIteration()
+        else:
+            return row
+
+    async def first(self):
+        # type: () -> Optional[Any]
+        """Fetch the first object or None if no object is present.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.first` except that
+        scalar values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+        return await greenlet_spawn(self._only_one_row, False, False, False)
+
+    async def one_or_none(self):
+        # type: () -> Optional[Any]
+        """Return at most one object or raise an exception.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
+        scalar values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+        return await greenlet_spawn(self._only_one_row, True, False, False)
+
+    async def one(self):
+        # type: () -> Any
+        """Return exactly one object or raise an exception.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.one` except that
+        scalar values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+        return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+class AsyncMappingResult(FilterResult):
+    """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values
+    rather than :class:`_engine.Row` values.
+
+    The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the
+    :meth:`_asyncio.AsyncResult.mappings` method.
+
+    Refer to the :class:`_result.MappingResult` object in the synchronous
+    SQLAlchemy API for a complete behavioral description.
+
+    .. versionadded:: 1.4
+
+    """
+
+    _generate_rows = True
+
+    _post_creational_filter = operator.attrgetter("_mapping")
+
+    def __init__(self, result):
+        self._real_result = result
+        self._unique_filter_state = result._unique_filter_state
+        self._metadata = result._metadata
+        if result._source_supports_scalars:
+            self._metadata = self._metadata._reduce([0])
+
+    def keys(self):
+        """Return an iterable view which yields the string keys that would
+        be represented by each :class:`.Row`.
+
+        The view also can be tested for key containment using the Python
+        ``in`` operator, which will test both for the string keys represented
+        in the view, as well as for alternate keys such as column objects.
+
+        .. versionchanged:: 1.4 a key view object is returned rather than a
+           plain list.
+
+
+        """
+        return self._metadata.keys
+
+    def unique(self, strategy=None):
+        # type: () -> AsyncMappingResult
+        """Apply unique filtering to the objects returned by this
+        :class:`_asyncio.AsyncMappingResult`.
+
+        See :meth:`_asyncio.AsyncResult.unique` for usage details.
+
+        """
+        self._unique_filter_state = (set(), strategy)
+        return self
+
+    def columns(self, *col_expressions):
+        # type: (*object) -> AsyncMappingResult
+        r"""Establish the columns that should be returned in each row.
+
+
+        """
+        return self._column_slices(col_expressions)
+
+    async def partitions(self, size=None):
+        # type: (Optional[Int]) -> Iterator[List[Mapping]]
+        """Iterate through sub-lists of elements of the size given.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
+        mapping values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+
+        getter = self._manyrow_getter
+
+        while True:
+            partition = await greenlet_spawn(getter, self, size)
+            if partition:
+                yield partition
+            else:
+                break
+
+    async def fetchall(self):
+        # type: () -> List[Mapping]
+        """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
+
+        return await greenlet_spawn(self._allrows)
+
+    async def fetchone(self):
+        # type: () -> Mapping
+        """Fetch one object.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
+        mapping values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+
+        row = await greenlet_spawn(self._onerow_getter, self)
+        if row is _NO_ROW:
+            return None
+        else:
+            return row
+
+    async def fetchmany(self, size=None):
+        # type: (Optional[Int]) -> List[Mapping]
+        """Fetch many objects.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
+        mapping values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+
+        return await greenlet_spawn(self._manyrow_getter, self, size)
+
+    async def all(self):
+        # type: () -> List[Mapping]
+        """Return all scalar values in a list.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.all` except that
+        mapping values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+
+        return await greenlet_spawn(self._allrows)
+
+    def __aiter__(self):
+        return self
+
+    async def __anext__(self):
+        row = await greenlet_spawn(self._onerow_getter, self)
+        if row is _NO_ROW:
+            raise StopAsyncIteration()
+        else:
+            return row
+
+    async def first(self):
+        # type: () -> Optional[Mapping]
+        """Fetch the first object or None if no object is present.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.first` except that
+        mapping values, rather than :class:`_result.Row` objects,
+        are returned.
+
+
+        """
+        return await greenlet_spawn(self._only_one_row, False, False, False)
+
+    async def one_or_none(self):
+        # type: () -> Optional[Mapping]
+        """Return at most one object or raise an exception.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
+        mapping values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+        return await greenlet_spawn(self._only_one_row, True, False, False)
+
+    async def one(self):
+        # type: () -> Mapping
+        """Return exactly one object or raise an exception.
+
+        Equivalent to :meth:`_asyncio.AsyncResult.one` except that
+        mapping values, rather than :class:`_result.Row` objects,
+        are returned.
+
+        """
+        return await greenlet_spawn(self._only_one_row, True, True, False)
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
new file mode 100644 (file)
index 0000000..1673017
--- /dev/null
@@ -0,0 +1,293 @@
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Mapping
+from typing import Optional
+
+from . import engine
+from . import result as _result
+from .base import StartableContext
+from .engine import AsyncEngine
+from ... import util
+from ...engine import Result
+from ...orm import Session
+from ...sql import Executable
+from ...util.concurrency import greenlet_spawn
+
+
+class AsyncSession:
+    """Asyncio version of :class:`_orm.Session`.
+
+
+    .. versionadded:: 1.4
+
+    """
+
+    def __init__(
+        self,
+        bind: AsyncEngine = None,
+        binds: Mapping[object, AsyncEngine] = None,
+        **kw
+    ):
+        kw["future"] = True
+        if bind:
+            bind = engine._get_sync_engine(bind)
+
+        if binds:
+            binds = {
+                key: engine._get_sync_engine(b) for key, b in binds.items()
+            }
+
+        self.sync_session = Session(bind=bind, binds=binds, **kw)
+
+    def add(self, instance: object) -> None:
+        """Place an object in this :class:`_asyncio.AsyncSession`.
+
+        .. seealso::
+
+            :meth:`_orm.Session.add`
+
+        """
+        self.sync_session.add(instance)
+
+    def add_all(self, instances: List[object]) -> None:
+        """Add the given collection of instances to this
+        :class:`_asyncio.AsyncSession`."""
+
+        self.sync_session.add_all(instances)
+
+    def expire_all(self):
+        """Expires all persistent instances within this Session.
+
+        See :meth:`_orm.Session.expire_all` for usage details.
+
+        """
+        self.sync_session.expire_all()
+
+    def expire(self, instance, attribute_names=None):
+        """Expire the attributes on an instance.
+
+        See :meth:`._orm.Session.expire` for usage details.
+
+        """
+        self.sync_session.expire()
+
+    async def refresh(
+        self, instance, attribute_names=None, with_for_update=None
+    ):
+        """Expire and refresh the attributes on the given instance.
+
+        A query will be issued to the database and all attributes will be
+        refreshed with their current database value.
+
+        This is the async version of the :meth:`_orm.Session.refresh` method.
+        See that method for a complete description of all options.
+
+        """
+
+        return await greenlet_spawn(
+            self.sync_session.refresh,
+            instance,
+            attribute_names=attribute_names,
+            with_for_update=with_for_update,
+        )
+
+    async def run_sync(self, fn: Callable, *arg, **kw) -> Any:
+        """Invoke the given sync callable passing sync self as the first
+        argument.
+
+        This method maintains the asyncio event loop all the way through
+        to the database connection by running the given callable in a
+        specially instrumented greenlet.
+
+        E.g.::
+
+            with AsyncSession(async_engine) as session:
+                await session.run_sync(some_business_method)
+
+        """
+
+        return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
+
+    async def execute(
+        self,
+        statement: Executable,
+        params: Optional[Mapping] = None,
+        execution_options: Mapping = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping] = None,
+        **kw
+    ) -> Result:
+        """Execute a statement and return a buffered
+        :class:`_engine.Result` object."""
+
+        execution_options = execution_options.union({"prebuffer_rows": True})
+
+        return await greenlet_spawn(
+            self.sync_session.execute,
+            statement,
+            params=params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            **kw
+        )
+
+    async def stream(
+        self,
+        statement,
+        params=None,
+        execution_options=util.EMPTY_DICT,
+        bind_arguments=None,
+        **kw
+    ):
+        """Execute a statement and return a streaming
+        :class:`_asyncio.AsyncResult` object."""
+
+        execution_options = execution_options.union({"stream_results": True})
+
+        result = await greenlet_spawn(
+            self.sync_session.execute,
+            statement,
+            params=params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            **kw
+        )
+        return _result.AsyncResult(result)
+
+    async def merge(self, instance, load=True):
+        """Copy the state of a given instance into a corresponding instance
+        within this :class:`_asyncio.AsyncSession`.
+
+        """
+        return await greenlet_spawn(
+            self.sync_session.merge, instance, load=load
+        )
+
+    async def flush(self, objects=None):
+        """Flush all the object changes to the database.
+
+        .. seealso::
+
+            :meth:`_orm.Session.flush`
+
+        """
+        await greenlet_spawn(self.sync_session.flush, objects=objects)
+
+    async def connection(self):
+        r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this
+        :class:`.Session` object's transactional state.
+
+        """
+        sync_connection = await greenlet_spawn(self.sync_session.connection)
+        return engine.AsyncConnection(sync_connection.engine, sync_connection)
+
+    def begin(self, **kw):
+        """Return an :class:`_asyncio.AsyncSessionTransaction` object.
+
+        The underlying :class:`_orm.Session` will perform the
+        "begin" action when the :class:`_asyncio.AsyncSessionTransaction`
+        object is entered::
+
+            async with async_session.begin():
+                # .. ORM transaction is begun
+
+        Note that database IO will not normally occur when the session-level
+        transaction is begun, as database transactions begin on an
+        on-demand basis.  However, the begin block is async to accommodate
+        for a :meth:`_orm.SessionEvents.after_transaction_create`
+        event hook that may perform IO.
+
+        For a general description of ORM begin, see
+        :meth:`_orm.Session.begin`.
+
+        """
+
+        return AsyncSessionTransaction(self)
+
+    def begin_nested(self, **kw):
+        """Return an :class:`_asyncio.AsyncSessionTransaction` object
+        which will begin a "nested" transaction, e.g. SAVEPOINT.
+
+        Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`.
+
+        For a general description of ORM begin nested, see
+        :meth:`_orm.Session.begin_nested`.
+
+        """
+
+        return AsyncSessionTransaction(self, nested=True)
+
+    async def rollback(self):
+        return await greenlet_spawn(self.sync_session.rollback)
+
+    async def commit(self):
+        return await greenlet_spawn(self.sync_session.commit)
+
+    async def close(self):
+        return await greenlet_spawn(self.sync_session.close)
+
+    async def __aenter__(self):
+        return self
+
+    async def __aexit__(self, type_, value, traceback):
+        await self.close()
+
+
+class AsyncSessionTransaction(StartableContext):
+    """A wrapper for the ORM :class:`_orm.SessionTransaction` object.
+
+    This object is provided so that a transaction-holding object
+    for the :meth:`_asyncio.AsyncSession.begin` may be returned.
+
+    The object supports both explicit calls to
+    :meth:`_asyncio.AsyncSessionTransaction.commit` and
+    :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an
+    async context manager.
+
+
+    .. versionadded:: 1.4
+
+    """
+
+    __slots__ = ("session", "sync_transaction", "nested")
+
+    def __init__(self, session, nested=False):
+        self.session = session
+        self.nested = nested
+        self.sync_transaction = None
+
+    @property
+    def is_active(self):
+        return (
+            self._sync_transaction() is not None
+            and self._sync_transaction().is_active
+        )
+
+    def _sync_transaction(self):
+        if not self.sync_transaction:
+            self._raise_for_not_started()
+        return self.sync_transaction
+
+    async def rollback(self):
+        """Roll back this :class:`_asyncio.AsyncTransaction`.
+
+        """
+        await greenlet_spawn(self._sync_transaction().rollback)
+
+    async def commit(self):
+        """Commit this :class:`_asyncio.AsyncTransaction`."""
+
+        await greenlet_spawn(self._sync_transaction().commit)
+
+    async def start(self):
+        self.sync_transaction = await greenlet_spawn(
+            self.session.sync_session.begin_nested
+            if self.nested
+            else self.session.sync_session.begin
+        )
+        return self
+
+    async def __aexit__(self, type_, value, traceback):
+        return await greenlet_spawn(
+            self._sync_transaction().__exit__, type_, value, traceback
+        )
index 37ce46e4770106c05d1907d0350a1acd2a8ddd1c..b07b9b040b822ed5e6c26727d83c34e6365be604 100644 (file)
@@ -14,4 +14,5 @@ from .engine import Engine  # noqa
 from ..sql.selectable import Select  # noqa
 from ..util.langhelpers import public_factory
 
+
 select = public_factory(Select._create_future_select, ".future.select")
index d5922daa3226d8746612b382706152375d1e6f61..dd72360eda46ea5afd33644351027c8b69336582 100644 (file)
@@ -359,6 +359,22 @@ class Engine(_LegacyEngine):
             execution_options=legacy_engine._execution_options,
         )
 
+    class _trans_ctx(object):
+        def __init__(self, conn):
+            self.conn = conn
+
+        def __enter__(self):
+            self.transaction = self.conn.begin()
+            return self.conn
+
+        def __exit__(self, type_, value, traceback):
+            if type_ is not None:
+                self.transaction.rollback()
+            else:
+                if self.transaction.is_active:
+                    self.transaction.commit()
+            self.conn.close()
+
     def begin(self):
         """Return a :class:`_future.Connection` object with a transaction
         begun.
@@ -381,7 +397,8 @@ class Engine(_LegacyEngine):
             :meth:`_future.Connection.begin`
 
         """
-        return super(Engine, self).begin()
+        conn = self.connect()
+        return self._trans_ctx(conn)
 
     def connect(self):
         """Return a new :class:`_future.Connection` object.
index 2eb3e1368992c1238dcf28d2ff597fccde791b88..fd3e92055b8bf439b0926a3b4aebb50815ed381e 100644 (file)
@@ -125,8 +125,21 @@ def instances(cursor, context):
             if not yield_per:
                 break
 
+    if context.execution_options.get("prebuffer_rows", False):
+        # this is a bit of a hack at the moment.
+        # I would rather have some option in the result to pre-buffer
+        # internally.
+        _prebuffered = list(chunks(None))
+
+        def chunks(size):
+            return iter(_prebuffered)
+
     result = ChunkedIteratorResult(
-        row_metadata, chunks, source_supports_scalars=single_entity, raw=cursor
+        row_metadata,
+        chunks,
+        source_supports_scalars=single_entity,
+        raw=cursor,
+        dynamic_yield_per=cursor.context._is_server_side,
     )
 
     result._attributes = result._attributes.union(
index 7c254c61bd9a0305b620d76371cafc7253177797..676dd438c189dd9d0a797c12a93c3e9df7bcdaa0 100644 (file)
@@ -1211,7 +1211,6 @@ def _emit_insert_statements(
                     has_all_pks,
                     has_all_defaults,
                 ) in records:
-
                     if value_params:
                         result = connection.execute(
                             statement.values(value_params), params
index eb0d3751733e5b7c32ebeeba463393533d0481c1..353f34333c94172c889798b38206e2d442e47573 100644 (file)
@@ -28,6 +28,7 @@ from .base import reset_rollback
 from .dbapi_proxy import clear_managers
 from .dbapi_proxy import manage
 from .impl import AssertionPool
+from .impl import AsyncAdaptedQueuePool
 from .impl import NullPool
 from .impl import QueuePool
 from .impl import SingletonThreadPool
@@ -44,6 +45,7 @@ __all__ = [
     "AssertionPool",
     "NullPool",
     "QueuePool",
+    "AsyncAdaptedQueuePool",
     "SingletonThreadPool",
     "StaticPool",
 ]
index 0fe7612b92ff7834dd51d43548c546db0857e7f4..e1a9f00db186e687de442d551c63372281aeeb98 100644 (file)
@@ -33,6 +33,8 @@ class QueuePool(Pool):
 
     """
 
+    _queue_class = sqla_queue.Queue
+
     def __init__(
         self,
         creator,
@@ -95,7 +97,7 @@ class QueuePool(Pool):
 
         """
         Pool.__init__(self, creator, **kw)
-        self._pool = sqla_queue.Queue(pool_size, use_lifo=use_lifo)
+        self._pool = self._queue_class(pool_size, use_lifo=use_lifo)
         self._overflow = 0 - pool_size
         self._max_overflow = max_overflow
         self._timeout = timeout
@@ -215,6 +217,10 @@ class QueuePool(Pool):
         return self._pool.maxsize - self._pool.qsize() + self._overflow
 
 
+class AsyncAdaptedQueuePool(QueuePool):
+    _queue_class = sqla_queue.AsyncAdaptedQueue
+
+
 class NullPool(Pool):
 
     """A Pool which does not pool connections.
index 2fe6f35d2b138cd08eec62b3507b69ec5fb512ca..8f6dc8e72d99fadeb3b0a3684971bf7e81d7240f 100644 (file)
@@ -5,6 +5,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+from .base import Executable  # noqa
 from .compiler import COLLECT_CARTESIAN_PRODUCTS  # noqa
 from .compiler import FROM_LINTING  # noqa
 from .compiler import NO_LINTING  # noqa
index 186f885d8aac9c31ee3878047b1e745fada175c8..64663a6b01724075cd3024ac556cded85791f68b 100644 (file)
@@ -2316,6 +2316,22 @@ class JSON(Indexable, TypeEngine):
 
         """
 
+    class JSONIntIndexType(JSONIndexType):
+        """Placeholder for the datatype of a JSON index value.
+
+        This allows execution-time processing of JSON index values
+        for special syntaxes.
+
+        """
+
+    class JSONStrIndexType(JSONIndexType):
+        """Placeholder for the datatype of a JSON index value.
+
+        This allows execution-time processing of JSON index values
+        for special syntaxes.
+
+        """
+
     class JSONPathType(JSONElementType):
         """Placeholder type for JSON path operations.
 
@@ -2346,7 +2362,9 @@ class JSON(Indexable, TypeEngine):
                     index,
                     expr=self.expr,
                     operator=operators.json_getitem_op,
-                    bindparam_type=JSON.JSONIndexType,
+                    bindparam_type=JSON.JSONIntIndexType
+                    if isinstance(index, int)
+                    else JSON.JSONStrIndexType,
                 )
                 operator = operators.json_getitem_op
 
index 79b7f9eb3d6226ff6b43a71339dc789f2a1d97cd..9b1164874aab671c546b0bde2993576738440bbb 100644 (file)
@@ -12,7 +12,6 @@ from .assertions import assert_raises  # noqa
 from .assertions import assert_raises_context_ok  # noqa
 from .assertions import assert_raises_message  # noqa
 from .assertions import assert_raises_message_context_ok  # noqa
-from .assertions import assert_raises_return  # noqa
 from .assertions import AssertsCompiledSQL  # noqa
 from .assertions import AssertsExecutionResults  # noqa
 from .assertions import ComparesTables  # noqa
@@ -23,6 +22,8 @@ from .assertions import eq_ignore_whitespace  # noqa
 from .assertions import eq_regex  # noqa
 from .assertions import expect_deprecated  # noqa
 from .assertions import expect_deprecated_20  # noqa
+from .assertions import expect_raises  # noqa
+from .assertions import expect_raises_message  # noqa
 from .assertions import expect_warnings  # noqa
 from .assertions import in_  # noqa
 from .assertions import is_  # noqa
@@ -35,6 +36,7 @@ from .assertions import ne_  # noqa
 from .assertions import not_in_  # noqa
 from .assertions import startswith_  # noqa
 from .assertions import uses_deprecated  # noqa
+from .config import async_test  # noqa
 from .config import combinations  # noqa
 from .config import db  # noqa
 from .config import fixture  # noqa
index ecc6a4ab830045f2894ba7156041505450910cb0..fe74be8235feee67a24be13ed74c12c22cc05461 100644 (file)
@@ -298,10 +298,6 @@ def assert_raises_context_ok(except_cls, callable_, *args, **kw):
     return _assert_raises(except_cls, callable_, args, kw,)
 
 
-def assert_raises_return(except_cls, callable_, *args, **kw):
-    return _assert_raises(except_cls, callable_, args, kw, check_context=True)
-
-
 def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
     return _assert_raises(
         except_cls, callable_, args, kwargs, msg=msg, check_context=True
@@ -317,14 +313,26 @@ def assert_raises_message_context_ok(
 def _assert_raises(
     except_cls, callable_, args, kwargs, msg=None, check_context=False
 ):
-    ret_err = None
+
+    with _expect_raises(except_cls, msg, check_context) as ec:
+        callable_(*args, **kwargs)
+    return ec.error
+
+
+class _ErrorContainer(object):
+    error = None
+
+
+@contextlib.contextmanager
+def _expect_raises(except_cls, msg=None, check_context=False):
+    ec = _ErrorContainer()
     if check_context:
         are_we_already_in_a_traceback = sys.exc_info()[0]
     try:
-        callable_(*args, **kwargs)
+        yield ec
         success = False
     except except_cls as err:
-        ret_err = err
+        ec.error = err
         success = True
         if msg is not None:
             assert re.search(
@@ -337,7 +345,13 @@ def _assert_raises(
     # assert outside the block so it works for AssertionError too !
     assert success, "Callable did not raise an exception"
 
-    return ret_err
+
+def expect_raises(except_cls):
+    return _expect_raises(except_cls, check_context=True)
+
+
+def expect_raises_message(except_cls, msg):
+    return _expect_raises(except_cls, msg=msg, check_context=True)
 
 
 class AssertsCompiledSQL(object):
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py
new file mode 100644 (file)
index 0000000..2e274de
--- /dev/null
@@ -0,0 +1,14 @@
+from .assertions import assert_raises as _assert_raises
+from .assertions import assert_raises_message as _assert_raises_message
+from ..util import await_fallback as await_
+from ..util import greenlet_spawn
+
+
+async def assert_raises_async(except_cls, msg, coroutine):
+    await greenlet_spawn(_assert_raises, except_cls, await_, coroutine)
+
+
+async def assert_raises_message_async(except_cls, msg, coroutine):
+    await greenlet_spawn(
+        _assert_raises_message, except_cls, msg, await_, coroutine
+    )
index e97821d722950f0bcbf5dc4d07aecfabaadd799f..8c232f3198d5105d2613c65a2344232f3e37e485 100644 (file)
@@ -178,3 +178,7 @@ class Config(object):
 
 def skip_test(msg):
     raise _fixture_functions.skip_test_exception(msg)
+
+
+def async_test(fn):
+    return _fixture_functions.async_test(fn)
index 1583147d47dfcc2063c0b2a41664a4a38d2f6059..85d3374de182482558be4d69c6a8cd14000699d0 100644 (file)
@@ -61,6 +61,7 @@ class TestBase(object):
     @config.fixture()
     def connection(self):
         eng = getattr(self, "bind", config.db)
+
         conn = eng.connect()
         trans = conn.begin()
         try:
index b31a4ff3e378fdb4bbfc49758b08c33de4645ea2..49ff0f9757fb3705316a6bcd412a720db106159e 100644 (file)
@@ -48,7 +48,6 @@ testing = None
 util = None
 file_config = None
 
-
 logging = None
 include_tags = set()
 exclude_tags = set()
@@ -193,6 +192,12 @@ def setup_options(make_option):
         default=False,
         help="Unconditionally write/update profiling data.",
     )
+    make_option(
+        "--dump-pyannotate",
+        type=str,
+        dest="dump_pyannotate",
+        help="Run pyannotate and dump json info to given file",
+    )
 
 
 def configure_follower(follower_ident):
@@ -378,7 +383,6 @@ def _engine_uri(options, file_config):
         cfg = provision.setup_config(
             db_url, options, file_config, provision.FOLLOWER_IDENT
         )
-
         if not config._current:
             cfg.set_as_current(cfg, testing)
 
index 015598952db6b82c5055121db34593978fe18327..3df239afa950028f8a2766e2ed854d3f9c34d775 100644 (file)
@@ -25,6 +25,11 @@ else:
     if typing.TYPE_CHECKING:
         from typing import Sequence
 
+try:
+    import asyncio
+except ImportError:
+    pass
+
 try:
     import xdist  # noqa
 
@@ -101,6 +106,24 @@ def pytest_configure(config):
 
     plugin_base.set_fixture_functions(PytestFixtureFunctions)
 
+    if config.option.dump_pyannotate:
+        global DUMP_PYANNOTATE
+        DUMP_PYANNOTATE = True
+
+
+DUMP_PYANNOTATE = False
+
+
+@pytest.fixture(autouse=True)
+def collect_types_fixture():
+    if DUMP_PYANNOTATE:
+        from pyannotate_runtime import collect_types
+
+        collect_types.start()
+    yield
+    if DUMP_PYANNOTATE:
+        collect_types.stop()
+
 
 def pytest_sessionstart(session):
     plugin_base.post_begin()
@@ -109,6 +132,31 @@ def pytest_sessionstart(session):
 def pytest_sessionfinish(session):
     plugin_base.final_process_cleanup()
 
+    if session.config.option.dump_pyannotate:
+        from pyannotate_runtime import collect_types
+
+        collect_types.dump_stats(session.config.option.dump_pyannotate)
+
+
+def pytest_collection_finish(session):
+    if session.config.option.dump_pyannotate:
+        from pyannotate_runtime import collect_types
+
+        lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
+
+        def _filter(filename):
+            filename = os.path.normpath(os.path.abspath(filename))
+            if "lib/sqlalchemy" not in os.path.commonpath(
+                [filename, lib_sqlalchemy]
+            ):
+                return None
+            if "testing" in filename:
+                return None
+
+            return filename
+
+        collect_types.init_types_collection(filter_filename=_filter)
+
 
 if has_xdist:
     import uuid
@@ -518,3 +566,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
 
     def get_current_test_name(self):
         return os.environ.get("PYTEST_CURRENT_TEST")
+
+    def async_test(self, fn):
+        @_pytest_fn_decorator
+        def decorate(fn, *args, **kwargs):
+            asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs))
+
+        return decorate(fn)
index 25998c07bbea1c634ac0e155a282ceb00697e747..36d0ce4c61deb6c8967e7645e35cf6d49f9a75a5 100644 (file)
@@ -1193,6 +1193,12 @@ class SuiteRequirements(Requirements):
         except ImportError:
             return False
 
+    @property
+    def async_dialect(self):
+        """dialect makes use of await_() to invoke operations on the DBAPI."""
+
+        return exclusions.closed()
+
     @property
     def computed_columns(self):
         "Supports computed columns"
index 2eb986c74aaca7124d8aa1dfda199eb831dc3507..e6f6068c8933c4386c3a6a8cfa210fa0de356b0b 100644 (file)
@@ -238,6 +238,8 @@ class ServerSideCursorsTest(
         elif self.engine.dialect.driver == "mysqldb":
             sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
             return isinstance(cursor, sscursor)
+        elif self.engine.dialect.driver == "asyncpg":
+            return cursor.server_side
         else:
             return False
 
@@ -331,29 +333,74 @@ class ServerSideCursorsTest(
             result.close()
 
     @testing.provide_metadata
-    def test_roundtrip(self):
+    def test_roundtrip_fetchall(self):
         md = self.metadata
 
-        self._fixture(True)
+        engine = self._fixture(True)
         test_table = Table(
             "test_table",
             md,
             Column("id", Integer, primary_key=True),
             Column("data", String(50)),
         )
-        test_table.create(checkfirst=True)
-        test_table.insert().execute(data="data1")
-        test_table.insert().execute(data="data2")
-        eq_(
-            test_table.select().order_by(test_table.c.id).execute().fetchall(),
-            [(1, "data1"), (2, "data2")],
-        )
-        test_table.update().where(test_table.c.id == 2).values(
-            data=test_table.c.data + " updated"
-        ).execute()
-        eq_(
-            test_table.select().order_by(test_table.c.id).execute().fetchall(),
-            [(1, "data1"), (2, "data2 updated")],
+
+        with engine.connect() as connection:
+            test_table.create(connection, checkfirst=True)
+            connection.execute(test_table.insert(), dict(data="data1"))
+            connection.execute(test_table.insert(), dict(data="data2"))
+            eq_(
+                connection.execute(
+                    test_table.select().order_by(test_table.c.id)
+                ).fetchall(),
+                [(1, "data1"), (2, "data2")],
+            )
+            connection.execute(
+                test_table.update()
+                .where(test_table.c.id == 2)
+                .values(data=test_table.c.data + " updated")
+            )
+            eq_(
+                connection.execute(
+                    test_table.select().order_by(test_table.c.id)
+                ).fetchall(),
+                [(1, "data1"), (2, "data2 updated")],
+            )
+            connection.execute(test_table.delete())
+            eq_(
+                connection.scalar(
+                    select([func.count("*")]).select_from(test_table)
+                ),
+                0,
+            )
+
+    @testing.provide_metadata
+    def test_roundtrip_fetchmany(self):
+        md = self.metadata
+
+        engine = self._fixture(True)
+        test_table = Table(
+            "test_table",
+            md,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(50)),
         )
-        test_table.delete().execute()
-        eq_(select([func.count("*")]).select_from(test_table).scalar(), 0)
+
+        with engine.connect() as connection:
+            test_table.create(connection, checkfirst=True)
+            connection.execute(
+                test_table.insert(),
+                [dict(data="data%d" % i) for i in range(1, 20)],
+            )
+
+            result = connection.execute(
+                test_table.select().order_by(test_table.c.id)
+            )
+
+            eq_(
+                result.fetchmany(5), [(i, "data%d" % i) for i in range(1, 6)],
+            )
+            eq_(
+                result.fetchmany(10),
+                [(i, "data%d" % i) for i in range(6, 16)],
+            )
+            eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
index 48144f8859bb427f2e5faab9418b705f2232ecf4..5e6ac1eabd913b216326b10ad5649aa08d805aa9 100644 (file)
@@ -35,6 +35,7 @@ from ... import Text
 from ... import Time
 from ... import TIMESTAMP
 from ... import type_coerce
+from ... import TypeDecorator
 from ... import Unicode
 from ... import UnicodeText
 from ... import util
@@ -282,6 +283,9 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
 
     @classmethod
     def define_tables(cls, metadata):
+        class Decorated(TypeDecorator):
+            impl = cls.datatype
+
         Table(
             "date_table",
             metadata,
@@ -289,6 +293,7 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
             ),
             Column("date_data", cls.datatype),
+            Column("decorated_date_data", Decorated),
         )
 
     def test_round_trip(self, connection):
@@ -302,6 +307,21 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
         eq_(row, (compare,))
         assert isinstance(row[0], type(compare))
 
+    def test_round_trip_decorated(self, connection):
+        date_table = self.tables.date_table
+
+        connection.execute(
+            date_table.insert(), {"decorated_date_data": self.data}
+        )
+
+        row = connection.execute(
+            select(date_table.c.decorated_date_data)
+        ).first()
+
+        compare = self.compare or self.data
+        eq_(row, (compare,))
+        assert isinstance(row[0], type(compare))
+
     def test_null(self, connection):
         date_table = self.tables.date_table
 
@@ -526,6 +546,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
             Float(precision=8, asdecimal=True),
             [15.7563, decimal.Decimal("15.7563"), None],
             [decimal.Decimal("15.7563"), None],
+            filter_=lambda n: n is not None and round(n, 4) or None,
         )
 
     def test_float_as_float(self):
@@ -777,6 +798,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
             #        ("json", {"foo": "bar"}),
             id_="sa",
         )(fn)
+
         return fn
 
     @_index_fixtures
@@ -1139,7 +1161,15 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest):
             and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6"
         )
 
-    def test_crit_against_string_coerce_type(self):
+    def test_crit_against_int_basic(self):
+        name = self.tables.data_table.c.name
+        col = self.tables.data_table.c["data"]
+
+        self._test_index_criteria(
+            and_(name == "r6", cast(col["a"], String) == "5"), "r6"
+        )
+
+    def _dont_test_crit_against_string_coerce_type(self):
         name = self.tables.data_table.c.name
         col = self.tables.data_table.c["data"]
 
@@ -1152,15 +1182,7 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest):
             test_literal=False,
         )
 
-    def test_crit_against_int_basic(self):
-        name = self.tables.data_table.c.name
-        col = self.tables.data_table.c["data"]
-
-        self._test_index_criteria(
-            and_(name == "r6", cast(col["a"], String) == "5"), "r6"
-        )
-
-    def test_crit_against_int_coerce_type(self):
+    def _dont_test_crit_against_int_coerce_type(self):
         name = self.tables.data_table.c.name
         col = self.tables.data_table.c["data"]
 
index ce96027455261db603ec7cd1a47772e20e40ed4d..1e3eb9a29e5d5c8521ce5fb842190ae1b70ddcb9 100644 (file)
@@ -90,6 +90,10 @@ from .compat import unquote_plus  # noqa
 from .compat import win32  # noqa
 from .compat import with_metaclass  # noqa
 from .compat import zip_longest  # noqa
+from .concurrency import asyncio  # noqa
+from .concurrency import await_fallback  # noqa
+from .concurrency import await_only  # noqa
+from .concurrency import greenlet_spawn  # noqa
 from .deprecations import deprecated  # noqa
 from .deprecations import deprecated_20  # noqa
 from .deprecations import deprecated_20_cls  # noqa
diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py
new file mode 100644 (file)
index 0000000..3b112ff
--- /dev/null
@@ -0,0 +1,110 @@
+import asyncio
+import sys
+from typing import Any
+from typing import Callable
+from typing import Coroutine
+
+from .. import exc
+
+try:
+    import greenlet
+
+    # implementation based on snaury gist at
+    # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
+    # Issue for context: https://github.com/python-greenlet/greenlet/issues/173
+
+    class _AsyncIoGreenlet(greenlet.greenlet):
+        def __init__(self, fn, driver):
+            greenlet.greenlet.__init__(self, fn, driver)
+            self.driver = driver
+
+    def await_only(awaitable: Coroutine) -> Any:
+        """Awaits an async function in a sync method.
+
+        The sync method must be insice a :func:`greenlet_spawn` context.
+        :func:`await_` calls cannot be nested.
+
+        :param awaitable: The coroutine to call.
+
+        """
+        # this is called in the context greenlet while running fn
+        current = greenlet.getcurrent()
+        if not isinstance(current, _AsyncIoGreenlet):
+            raise exc.InvalidRequestError(
+                "greenlet_spawn has not been called; can't call await_() here."
+            )
+
+        # returns the control to the driver greenlet passing it
+        # a coroutine to run. Once the awaitable is done, the driver greenlet
+        # switches back to this greenlet with the result of awaitable that is
+        # then returned to the caller (or raised as error)
+        return current.driver.switch(awaitable)
+
+    def await_fallback(awaitable: Coroutine) -> Any:
+        """Awaits an async function in a sync method.
+
+        The sync method must be insice a :func:`greenlet_spawn` context.
+        :func:`await_` calls cannot be nested.
+
+        :param awaitable: The coroutine to call.
+
+        """
+        # this is called in the context greenlet while running fn
+        current = greenlet.getcurrent()
+        if not isinstance(current, _AsyncIoGreenlet):
+            loop = asyncio.get_event_loop()
+            if loop.is_running():
+                raise exc.InvalidRequestError(
+                    "greenlet_spawn has not been called and asyncio event "
+                    "loop is already running; can't call await_() here."
+                )
+            return loop.run_until_complete(awaitable)
+
+        return current.driver.switch(awaitable)
+
+    async def greenlet_spawn(fn: Callable, *args, **kwargs) -> Any:
+        """Runs a sync function ``fn`` in a new greenlet.
+
+        The sync function can then use :func:`await_` to wait for async
+        functions.
+
+        :param fn: The sync callable to call.
+        :param \\*args: Positional arguments to pass to the ``fn`` callable.
+        :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
+        """
+        context = _AsyncIoGreenlet(fn, greenlet.getcurrent())
+        # runs the function synchronously in gl greenlet. If the execution
+        # is interrupted by await_, context is not dead and result is a
+        # coroutine to wait. If the context is dead the function has
+        # returned, and its result can be returned.
+        try:
+            result = context.switch(*args, **kwargs)
+            while not context.dead:
+                try:
+                    # wait for a coroutine from await_ and then return its
+                    # result back to it.
+                    value = await result
+                except Exception:
+                    # this allows an exception to be raised within
+                    # the moderated greenlet so that it can continue
+                    # its expected flow.
+                    result = context.throw(*sys.exc_info())
+                else:
+                    result = context.switch(value)
+        finally:
+            # clean up to avoid cycle resolution by gc
+            del context.driver
+        return result
+
+
+except ImportError:  # pragma: no cover
+    greenlet = None
+
+    def await_fallback(awaitable):
+        return asyncio.get_event_loop().run_until_complete(awaitable)
+
+    def await_only(awaitable):
+        raise ValueError("Greenlet is required to use this function")
+
+    async def greenlet_spawn(fn, *args, **kw):
+        raise ValueError("Greenlet is required to use this function")
diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py
new file mode 100644 (file)
index 0000000..4c4ea20
--- /dev/null
@@ -0,0 +1,21 @@
+from . import compat
+
+
+if compat.py3k:
+    import asyncio
+    from ._concurrency_py3k import await_only
+    from ._concurrency_py3k import await_fallback
+    from ._concurrency_py3k import greenlet
+    from ._concurrency_py3k import greenlet_spawn
+else:
+    asyncio = None
+    greenlet = None
+
+    def await_only(thing):
+        return thing
+
+    def await_fallback(thing):
+        return thing
+
+    def greenlet_spawn(fn, *args, **kw):
+        raise ValueError("Cannot use this function in py2.")
index 3433657d6bc3eb3a24415488432b5b7a209c0c81..5f71c7bd6f29fd35ba55776ed86ae5ef8b0a41e8 100644 (file)
@@ -21,7 +21,10 @@ condition.
 from collections import deque
 from time import time as _time
 
+from . import compat
 from .compat import threading
+from .concurrency import asyncio
+from .concurrency import await_fallback
 
 
 __all__ = ["Empty", "Full", "Queue"]
@@ -196,3 +199,64 @@ class Queue:
         else:
             # FIFO
             return self.queue.popleft()
+
+
+class AsyncAdaptedQueue:
+    await_ = await_fallback
+
+    def __init__(self, maxsize=0, use_lifo=False):
+        if use_lifo:
+            self._queue = asyncio.LifoQueue(maxsize=maxsize)
+        else:
+            self._queue = asyncio.Queue(maxsize=maxsize)
+        self.maxsize = maxsize
+        self.empty = self._queue.empty
+        self.full = self._queue.full
+        self.qsize = self._queue.qsize
+
+    def put_nowait(self, item):
+        try:
+            return self._queue.put_nowait(item)
+        except asyncio.queues.QueueFull as err:
+            compat.raise_(
+                Full(), replace_context=err,
+            )
+
+    def put(self, item, block=True, timeout=None):
+        if not block:
+            return self.put_nowait(item)
+
+        try:
+            if timeout:
+                return self.await_(
+                    asyncio.wait_for(self._queue.put(item), timeout)
+                )
+            else:
+                return self.await_(self._queue.put(item))
+        except asyncio.queues.QueueFull as err:
+            compat.raise_(
+                Full(), replace_context=err,
+            )
+
+    def get_nowait(self):
+        try:
+            return self._queue.get_nowait()
+        except asyncio.queues.QueueEmpty as err:
+            compat.raise_(
+                Empty(), replace_context=err,
+            )
+
+    def get(self, block=True, timeout=None):
+        if not block:
+            return self.get_nowait()
+        try:
+            if timeout:
+                return self.await_(
+                    asyncio.wait_for(self._queue.get(), timeout)
+                )
+            else:
+                return self.await_(self._queue.get())
+        except asyncio.queues.QueueEmpty as err:
+            compat.raise_(
+                Empty(), replace_context=err,
+            )
index 9cbdbd838db4b89cf8b972abe7bf6d6685602efe..387f422efd30c26ca503bb50b831627291c1545b 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -38,10 +38,12 @@ python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*
 package_dir =
     =lib
 install_requires =
-  importlib-metadata;python_version<"3.8"
-
+    importlib-metadata;python_version<"3.8"
+    greenlet
 
 [options.extras_require]
+asyncio =
+    greenlet
 mssql = pyodbc
 mssql_pymssql = pymssql
 mssql_pyodbc = pyodbc
@@ -53,6 +55,9 @@ oracle =
     cx_oracle>=7;python_version>="3"
 postgresql = psycopg2>=2.7
 postgresql_pg8000 = pg8000
+postgresql_asyncpg =
+    asyncpg;python_version>="3"
+    greenlet
 postgresql_psycopg2binary = psycopg2-binary
 postgresql_psycopg2cffi = psycopg2cffi
 pymysql = pymysql
@@ -110,6 +115,7 @@ default = sqlite:///:memory:
 sqlite = sqlite:///:memory:
 sqlite_file = sqlite:///querytest.db
 postgresql = postgresql://scott:tiger@127.0.0.1:5432/test
+asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true
 pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
 postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test
 mysql = mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py
new file mode 100644 (file)
index 0000000..10b8929
--- /dev/null
@@ -0,0 +1,103 @@
+from sqlalchemy import exc
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.util import await_fallback
+from sqlalchemy.util import await_only
+from sqlalchemy.util import greenlet_spawn
+
+
+async def run1():
+    return 1
+
+
+async def run2():
+    return 2
+
+
+def go(*fns):
+    return sum(await_only(fn()) for fn in fns)
+
+
+class TestAsyncioCompat(fixtures.TestBase):
+    @async_test
+    async def test_ok(self):
+
+        eq_(await greenlet_spawn(go, run1, run2), 3)
+
+    @async_test
+    async def test_async_error(self):
+        async def err():
+            raise ValueError("an error")
+
+        with expect_raises_message(ValueError, "an error"):
+            await greenlet_spawn(go, run1, err)
+
+    @async_test
+    async def test_sync_error(self):
+        def go():
+            await_only(run1())
+            raise ValueError("sync error")
+
+        with expect_raises_message(ValueError, "sync error"):
+            await greenlet_spawn(go)
+
+    def test_await_fallback_no_greenlet(self):
+        to_await = run1()
+        await_fallback(to_await)
+
+    def test_await_only_no_greenlet(self):
+        to_await = run1()
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"greenlet_spawn has not been called; can't call await_\(\) here.",
+        ):
+            await_only(to_await)
+
+        # ensure no warning
+        await_fallback(to_await)
+
+    @async_test
+    async def test_await_fallback_error(self):
+        to_await = run1()
+
+        await to_await
+
+        async def inner_await():
+            nonlocal to_await
+            to_await = run1()
+            await_fallback(to_await)
+
+        def go():
+            await_fallback(inner_await())
+
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            "greenlet_spawn has not been called and asyncio event loop",
+        ):
+            await greenlet_spawn(go)
+
+        await to_await
+
+    @async_test
+    async def test_await_only_error(self):
+        to_await = run1()
+
+        await to_await
+
+        async def inner_await():
+            nonlocal to_await
+            to_await = run1()
+            await_only(to_await)
+
+        def go():
+            await_only(inner_await())
+
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"greenlet_spawn has not been called; can't call await_\(\) here.",
+        ):
+            await greenlet_spawn(go)
+
+        await to_await
index 5c6b89fde7ce98eabdc63843ac0d8b599c9aa7bc..92d3e07768572a1d15233dfc28ca28bda534c7c4 100755 (executable)
@@ -11,6 +11,11 @@ import sys
 
 import pytest
 
+
+collect_ignore_glob = []
+if sys.version_info[0] < 3:
+    collect_ignore_glob.append("*_py3k.py")
+
 pytest.register_assert_rewrite("sqlalchemy.testing.assertions")
 
 
index f6aba550ecc06b82e321c0c76cfd26324de43a24..57c243442e619ace82a479fe900bef146d35be67 100644 (file)
@@ -937,9 +937,7 @@ $$ LANGUAGE plpgsql;
         stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric)
         assert_raises(exc.InvalidRequestError, connection.execute, stmt)
 
-    @testing.only_if(
-        "postgresql >= 8.2", "requires standard_conforming_strings"
-    )
+    @testing.only_on("postgresql+psycopg2")
     def test_serial_integer(self):
         class BITD(TypeDecorator):
             impl = Integer
index ffd32813c0cf73509c7b8450e337aa1a97437c84..5ab65f9e34ec4dfe4c84d5f14c3c1cd0e1074322 100644 (file)
@@ -738,17 +738,14 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
     def teardown_class(cls):
         metadata.drop_all()
 
-    @testing.fails_on("postgresql+pg8000", "uses positional")
+    @testing.requires.pyformat_paramstyle
     def test_expression_pyformat(self):
         self.assert_compile(
             matchtable.c.title.match("somstr"),
             "matchtable.title @@ to_tsquery(%(title_1)s" ")",
         )
 
-    @testing.fails_on("postgresql+psycopg2", "uses pyformat")
-    @testing.fails_on("postgresql+pypostgresql", "uses pyformat")
-    @testing.fails_on("postgresql+pygresql", "uses pyformat")
-    @testing.fails_on("postgresql+psycopg2cffi", "uses pyformat")
+    @testing.requires.format_paramstyle
     def test_expression_positional(self):
         self.assert_compile(
             matchtable.c.title.match("somstr"),
index 95486b19799cab8997c10dba4ae13768ca3f05f6..503477833d2b9704daf819e3172d63d0e9bd59fe 100644 (file)
@@ -27,6 +27,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import Text
 from sqlalchemy import text
+from sqlalchemy import type_coerce
 from sqlalchemy import TypeDecorator
 from sqlalchemy import types
 from sqlalchemy import Unicode
@@ -774,7 +775,12 @@ class RegClassTest(fixtures.TestBase):
         regclass = cast("pg_class", postgresql.REGCLASS)
         oid = self._scalar(cast(regclass, postgresql.OID))
         assert isinstance(oid, int)
-        eq_(self._scalar(cast(oid, postgresql.REGCLASS)), "pg_class")
+        eq_(
+            self._scalar(
+                cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS)
+            ),
+            "pg_class",
+        )
 
     def test_cast_whereclause(self):
         pga = Table(
@@ -1801,10 +1807,13 @@ class ArrayEnum(fixtures.TestBase):
             testing.db,
         )
 
+    @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
     @testing.combinations(
-        sqltypes.ARRAY, postgresql.ARRAY, _ArrayOfEnum, argnames="array_cls"
+        sqltypes.ARRAY,
+        postgresql.ARRAY,
+        (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
+        argnames="array_cls",
     )
-    @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
     @testing.provide_metadata
     def test_array_of_enums(self, array_cls, enum_cls, connection):
         tbl = Table(
@@ -1845,6 +1854,8 @@ class ArrayEnum(fixtures.TestBase):
             sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc())
             eq_(connection.scalar(sel), [MyEnum.a])
 
+        self.metadata.drop_all(connection)
+
 
 class ArrayJSON(fixtures.TestBase):
     __backend__ = True
index af6bc1d369a4d7f2740684f9fc8e5cdf281e32e0..624fa90053d7f68be7d17d358f6aba6eda91832a 100644 (file)
@@ -10,8 +10,8 @@ from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import util
 from sqlalchemy.sql import util as sql_util
+from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
-from sqlalchemy.testing import assert_raises_return
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import eq_regex
@@ -104,7 +104,7 @@ class LogParamsTest(fixtures.TestBase):
 
     def test_log_positional_array(self):
         with self.eng.connect() as conn:
-            exc_info = assert_raises_return(
+            exc_info = assert_raises(
                 tsa.exc.DBAPIError,
                 conn.execute,
                 tsa.text("SELECT * FROM foo WHERE id IN :foo AND bar=:bar"),
index d91105f41100ecb01c812ee147c895e36c40c81d..48eb485cb7a76f8e2d632af137284c953c049f22 100644 (file)
@@ -1356,7 +1356,14 @@ class InvalidateDuringResultTest(fixtures.TestBase):
         "cx_oracle 6 doesn't allow a close like this due to open cursors",
     )
     @testing.fails_if(
-        ["+mysqlconnector", "+mysqldb", "+cymysql", "+pymysql", "+pg8000"],
+        [
+            "+mysqlconnector",
+            "+mysqldb",
+            "+cymysql",
+            "+pymysql",
+            "+pg8000",
+            "+asyncpg",
+        ],
         "Buffers the result set and doesn't check for connection close",
     )
     def test_invalidate_on_results(self):
@@ -1365,5 +1372,8 @@ class InvalidateDuringResultTest(fixtures.TestBase):
         for x in range(20):
             result.fetchone()
         self.engine.test_shutdown()
-        _assert_invalidated(result.fetchone)
-        assert conn.invalidated
+        try:
+            _assert_invalidated(result.fetchone)
+            assert conn.invalidated
+        finally:
+            conn.invalidate()
index 8981028d2cac797f298976937056db5bd61b926e..cd144e45f49f5a7a11326463ed78168fbdb13e33 100644 (file)
@@ -461,11 +461,8 @@ class TransactionTest(fixtures.TestBase):
         assert not savepoint.is_active
 
         if util.py3k:
-            # driver error
-            assert exc_.__cause__
-
-            # and that's it, no other context
-            assert not exc_.__cause__.__context__
+            # ensure cause comes from the DBAPI
+            assert isinstance(exc_.__cause__, testing.db.dialect.dbapi.Error)
 
     def test_retains_through_options(self, local_connection):
         connection = local_connection
diff --git a/test/ext/asyncio/__init__.py b/test/ext/asyncio/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
new file mode 100644 (file)
index 0000000..ec513cb
--- /dev/null
@@ -0,0 +1,340 @@
+from sqlalchemy import Column
+from sqlalchemy import delete
+from sqlalchemy import exc
+from sqlalchemy import func
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import union_all
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio import exc as asyncio_exc
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.asyncio import assert_raises_message_async
+
+
+class EngineFixture(fixtures.TablesTest):
+    __requires__ = ("async_dialect",)
+
+    @testing.fixture
+    def async_engine(self):
+        return create_async_engine(testing.db.url)
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "users",
+            metadata,
+            Column("user_id", Integer, primary_key=True, autoincrement=False),
+            Column("user_name", String(20)),
+        )
+
+    @classmethod
+    def insert_data(cls, connection):
+        users = cls.tables.users
+        with connection.begin():
+            connection.execute(
+                users.insert(),
+                [
+                    {"user_id": i, "user_name": "name%d" % i}
+                    for i in range(1, 20)
+                ],
+            )
+
+
+class AsyncEngineTest(EngineFixture):
+    __backend__ = True
+
+    @async_test
+    async def test_connect_ctxmanager(self, async_engine):
+        async with async_engine.connect() as conn:
+            result = await conn.execute(select(1))
+            eq_(result.scalar(), 1)
+
+    @async_test
+    async def test_connect_plain(self, async_engine):
+        conn = await async_engine.connect()
+        try:
+            result = await conn.execute(select(1))
+            eq_(result.scalar(), 1)
+        finally:
+            await conn.close()
+
+    @async_test
+    async def test_connection_not_started(self, async_engine):
+
+        conn = async_engine.connect()
+        testing.assert_raises_message(
+            asyncio_exc.AsyncContextNotStarted,
+            "AsyncConnection context has not been started and "
+            "object has not been awaited.",
+            conn.begin,
+        )
+
+    @async_test
+    async def test_transaction_commit(self, async_engine):
+        users = self.tables.users
+
+        async with async_engine.begin() as conn:
+            await conn.execute(delete(users))
+
+        async with async_engine.connect() as conn:
+            eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
+
+    @async_test
+    async def test_savepoint_rollback_noctx(self, async_engine):
+        users = self.tables.users
+
+        async with async_engine.begin() as conn:
+
+            savepoint = await conn.begin_nested()
+            await conn.execute(delete(users))
+            await savepoint.rollback()
+
+        async with async_engine.connect() as conn:
+            eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
+
+    @async_test
+    async def test_savepoint_commit_noctx(self, async_engine):
+        users = self.tables.users
+
+        async with async_engine.begin() as conn:
+
+            savepoint = await conn.begin_nested()
+            await conn.execute(delete(users))
+            await savepoint.commit()
+
+        async with async_engine.connect() as conn:
+            eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
+
+    @async_test
+    async def test_transaction_rollback(self, async_engine):
+        users = self.tables.users
+
+        async with async_engine.connect() as conn:
+            trans = conn.begin()
+            await trans.start()
+            await conn.execute(delete(users))
+            await trans.rollback()
+
+        async with async_engine.connect() as conn:
+            eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
+
+    @async_test
+    async def test_conn_transaction_not_started(self, async_engine):
+
+        async with async_engine.connect() as conn:
+            trans = conn.begin()
+            await assert_raises_message_async(
+                asyncio_exc.AsyncContextNotStarted,
+                "AsyncTransaction context has not been started "
+                "and object has not been awaited.",
+                trans.rollback(),
+            )
+
+
+class AsyncResultTest(EngineFixture):
+    @testing.combinations(
+        (None,), ("scalars",), ("mappings",), argnames="filter_"
+    )
+    @async_test
+    async def test_all(self, async_engine, filter_):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(select(users))
+
+            if filter_ == "mappings":
+                result = result.mappings()
+            elif filter_ == "scalars":
+                result = result.scalars(1)
+
+            all_ = await result.all()
+            if filter_ == "mappings":
+                eq_(
+                    all_,
+                    [
+                        {"user_id": i, "user_name": "name%d" % i}
+                        for i in range(1, 20)
+                    ],
+                )
+            elif filter_ == "scalars":
+                eq_(
+                    all_, ["name%d" % i for i in range(1, 20)],
+                )
+            else:
+                eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
+
+    @testing.combinations(
+        (None,), ("scalars",), ("mappings",), argnames="filter_"
+    )
+    @async_test
+    async def test_aiter(self, async_engine, filter_):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(select(users))
+
+            if filter_ == "mappings":
+                result = result.mappings()
+            elif filter_ == "scalars":
+                result = result.scalars(1)
+
+            rows = []
+
+            async for row in result:
+                rows.append(row)
+
+            if filter_ == "mappings":
+                eq_(
+                    rows,
+                    [
+                        {"user_id": i, "user_name": "name%d" % i}
+                        for i in range(1, 20)
+                    ],
+                )
+            elif filter_ == "scalars":
+                eq_(
+                    rows, ["name%d" % i for i in range(1, 20)],
+                )
+            else:
+                eq_(rows, [(i, "name%d" % i) for i in range(1, 20)])
+
+    @testing.combinations((None,), ("mappings",), argnames="filter_")
+    @async_test
+    async def test_keys(self, async_engine, filter_):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(select(users))
+
+            if filter_ == "mappings":
+                result = result.mappings()
+
+            eq_(result.keys(), ["user_id", "user_name"])
+
+    @async_test
+    async def test_unique_all(self, async_engine):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(
+                union_all(select(users), select(users)).order_by(
+                    users.c.user_id
+                )
+            )
+
+            all_ = await result.unique().all()
+            eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
+
+    @async_test
+    async def test_columns_all(self, async_engine):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(select(users))
+
+            all_ = await result.columns(1).all()
+            eq_(all_, [("name%d" % i,) for i in range(1, 20)])
+
+    @testing.combinations(
+        (None,), ("scalars",), ("mappings",), argnames="filter_"
+    )
+    @async_test
+    async def test_partitions(self, async_engine, filter_):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(select(users))
+
+            if filter_ == "mappings":
+                result = result.mappings()
+            elif filter_ == "scalars":
+                result = result.scalars(1)
+
+            check_result = []
+            async for partition in result.partitions(5):
+                check_result.append(partition)
+
+            if filter_ == "mappings":
+                eq_(
+                    check_result,
+                    [
+                        [
+                            {"user_id": i, "user_name": "name%d" % i}
+                            for i in range(a, b)
+                        ]
+                        for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)]
+                    ],
+                )
+            elif filter_ == "scalars":
+                eq_(
+                    check_result,
+                    [
+                        ["name%d" % i for i in range(a, b)]
+                        for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)]
+                    ],
+                )
+            else:
+                eq_(
+                    check_result,
+                    [
+                        [(i, "name%d" % i) for i in range(a, b)]
+                        for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)]
+                    ],
+                )
+
+    @testing.combinations(
+        (None,), ("scalars",), ("mappings",), argnames="filter_"
+    )
+    @async_test
+    async def test_one_success(self, async_engine, filter_):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(
+                select(users).limit(1).order_by(users.c.user_name)
+            )
+
+            if filter_ == "mappings":
+                result = result.mappings()
+            elif filter_ == "scalars":
+                result = result.scalars()
+            u1 = await result.one()
+
+            if filter_ == "mappings":
+                eq_(u1, {"user_id": 1, "user_name": "name%d" % 1})
+            elif filter_ == "scalars":
+                eq_(u1, 1)
+            else:
+                eq_(u1, (1, "name%d" % 1))
+
+    @async_test
+    async def test_one_no_result(self, async_engine):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(
+                select(users).where(users.c.user_name == "nonexistent")
+            )
+
+            async def go():
+                await result.one()
+
+            await assert_raises_message_async(
+                exc.NoResultFound,
+                "No row was found when one was required",
+                go(),
+            )
+
+    @async_test
+    async def test_one_multi_result(self, async_engine):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            result = await conn.stream(
+                select(users).where(users.c.user_name.in_(["name3", "name5"]))
+            )
+
+            async def go():
+                await result.one()
+
+            await assert_raises_message_async(
+                exc.MultipleResultsFound,
+                "Multiple rows were found when exactly one was required",
+                go(),
+            )
diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py
new file mode 100644 (file)
index 0000000..e8caaca
--- /dev/null
@@ -0,0 +1,200 @@
+from sqlalchemy import exc
+from sqlalchemy import func
+from sqlalchemy import select
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.orm import selectinload
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import is_
+from ...orm import _fixtures
+
+
+class AsyncFixture(_fixtures.FixtureTest):
+    __requires__ = ("async_dialect",)
+
+    @classmethod
+    def setup_mappers(cls):
+        cls._setup_stock_mapping()
+
+    @testing.fixture
+    def async_engine(self):
+        return create_async_engine(testing.db.url)
+
+    @testing.fixture
+    def async_session(self, async_engine):
+        return AsyncSession(async_engine)
+
+
+class AsyncSessionTest(AsyncFixture):
+    def test_requires_async_engine(self, async_engine):
+        testing.assert_raises_message(
+            exc.ArgumentError,
+            "AsyncEngine expected, got Engine",
+            AsyncSession,
+            bind=async_engine.sync_engine,
+        )
+
+
+class AsyncSessionQueryTest(AsyncFixture):
+    @async_test
+    async def test_execute(self, async_session):
+        User = self.classes.User
+
+        stmt = (
+            select(User)
+            .options(selectinload(User.addresses))
+            .order_by(User.id)
+        )
+
+        result = await async_session.execute(stmt)
+        eq_(result.scalars().all(), self.static.user_address_result)
+
+    @async_test
+    async def test_stream_partitions(self, async_session):
+        User = self.classes.User
+
+        stmt = (
+            select(User)
+            .options(selectinload(User.addresses))
+            .order_by(User.id)
+        )
+
+        result = await async_session.stream(stmt)
+
+        assert_result = []
+        async for partition in result.scalars().partitions(3):
+            assert_result.append(partition)
+
+        eq_(
+            assert_result,
+            [
+                self.static.user_address_result[0:3],
+                self.static.user_address_result[3:],
+            ],
+        )
+
+
+class AsyncSessionTransactionTest(AsyncFixture):
+    run_inserts = None
+
+    @async_test
+    async def test_trans(self, async_session, async_engine):
+        async with async_engine.connect() as outer_conn:
+
+            User = self.classes.User
+
+            async with async_session.begin():
+
+                eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
+
+                u1 = User(name="u1")
+
+                async_session.add(u1)
+
+                result = await async_session.execute(select(User))
+                eq_(result.scalar(), u1)
+
+            eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
+
+    @async_test
+    async def test_commit_as_you_go(self, async_session, async_engine):
+        async with async_engine.connect() as outer_conn:
+
+            User = self.classes.User
+
+            eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
+
+            u1 = User(name="u1")
+
+            async_session.add(u1)
+
+            result = await async_session.execute(select(User))
+            eq_(result.scalar(), u1)
+
+            await async_session.commit()
+
+            eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
+
+    @async_test
+    async def test_trans_noctx(self, async_session, async_engine):
+        async with async_engine.connect() as outer_conn:
+
+            User = self.classes.User
+
+            trans = await async_session.begin()
+            try:
+                eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
+
+                u1 = User(name="u1")
+
+                async_session.add(u1)
+
+                result = await async_session.execute(select(User))
+                eq_(result.scalar(), u1)
+            finally:
+                await trans.commit()
+
+            eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
+
+    @async_test
+    async def test_flush(self, async_session):
+        User = self.classes.User
+
+        async with async_session.begin():
+            u1 = User(name="u1")
+
+            async_session.add(u1)
+
+            conn = await async_session.connection()
+
+            eq_(await conn.scalar(select(func.count(User.id))), 0)
+
+            await async_session.flush()
+
+            eq_(await conn.scalar(select(func.count(User.id))), 1)
+
+    @async_test
+    async def test_refresh(self, async_session):
+        User = self.classes.User
+
+        async with async_session.begin():
+            u1 = User(name="u1")
+
+            async_session.add(u1)
+            await async_session.flush()
+
+            conn = await async_session.connection()
+
+            await conn.execute(
+                update(User)
+                .values(name="u2")
+                .execution_options(synchronize_session=None)
+            )
+
+            eq_(u1.name, "u1")
+
+            await async_session.refresh(u1)
+
+            eq_(u1.name, "u2")
+
+            eq_(await conn.scalar(select(func.count(User.id))), 1)
+
+    @async_test
+    async def test_merge(self, async_session):
+        User = self.classes.User
+
+        async with async_session.begin():
+            u1 = User(id=1, name="u1")
+
+            async_session.add(u1)
+
+        async with async_session.begin():
+            new_u = User(id=1, name="new u1")
+
+            new_u_merged = await async_session.merge(new_u)
+
+            is_(new_u_merged, u1)
+            eq_(u1.name, "new u1")
index 75dca1c99075842db3c39e805cabc3817e1c2a69..047ef25aee5b8379d9aa64f34032ba49166e7231 100644 (file)
@@ -1421,6 +1421,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
         # this would work with Firebird if you do literal_column('1')
         # instead
         case_stmt = case([(Document.title.in_(subq), True)], else_=False)
+
         s.query(Document).update(
             {"flag": case_stmt}, synchronize_session=False
         )
index 28f955fa5fd6dd20a1cc1fc6de43aba42c014124..fdb7c2ff338a03d0ccdc92af9e63f32be9fb9445 100644 (file)
@@ -198,7 +198,7 @@ class DefaultRequirements(SuiteRequirements):
                 "mysql+pymysql",
                 "mysql+cymysql",
                 "mysql+mysqlconnector",
-                "postgresql",
+                "postgresql+pg8000",
             ]
         )
 
@@ -1162,20 +1162,6 @@ class DefaultRequirements(SuiteRequirements):
                     "Firebird still has FP inaccuracy even "
                     "with only four decimal places",
                 ),
-                (
-                    "mssql+pyodbc",
-                    None,
-                    None,
-                    "mssql+pyodbc has FP inaccuracy even with "
-                    "only four decimal places ",
-                ),
-                (
-                    "mssql+pymssql",
-                    None,
-                    None,
-                    "mssql+pymssql has FP inaccuracy even with "
-                    "only four decimal places ",
-                ),
                 (
                     "postgresql+pg8000",
                     None,
@@ -1280,6 +1266,12 @@ class DefaultRequirements(SuiteRequirements):
 
         return only_if(check_range_types)
 
+    @property
+    def async_dialect(self):
+        """dialect makes use of await_() to invoke operations on the DBAPI."""
+
+        return only_on(["postgresql+asyncpg"])
+
     @property
     def oracle_test_dblink(self):
         return skip_if(
index 676c46db65fb0ad98f243114c3b68b20fa215794..aa1c0d48d722cf78f2a131ecccf2d80ffd9aa9cf 100644 (file)
@@ -948,7 +948,7 @@ class PKDefaultTest(fixtures.TablesTest):
             metadata,
             Column(
                 "date_id",
-                DateTime,
+                DateTime(timezone=True),
                 default=text("current_timestamp"),
                 primary_key=True,
             ),
diff --git a/tox.ini b/tox.ini
index 0ce79d7a1a1edf11134344da7abb4ab764fb96f7..e3539ce6112b4703348aa1cc9e8a4696f234f411 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -17,9 +17,11 @@ usedevelop=
 
 deps=pytest!=3.9.1,!=3.9.2
      pytest-xdist
+     greenlet
      mock; python_version < '3.3'
      importlib_metadata; python_version < '3.8'
      postgresql: .[postgresql]
+     postgresql: .[postgresql_asyncpg]
      mysql: .[mysql]
      mysql: .[pymysql]
      oracle: .[oracle]
@@ -56,7 +58,7 @@ setenv=
     cov: COVERAGE={[testenv]cov_args}
     sqlite: SQLITE={env:TOX_SQLITE:--db sqlite}
     sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
-    postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
+    postgresql: POSTGRESQL={env:TOX_POSTGRESQL_W_ASYNCPG:--db postgresql}
     mysql: MYSQL={env:TOX_MYSQL:--db mysql --db pymysql}
     oracle: ORACLE={env:TOX_ORACLE:--db oracle}
     mssql: MSSQL={env:TOX_MSSQL:--db mssql}
@@ -68,7 +70,7 @@ setenv=
 # tox as of 2.0 blocks all environment variables from the
 # outside, unless they are here (or in TOX_TESTENV_PASSENV,
 # wildcards OK).  Need at least these
-passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_MYSQL TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS
+passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL_W_ASYNCPG TOX_MYSQL TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS
 
 # for nocext, we rm *.so in lib in case we are doing usedevelop=True
 commands=