]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Detect non compatible execution in async mode
authorFederico Caselli <cfederico87@gmail.com>
Thu, 3 Dec 2020 22:53:47 +0000 (23:53 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 8 Dec 2020 14:41:23 +0000 (15:41 +0100)
The SQLAlchemy async mode now detects and raises an informative
error when an non asyncio compatible :term:`DBAPI` is used.
Using a standard ``DBAPI`` with async SQLAlchemy will cause
it to block like any sync call, interrupting the executing asyncio
loop.

Change-Id: I9aed87dc1b0df53e8cb2109495237038aa2cb2d4

doc/build/changelog/unreleased_14/async_dbapi_detection.rst [new file with mode: 0644]
doc/build/errors.rst
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/testing/asyncio.py [deleted file]
lib/sqlalchemy/util/_concurrency_py3k.py
test/base/test_concurrency_py3k.py
test/ext/asyncio/test_engine_py3k.py

diff --git a/doc/build/changelog/unreleased_14/async_dbapi_detection.rst b/doc/build/changelog/unreleased_14/async_dbapi_detection.rst
new file mode 100644 (file)
index 0000000..4764f7a
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: asyncio
+
+    The SQLAlchemy async mode now detects and raises an informative
+    error when an non asyncio compatible :term:`DBAPI` is used.
+    Using a standard ``DBAPI`` with async SQLAlchemy will cause
+    it to block like any sync call, interrupting the executing asyncio
+    loop.
index 42c0db977655d673d6f0fcae8293d6780800f1cd..a52444766db6fb9e2ebe9d62499fc431713737d8 100644 (file)
@@ -1114,6 +1114,24 @@ message for details.
 
     :ref:`error_bbf0`
 
+
+AsyncIO Exceptions
+==================
+
+.. _error_xd1r:
+
+AwaitRequired
+-------------
+
+The SQLAlchemy async mode requires an async driver to be used to connect to the db.
+This error is usually raised when trying to use the async version of SQLAlchemy
+with a non compatible :term:`DBAPI`.
+
+.. seealso::
+
+    :ref:`asyncio extension <asyncio_toplevel>`
+
+
 Core Exception Classes
 ======================
 
index 7c1bbb18ef26b1f194919ff0671db426e5f26b73..c56c1f020338dd6446d271450f7f0e03f850ef53 100644 (file)
@@ -419,8 +419,8 @@ From version 3.24.0 onwards, SQLite supports "upserts" (update or insert)
 of rows into a table via the ``ON CONFLICT`` clause of the ``INSERT``
 statement. A candidate row will only be inserted if that row does not violate
 any unique or primary key constraints. In the case of a unique constraint violation, a
-secondary action can occur which can be either “DO UPDATE”, indicating that
-the data in the target row should be updated, or “DO NOTHING”, which indicates
+secondary action can occur which can be either "DO UPDATE", indicating that
+the data in the target row should be updated, or "DO NOTHING", which indicates
 to silently skip this row.
 
 Conflicts are determined using columns that are part of existing unique
@@ -469,7 +469,7 @@ and :meth:`_sqlite.Insert.on_conflict_do_nothing`:
 Specifying the Target
 ^^^^^^^^^^^^^^^^^^^^^
 
-Both methods supply the “target” of the conflict using column inference:
+Both methods supply the "target" of the conflict using column inference:
 
 * The :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` argument
   specifies a sequence containing string column names, :class:`_schema.Column`
index 7ba2e369b603c93c6d48a7730f4f7b69322560d5..63c56c34d8bdab3742bec57a1b38a765e1e7b362 100644 (file)
@@ -285,6 +285,15 @@ class NoReferenceError(InvalidRequestError):
     """Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
 
 
+class AwaitRequired(InvalidRequestError):
+    """Error raised by the async greenlet spawn if no async operation
+    was awaited when it required one
+
+    """
+
+    code = "xd1r"
+
+
 class NoReferencedTableError(NoReferenceError):
     """Raised by ``ForeignKey`` when the referred ``Table`` cannot be
     located.
@@ -355,10 +364,6 @@ class DontWrapMixin(object):
     """
 
 
-# Moved to orm.exc; compatibility definition installed by orm import until 0.6
-UnmappedColumnError = None
-
-
 class StatementError(SQLAlchemyError):
     """An error occurred during execution of a SQL statement.
 
index 9e4851dfcda38e7b78a546d774dedca87ed35e44..16edcc2b2a6c4e50d7c833b96dcb8f35087c9daa 100644 (file)
@@ -243,6 +243,7 @@ class AsyncConnection(StartableContext, AsyncConnectable):
             statement,
             parameters,
             execution_options,
+            _require_await=True,
         )
         if result.context._is_server_side:
             raise async_exc.AsyncMethodRequired(
@@ -272,6 +273,7 @@ class AsyncConnection(StartableContext, AsyncConnectable):
             util.EMPTY_DICT.merge_with(
                 execution_options, {"stream_results": True}
             ),
+            _require_await=True,
         )
         if not result.context._is_server_side:
             # TODO: real exception here
@@ -322,6 +324,7 @@ class AsyncConnection(StartableContext, AsyncConnectable):
             statement,
             parameters,
             execution_options,
+            _require_await=True,
         )
         if result.context._is_server_side:
             raise async_exc.AsyncMethodRequired(
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py
deleted file mode 100644 (file)
index 2e274de..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-from .assertions import assert_raises as _assert_raises
-from .assertions import assert_raises_message as _assert_raises_message
-from ..util import await_fallback as await_
-from ..util import greenlet_spawn
-
-
-async def assert_raises_async(except_cls, msg, coroutine):
-    await greenlet_spawn(_assert_raises, except_cls, await_, coroutine)
-
-
-async def assert_raises_message_async(except_cls, msg, coroutine):
-    await greenlet_spawn(
-        _assert_raises_message, except_cls, msg, await_, coroutine
-    )
index dcee057134e9cacca68b717b4f08694c46395e7c..8ad3be5439b78c16b5c3838357775e54226c4c30 100644 (file)
@@ -79,7 +79,9 @@ def await_fallback(awaitable: Coroutine) -> Any:
     return current.driver.switch(awaitable)
 
 
-async def greenlet_spawn(fn: Callable, *args, **kwargs) -> Any:
+async def greenlet_spawn(
+    fn: Callable, *args, _require_await=False, **kwargs
+) -> Any:
     """Runs a sync function ``fn`` in a new greenlet.
 
     The sync function can then use :func:`await_` to wait for async
@@ -95,9 +97,11 @@ async def greenlet_spawn(fn: Callable, *args, **kwargs) -> Any:
     # is interrupted by await_, context is not dead and result is a
     # coroutine to wait. If the context is dead the function has
     # returned, and its result can be returned.
+    switch_occurred = False
     try:
         result = context.switch(*args, **kwargs)
         while not context.dead:
+            switch_occurred = True
             try:
                 # wait for a coroutine from await_ and then return its
                 # result back to it.
@@ -112,6 +116,12 @@ async def greenlet_spawn(fn: Callable, *args, **kwargs) -> Any:
     finally:
         # clean up to avoid cycle resolution by gc
         del context.driver
+    if _require_await and not switch_occurred:
+        raise exc.AwaitRequired(
+            "The current operation required an async execution but none was "
+            "detected. This will usually happen when using a non compatible "
+            "DBAPI driver. Please ensure that an async DBAPI is used."
+        )
     return result
 
 
index ba53ea63522fa9c7819a1b92f0416458731b8093..cf1067667d6a680afb5ec9a60947b0873bf1795d 100644 (file)
@@ -138,3 +138,16 @@ class TestAsyncioCompat(fixtures.TestBase):
             )
         }
         eq_(values, set(range(concurrency)))
+
+    @async_test
+    async def test_require_await(self):
+        def run():
+            return 1 + 1
+
+        assert (await greenlet_spawn(run)) == 2
+
+        with expect_raises_message(
+            exc.AwaitRequired,
+            "The current operation required an async execution but none was",
+        ):
+            await greenlet_spawn(run, _require_await=True)
index 83987b06f1173e269397d1083b54acd30b5fe29e..a361ff835a9f93a1b0941410a60a298b7847a766 100644 (file)
@@ -16,12 +16,14 @@ from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.ext.asyncio import engine as _async_engine
 from sqlalchemy.ext.asyncio import exc as asyncio_exc
 from sqlalchemy.testing import async_test
+from sqlalchemy.testing import combinations
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import mock
-from sqlalchemy.testing.asyncio import assert_raises_message_async
 from sqlalchemy.util.concurrency import greenlet_spawn
 
 
@@ -254,12 +256,12 @@ class AsyncEngineTest(EngineFixture):
 
         async with async_engine.connect() as conn:
             trans = conn.begin()
-            await assert_raises_message_async(
+            with expect_raises_message(
                 asyncio_exc.AsyncContextNotStarted,
                 "AsyncTransaction context has not been started "
                 "and object has not been awaited.",
-                trans.rollback(),
-            )
+            ):
+                await trans.rollback(),
 
     @async_test
     async def test_pool_exhausted(self, async_engine):
@@ -270,11 +272,8 @@ class AsyncEngineTest(EngineFixture):
             pool_timeout=0.1,
         )
         async with engine.connect():
-            await assert_raises_message_async(
-                asyncio.TimeoutError,
-                "",
-                engine.connect(),
-            )
+            with expect_raises(asyncio.TimeoutError):
+                await engine.connect()
 
     @async_test
     async def test_create_async_engine_server_side_cursor(self, async_engine):
@@ -530,15 +529,11 @@ class AsyncResultTest(EngineFixture):
                 select(users).where(users.c.user_name == "nonexistent")
             )
 
-            async def go():
+            with expect_raises_message(
+                exc.NoResultFound, "No row was found when one was required"
+            ):
                 await result.one()
 
-            await assert_raises_message_async(
-                exc.NoResultFound,
-                "No row was found when one was required",
-                go(),
-            )
-
     @async_test
     async def test_one_multi_result(self, async_engine):
         users = self.tables.users
@@ -547,11 +542,38 @@ class AsyncResultTest(EngineFixture):
                 select(users).where(users.c.user_name.in_(["name3", "name5"]))
             )
 
-            async def go():
-                await result.one()
-
-            await assert_raises_message_async(
+            with expect_raises_message(
                 exc.MultipleResultsFound,
                 "Multiple rows were found when exactly one was required",
-                go(),
+            ):
+                await result.one()
+
+
+class TextSyncDBAPI(fixtures.TestBase):
+    @testing.fixture
+    def async_engine(self):
+        return create_async_engine("sqlite:///:memory:")
+
+    @async_test
+    @combinations(
+        lambda conn: conn.exec_driver_sql("select 1"),
+        lambda conn: conn.stream(text("select 1")),
+        lambda conn: conn.execute(text("select 1")),
+        argnames="case",
+    )
+    async def test_sync_driver_execution(self, async_engine, case):
+        with expect_raises_message(
+            exc.AwaitRequired,
+            "The current operation required an async execution but none was",
+        ):
+            async with async_engine.connect() as conn:
+                await case(conn)
+
+    @async_test
+    async def test_sync_driver_run_sync(self, async_engine):
+        async with async_engine.connect() as conn:
+            res = await conn.run_sync(
+                lambda conn: conn.scalar(text("select 1"))
             )
+            assert res == 1
+            assert await conn.run_sync(lambda _: 2) == 2