]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Remove async_fallback mode
authorFederico Caselli <cfederico87@gmail.com>
Tue, 5 Dec 2023 21:29:19 +0000 (22:29 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 20 Dec 2023 21:54:28 +0000 (22:54 +0100)
Removed the async_fallback mode and await_fallback function.
Replace get_event_loop with Runner.
Removed the internal function ``await_fallback()``.
Renamed the internal function ``await_only()`` to ``await_()``.

Change-Id: Ib43829be6ebdb59b6c4447f5a15b5d2b81403fa9

31 files changed:
README.unittests.rst
doc/build/changelog/unreleased_21/async_fallback.rst [new file with mode: 0644]
lib/sqlalchemy/__init__.py
lib/sqlalchemy/connectors/aioodbc.py
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/ext/asyncio/base.py
lib/sqlalchemy/pool/__init__.py
lib/sqlalchemy/pool/impl.py
lib/sqlalchemy/testing/asyncio.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/engines.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
lib/sqlalchemy/testing/provision.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/concurrency.py
lib/sqlalchemy/util/queue.py
setup.cfg
test/base/_concurrency_fixtures.py
test/base/test_concurrency.py
test/engine/test_pool.py
test/ext/asyncio/test_engine_py3k.py
test/requirements.py

index d7155c1ac2b4f1a3349c42ed969ba6d7c44193fa..046a30f6a92c8e9aaaccb4a93a9e4acb0d6a9c0d 100644 (file)
@@ -83,13 +83,10 @@ a pre-set URL.  These can be seen using --dbs::
     $ pytest --dbs
     Available --db options (use --dburi to override)
                 aiomysql    mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
-       aiomysql_fallback    mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true
                aiosqlite    sqlite+aiosqlite:///:memory:
           aiosqlite_file    sqlite+aiosqlite:///async_querytest.db
                  asyncmy    mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
-        asyncmy_fallback    mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true
                  asyncpg    postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test
-        asyncpg_fallback    postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true
                  default    sqlite:///:memory:
             docker_mssql    mssql+pymssql://scott:tiger^5HHH@127.0.0.1:1433/test
                  mariadb    mariadb+mysqldb://scott:tiger@127.0.0.1:3306/test
@@ -105,7 +102,6 @@ a pre-set URL.  These can be seen using --dbs::
                  psycopg    postgresql+psycopg://scott:tiger@127.0.0.1:5432/test
                 psycopg2    postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test
            psycopg_async    postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test
-    psycopg_async_fallback  postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test?async_fallback=true
                  pymysql    mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
         pysqlcipher_file    sqlite+pysqlcipher://:test@/querytest.db.enc
                   sqlite    sqlite:///:memory:
diff --git a/doc/build/changelog/unreleased_21/async_fallback.rst b/doc/build/changelog/unreleased_21/async_fallback.rst
new file mode 100644 (file)
index 0000000..44b91d2
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: change, asyncio
+
+    Removed the compatibility ``async_fallback`` mode for async dialects,
+    since it's no longer used by SQLAlchemy tests.
+    Also removed the internal function ``await_fallback()`` and renamed
+    the internal function ``await_only()`` to ``await_()``.
+    No change is expected to user code.
index 2300c2d409a47b1b7c0b7280ba178e4c67960c64..af030614a528f0339a2a11fc03c4d2b7322eb61e 100644 (file)
@@ -47,9 +47,6 @@ from .engine import URL as URL
 from .inspection import inspect as inspect
 from .pool import AssertionPool as AssertionPool
 from .pool import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool
-from .pool import (
-    FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool,
-)
 from .pool import NullPool as NullPool
 from .pool import Pool as Pool
 from .pool import PoolProxiedConnection as PoolProxiedConnection
index e0f5f55474fe460397295ecc834fb2888d14145a..927330b286e405337eee14b6e0c3e016f5492581 100644 (file)
@@ -13,12 +13,8 @@ from typing import TYPE_CHECKING
 from .asyncio import AsyncAdapt_dbapi_connection
 from .asyncio import AsyncAdapt_dbapi_cursor
 from .asyncio import AsyncAdapt_dbapi_ss_cursor
-from .asyncio import AsyncAdaptFallback_dbapi_connection
 from .pyodbc import PyODBCConnector
-from .. import pool
-from .. import util
-from ..util.concurrency import await_fallback
-from ..util.concurrency import await_only
+from ..util.concurrency import await_
 
 if TYPE_CHECKING:
     from ..engine.interfaces import ConnectArgsType
@@ -33,7 +29,7 @@ class AsyncAdapt_aioodbc_cursor(AsyncAdapt_dbapi_cursor):
         return self._cursor._impl.setinputsizes(*inputsizes)
 
         # how it's supposed to work
-        # return self.await_(self._cursor.setinputsizes(*inputsizes))
+        # return await_(self._cursor.setinputsizes(*inputsizes))
 
 
 class AsyncAdapt_aioodbc_ss_cursor(
@@ -59,7 +55,7 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection):
         self._connection._conn.autocommit = value
 
     def ping(self, reconnect):
-        return self.await_(self._connection.ping(reconnect))
+        return await_(self._connection.ping(reconnect))
 
     def add_output_converter(self, *arg, **kw):
         self._connection.add_output_converter(*arg, **kw)
@@ -96,12 +92,6 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection):
             super().close()
 
 
-class AsyncAdaptFallback_aioodbc_connection(
-    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aioodbc_connection
-):
-    __slots__ = ()
-
-
 class AsyncAdapt_aioodbc_dbapi:
     def __init__(self, aioodbc, pyodbc):
         self.aioodbc = aioodbc
@@ -136,19 +126,12 @@ class AsyncAdapt_aioodbc_dbapi:
             setattr(self, name, getattr(self.pyodbc, name))
 
     def connect(self, *arg, **kw):
-        async_fallback = kw.pop("async_fallback", False)
         creator_fn = kw.pop("async_creator_fn", self.aioodbc.connect)
 
-        if util.asbool(async_fallback):
-            return AsyncAdaptFallback_aioodbc_connection(
-                self,
-                await_fallback(creator_fn(*arg, **kw)),
-            )
-        else:
-            return AsyncAdapt_aioodbc_connection(
-                self,
-                await_only(creator_fn(*arg, **kw)),
-            )
+        return AsyncAdapt_aioodbc_connection(
+            self,
+            await_(creator_fn(*arg, **kw)),
+        )
 
 
 class aiodbcConnector(PyODBCConnector):
@@ -170,15 +153,6 @@ class aiodbcConnector(PyODBCConnector):
 
         return (), kw
 
-    @classmethod
-    def get_pool_class(cls, url):
-        async_fallback = url.query.get("async_fallback", False)
-
-        if util.asbool(async_fallback):
-            return pool.FallbackAsyncAdaptedQueuePool
-        else:
-            return pool.AsyncAdaptedQueuePool
-
     def _do_isolation_level(self, connection, autocommit, isolation_level):
         connection.set_autocommit(autocommit)
         connection.set_isolation_level(isolation_level)
index 9358457ceb26caf3e0d8541ca2a2f9f575317d46..f17831068cf5071737d399e0040c58fb23c79d85 100644 (file)
@@ -25,8 +25,7 @@ from ..engine import AdaptedConnection
 from ..engine.interfaces import _DBAPICursorDescription
 from ..engine.interfaces import _DBAPIMultiExecuteParams
 from ..engine.interfaces import _DBAPISingleExecuteParams
-from ..util.concurrency import await_fallback
-from ..util.concurrency import await_only
+from ..util.concurrency import await_
 from ..util.typing import Self
 
 
@@ -121,7 +120,6 @@ class AsyncAdapt_dbapi_cursor:
     __slots__ = (
         "_adapt_connection",
         "_connection",
-        "await_",
         "_cursor",
         "_rows",
     )
@@ -134,12 +132,11 @@ class AsyncAdapt_dbapi_cursor:
     def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
         self._adapt_connection = adapt_connection
         self._connection = adapt_connection._connection
-        self.await_ = adapt_connection.await_
 
         cursor = self._make_new_cursor(self._connection)
 
         try:
-            self._cursor = self.await_(cursor.__aenter__())
+            self._cursor = await_(cursor.__aenter__())
         except Exception as error:
             self._adapt_connection._handle_exception(error)
 
@@ -181,7 +178,7 @@ class AsyncAdapt_dbapi_cursor:
         parameters: Optional[_DBAPISingleExecuteParams] = None,
     ) -> Any:
         try:
-            return self.await_(self._execute_async(operation, parameters))
+            return await_(self._execute_async(operation, parameters))
         except Exception as error:
             self._adapt_connection._handle_exception(error)
 
@@ -191,7 +188,7 @@ class AsyncAdapt_dbapi_cursor:
         seq_of_parameters: _DBAPIMultiExecuteParams,
     ) -> Any:
         try:
-            return self.await_(
+            return await_(
                 self._executemany_async(operation, seq_of_parameters)
             )
         except Exception as error:
@@ -223,18 +220,16 @@ class AsyncAdapt_dbapi_cursor:
             return await self._cursor.executemany(operation, seq_of_parameters)
 
     def nextset(self) -> None:
-        self.await_(self._cursor.nextset())
+        await_(self._cursor.nextset())
         if self._cursor.description and not self.server_side:
-            self._rows = collections.deque(
-                self.await_(self._cursor.fetchall())
-            )
+            self._rows = collections.deque(await_(self._cursor.fetchall()))
 
     def setinputsizes(self, *inputsizes: Any) -> None:
         # NOTE: this is overrridden in aioodbc due to
         # see https://github.com/aio-libs/aioodbc/issues/451
         # right now
 
-        return self.await_(self._cursor.setinputsizes(*inputsizes))
+        return await_(self._cursor.setinputsizes(*inputsizes))
 
     def __enter__(self) -> Self:
         return self
@@ -273,24 +268,23 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
 
     def close(self) -> None:
         if self._cursor is not None:
-            self.await_(self._cursor.close())
+            await_(self._cursor.close())
             self._cursor = None  # type: ignore
 
     def fetchone(self) -> Optional[Any]:
-        return self.await_(self._cursor.fetchone())
+        return await_(self._cursor.fetchone())
 
     def fetchmany(self, size: Optional[int] = None) -> Any:
-        return self.await_(self._cursor.fetchmany(size=size))
+        return await_(self._cursor.fetchmany(size=size))
 
     def fetchall(self) -> Sequence[Any]:
-        return self.await_(self._cursor.fetchall())
+        return await_(self._cursor.fetchall())
 
 
 class AsyncAdapt_dbapi_connection(AdaptedConnection):
     _cursor_cls = AsyncAdapt_dbapi_cursor
     _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
 
-    await_ = staticmethod(await_only)
     __slots__ = ("dbapi", "_execute_mutex")
 
     _connection: AsyncIODBAPIConnection
@@ -323,21 +317,15 @@ class AsyncAdapt_dbapi_connection(AdaptedConnection):
 
     def rollback(self) -> None:
         try:
-            self.await_(self._connection.rollback())
+            await_(self._connection.rollback())
         except Exception as error:
             self._handle_exception(error)
 
     def commit(self) -> None:
         try:
-            self.await_(self._connection.commit())
+            await_(self._connection.commit())
         except Exception as error:
             self._handle_exception(error)
 
     def close(self) -> None:
-        self.await_(self._connection.close())
-
-
-class AsyncAdaptFallback_dbapi_connection(AsyncAdapt_dbapi_connection):
-    __slots__ = ()
-
-    await_ = staticmethod(await_fallback)
+        await_(self._connection.close())
index 978950b8780c79139554d72aacf37837411ac088..f92b1bfaa6c03ec12dceeb93ece129cd815e7eba 100644 (file)
@@ -28,14 +28,10 @@ This dialect should normally be used only with the
 
 """  # noqa
 from .pymysql import MySQLDialect_pymysql
-from ... import pool
-from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
-from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
-from ...util.concurrency import await_fallback
-from ...util.concurrency import await_only
+from ...util.concurrency import await_
 
 
 class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
@@ -64,25 +60,19 @@ class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
 
     def ping(self, reconnect):
         assert not reconnect
-        return self.await_(self._connection.ping(reconnect))
+        return await_(self._connection.ping(reconnect))
 
     def character_set_name(self):
         return self._connection.character_set_name()
 
     def autocommit(self, value):
-        self.await_(self._connection.autocommit(value))
+        await_(self._connection.autocommit(value))
 
     def close(self):
         # it's not awaitable.
         self._connection.close()
 
 
-class AsyncAdaptFallback_aiomysql_connection(
-    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aiomysql_connection
-):
-    __slots__ = ()
-
-
 class AsyncAdapt_aiomysql_dbapi:
     def __init__(self, aiomysql, pymysql):
         self.aiomysql = aiomysql
@@ -118,19 +108,12 @@ class AsyncAdapt_aiomysql_dbapi:
             setattr(self, name, getattr(self.pymysql, name))
 
     def connect(self, *arg, **kw):
-        async_fallback = kw.pop("async_fallback", False)
         creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
 
-        if util.asbool(async_fallback):
-            return AsyncAdaptFallback_aiomysql_connection(
-                self,
-                await_fallback(creator_fn(*arg, **kw)),
-            )
-        else:
-            return AsyncAdapt_aiomysql_connection(
-                self,
-                await_only(creator_fn(*arg, **kw)),
-            )
+        return AsyncAdapt_aiomysql_connection(
+            self,
+            await_(creator_fn(*arg, **kw)),
+        )
 
     def _init_cursors_subclasses(self):
         # suppress unconditional warning emitted by aiomysql
@@ -160,15 +143,6 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
             __import__("aiomysql"), __import__("pymysql")
         )
 
-    @classmethod
-    def get_pool_class(cls, url):
-        async_fallback = url.query.get("async_fallback", False)
-
-        if util.asbool(async_fallback):
-            return pool.FallbackAsyncAdaptedQueuePool
-        else:
-            return pool.AsyncAdaptedQueuePool
-
     def create_connect_args(self, url):
         return super().create_connect_args(
             url, _translate_args=dict(username="user", database="db")
index 3029626fd5f17b141094141525e74272d74394a9..7f2a9979e6b43f4116d05e2995de85cbbf8e9a74 100644 (file)
@@ -28,14 +28,11 @@ This dialect should normally be used only with the
 from __future__ import annotations
 
 from .pymysql import MySQLDialect_pymysql
-from ... import pool
 from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
-from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
-from ...util.concurrency import await_fallback
-from ...util.concurrency import await_only
+from ...util.concurrency import await_
 
 
 class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor):
@@ -69,7 +66,7 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
 
     def ping(self, reconnect):
         assert not reconnect
-        return self.await_(self._do_ping())
+        return await_(self._do_ping())
 
     async def _do_ping(self):
         try:
@@ -82,19 +79,13 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
         return self._connection.character_set_name()
 
     def autocommit(self, value):
-        self.await_(self._connection.autocommit(value))
+        await_(self._connection.autocommit(value))
 
     def close(self):
         # it's not awaitable.
         self._connection.close()
 
 
-class AsyncAdaptFallback_asyncmy_connection(
-    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_asyncmy_connection
-):
-    __slots__ = ()
-
-
 def _Binary(x):
     """Return x as a binary type."""
     return bytes(x)
@@ -130,19 +121,12 @@ class AsyncAdapt_asyncmy_dbapi:
     Binary = staticmethod(_Binary)
 
     def connect(self, *arg, **kw):
-        async_fallback = kw.pop("async_fallback", False)
         creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
 
-        if util.asbool(async_fallback):
-            return AsyncAdaptFallback_asyncmy_connection(
-                self,
-                await_fallback(creator_fn(*arg, **kw)),
-            )
-        else:
-            return AsyncAdapt_asyncmy_connection(
-                self,
-                await_only(creator_fn(*arg, **kw)),
-            )
+        return AsyncAdapt_asyncmy_connection(
+            self,
+            await_(creator_fn(*arg, **kw)),
+        )
 
 
 class MySQLDialect_asyncmy(MySQLDialect_pymysql):
@@ -158,15 +142,6 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
     def import_dbapi(cls):
         return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
 
-    @classmethod
-    def get_pool_class(cls, url):
-        async_fallback = url.query.get("async_fallback", False)
-
-        if util.asbool(async_fallback):
-            return pool.FallbackAsyncAdaptedQueuePool
-        else:
-            return pool.AsyncAdaptedQueuePool
-
     def create_connect_args(self, url):
         return super().create_connect_args(
             url, _translate_args=dict(username="user", database="db")
index 2ce68acce6e1214ca3ee470e6010cf1187e333b5..d138c1819a1c013b3087b0536252b7ec41cc1f76 100644 (file)
@@ -25,17 +25,6 @@ This dialect should normally be used only with the
     from sqlalchemy.ext.asyncio import create_async_engine
     engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
 
-The dialect can also be run as a "synchronous" dialect within the
-:func:`_sa.create_engine` function, which will pass "await" calls into
-an ad-hoc event loop.  This mode of operation is of **limited use**
-and is for special testing scenarios only.  The mode can be enabled by
-adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
-in conjunction with :func:`_sa.create_engine`::
-
-    # for testing purposes only; do not use in production!
-    engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
-
-
 .. versionadded:: 1.4
 
 .. note::
@@ -217,15 +206,13 @@ from .types import BIT
 from .types import BYTEA
 from .types import CITEXT
 from ... import exc
-from ... import pool
 from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
 from ...engine import processors
 from ...sql import sqltypes
-from ...util.concurrency import await_fallback
-from ...util.concurrency import await_only
+from ...util.concurrency import await_
 
 if TYPE_CHECKING:
     from ...engine.interfaces import _DBAPICursorDescription
@@ -556,7 +543,6 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
     def __init__(self, adapt_connection: AsyncAdapt_asyncpg_connection):
         self._adapt_connection = adapt_connection
         self._connection = adapt_connection._connection
-        self.await_ = adapt_connection.await_
         self._cursor = None
         self._rows = collections.deque()
         self._description = None
@@ -654,14 +640,10 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
                 self._handle_exception(error)
 
     def execute(self, operation, parameters=None):
-        self._adapt_connection.await_(
-            self._prepare_and_execute(operation, parameters)
-        )
+        await_(self._prepare_and_execute(operation, parameters))
 
     def executemany(self, operation, seq_of_parameters):
-        return self._adapt_connection.await_(
-            self._executemany(operation, seq_of_parameters)
-        )
+        return await_(self._executemany(operation, seq_of_parameters))
 
     def setinputsizes(self, *inputsizes):
         raise NotImplementedError()
@@ -683,7 +665,7 @@ class AsyncAdapt_asyncpg_ss_cursor(
 
     def _buffer_rows(self):
         assert self._cursor is not None
-        new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
+        new_rows = await_(self._cursor.fetch(50))
         self._rowbuffer = collections.deque(new_rows)
 
     def __aiter__(self):
@@ -721,9 +703,7 @@ class AsyncAdapt_asyncpg_ss_cursor(
         buf = list(self._rowbuffer)
         lb = len(buf)
         if size > lb:
-            buf.extend(
-                self._adapt_connection.await_(self._cursor.fetch(size - lb))
-            )
+            buf.extend(await_(self._cursor.fetch(size - lb)))
 
         result = buf[0:size]
         self._rowbuffer = collections.deque(buf[size:])
@@ -732,9 +712,7 @@ class AsyncAdapt_asyncpg_ss_cursor(
     def fetchall(self):
         assert self._rowbuffer is not None
 
-        ret = list(self._rowbuffer) + list(
-            self._adapt_connection.await_(self._all())
-        )
+        ret = list(self._rowbuffer) + list(await_(self._all()))
         self._rowbuffer.clear()
         return ret
 
@@ -876,7 +854,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
 
     def ping(self):
         try:
-            _ = self.await_(self._async_ping())
+            _ = await_(self._async_ping())
         except Exception as error:
             self._handle_exception(error)
 
@@ -918,7 +896,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
         if self._started:
             assert self._transaction is not None
             try:
-                self.await_(self._transaction.rollback())
+                await_(self._transaction.rollback())
             except Exception as error:
                 self._handle_exception(error)
             finally:
@@ -929,7 +907,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
         if self._started:
             assert self._transaction is not None
             try:
-                self.await_(self._transaction.commit())
+                await_(self._transaction.commit())
             except Exception as error:
                 self._handle_exception(error)
             finally:
@@ -939,7 +917,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
     def close(self):
         self.rollback()
 
-        self.await_(self._connection.close())
+        await_(self._connection.close())
 
     def terminate(self):
         if util.concurrency.in_greenlet():
@@ -948,7 +926,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
             try:
                 # try to gracefully close; see #10717
                 # timeout added in asyncpg 0.14.0 December 2017
-                self.await_(self._connection.close(timeout=2))
+                await_(self._connection.close(timeout=2))
             except asyncio.TimeoutError:
                 # in the case where we are recycling an old connection
                 # that may have already been disconnected, close() will
@@ -966,19 +944,12 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
         return None
 
 
-class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
-    __slots__ = ()
-
-    await_ = staticmethod(await_fallback)
-
-
 class AsyncAdapt_asyncpg_dbapi:
     def __init__(self, asyncpg):
         self.asyncpg = asyncpg
         self.paramstyle = "numeric_dollar"
 
     def connect(self, *arg, **kw):
-        async_fallback = kw.pop("async_fallback", False)
         creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect)
         prepared_statement_cache_size = kw.pop(
             "prepared_statement_cache_size", 100
@@ -987,20 +958,12 @@ class AsyncAdapt_asyncpg_dbapi:
             "prepared_statement_name_func", None
         )
 
-        if util.asbool(async_fallback):
-            return AsyncAdaptFallback_asyncpg_connection(
-                self,
-                await_fallback(creator_fn(*arg, **kw)),
-                prepared_statement_cache_size=prepared_statement_cache_size,
-                prepared_statement_name_func=prepared_statement_name_func,
-            )
-        else:
-            return AsyncAdapt_asyncpg_connection(
-                self,
-                await_only(creator_fn(*arg, **kw)),
-                prepared_statement_cache_size=prepared_statement_cache_size,
-                prepared_statement_name_func=prepared_statement_name_func,
-            )
+        return AsyncAdapt_asyncpg_connection(
+            self,
+            await_(creator_fn(*arg, **kw)),
+            prepared_statement_cache_size=prepared_statement_cache_size,
+            prepared_statement_name_func=prepared_statement_name_func,
+        )
 
     class Error(Exception):
         pass
@@ -1201,15 +1164,6 @@ class PGDialect_asyncpg(PGDialect):
         dbapi_connection.ping()
         return True
 
-    @classmethod
-    def get_pool_class(cls, url):
-        async_fallback = url.query.get("async_fallback", False)
-
-        if util.asbool(async_fallback):
-            return pool.FallbackAsyncAdaptedQueuePool
-        else:
-            return pool.AsyncAdaptedQueuePool
-
     def is_disconnect(self, e, connection, cursor):
         if connection:
             return connection._connection.is_closed()
@@ -1308,11 +1262,11 @@ class PGDialect_asyncpg(PGDialect):
         super_connect = super().on_connect()
 
         def connect(conn):
-            conn.await_(self.setup_asyncpg_json_codec(conn))
-            conn.await_(self.setup_asyncpg_jsonb_codec(conn))
+            await_(self.setup_asyncpg_json_codec(conn))
+            await_(self.setup_asyncpg_jsonb_codec(conn))
 
             if self._native_inet_types is False:
-                conn.await_(self._disable_asyncpg_inet_codecs(conn))
+                await_(self._disable_asyncpg_inet_codecs(conn))
             if super_connect is not None:
                 super_connect(conn)
 
index 743d0388809befd32e8a3675a8bb2a7655cf2840..690cadb6b3aed93374491484e262e7de5de329fb 100644 (file)
@@ -70,15 +70,12 @@ from .json import JSON
 from .json import JSONB
 from .json import JSONPathType
 from .types import CITEXT
-from ... import pool
 from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
-from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
 from ...sql import sqltypes
-from ...util.concurrency import await_fallback
-from ...util.concurrency import await_only
+from ...util.concurrency import await_
 
 if TYPE_CHECKING:
     from typing import Iterable
@@ -585,7 +582,7 @@ class AsyncAdapt_psycopg_ss_cursor(
         iterator = self._cursor.__aiter__()
         while True:
             try:
-                yield self.await_(iterator.__anext__())
+                yield await_(iterator.__anext__())
             except StopAsyncIteration:
                 break
 
@@ -632,16 +629,16 @@ class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
         self.set_autocommit(value)
 
     def set_autocommit(self, value):
-        self.await_(self._connection.set_autocommit(value))
+        await_(self._connection.set_autocommit(value))
 
     def set_isolation_level(self, value):
-        self.await_(self._connection.set_isolation_level(value))
+        await_(self._connection.set_isolation_level(value))
 
     def set_read_only(self, value):
-        self.await_(self._connection.set_read_only(value))
+        await_(self._connection.set_read_only(value))
 
     def set_deferrable(self, value):
-        self.await_(self._connection.set_deferrable(value))
+        await_(self._connection.set_deferrable(value))
 
     def cursor(self, name=None, /):
         if name:
@@ -650,12 +647,6 @@ class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
             return AsyncAdapt_psycopg_cursor(self)
 
 
-class AsyncAdaptFallback_psycopg_connection(
-    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_psycopg_connection
-):
-    __slots__ = ()
-
-
 class PsycopgAdaptDBAPI:
     def __init__(self, psycopg, ExecStatus) -> None:
         self.psycopg = psycopg
@@ -666,18 +657,12 @@ class PsycopgAdaptDBAPI:
                 self.__dict__[k] = v
 
     def connect(self, *arg, **kw):
-        async_fallback = kw.pop("async_fallback", False)
         creator_fn = kw.pop(
             "async_creator_fn", self.psycopg.AsyncConnection.connect
         )
-        if util.asbool(async_fallback):
-            return AsyncAdaptFallback_psycopg_connection(
-                self, await_fallback(creator_fn(*arg, **kw))
-            )
-        else:
-            return AsyncAdapt_psycopg_connection(
-                self, await_only(creator_fn(*arg, **kw))
-            )
+        return AsyncAdapt_psycopg_connection(
+            self, await_(creator_fn(*arg, **kw))
+        )
 
 
 class PGDialectAsync_psycopg(PGDialect_psycopg):
@@ -691,20 +676,11 @@ class PGDialectAsync_psycopg(PGDialect_psycopg):
 
         return PsycopgAdaptDBAPI(psycopg, ExecStatus)
 
-    @classmethod
-    def get_pool_class(cls, url):
-        async_fallback = url.query.get("async_fallback", False)
-
-        if util.asbool(async_fallback):
-            return pool.FallbackAsyncAdaptedQueuePool
-        else:
-            return pool.AsyncAdaptedQueuePool
-
     def _type_info_fetch(self, connection, name):
         from psycopg.types import TypeInfo
 
         adapted = connection.connection
-        return adapted.await_(TypeInfo.fetch(adapted.driver_connection, name))
+        return await_(TypeInfo.fetch(adapted.driver_connection, name))
 
     def _do_isolation_level(self, connection, autocommit, isolation_level):
         connection.set_autocommit(autocommit)
index 7eccf5fb174e5bc0b779a561afd1f0c1976b0e7b..05e64ee85d910190a2a4bfd0ba4d959a6a471b18 100644 (file)
@@ -83,13 +83,10 @@ from functools import partial
 from .base import SQLiteExecutionContext
 from .pysqlite import SQLiteDialect_pysqlite
 from ... import pool
-from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
-from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
-from ...util.concurrency import await_fallback
-from ...util.concurrency import await_only
+from ...util.concurrency import await_
 
 
 class AsyncAdapt_aiosqlite_cursor(AsyncAdapt_dbapi_cursor):
@@ -126,13 +123,13 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
         self._connection._tx.put_nowait((future, function))
 
         try:
-            return self.await_(future)
+            return await_(future)
         except Exception as error:
             self._handle_exception(error)
 
     def create_function(self, *args, **kw):
         try:
-            self.await_(self._connection.create_function(*args, **kw))
+            await_(self._connection.create_function(*args, **kw))
         except Exception as error:
             self._handle_exception(error)
 
@@ -146,7 +143,7 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
 
     def close(self):
         try:
-            self.await_(self._connection.close())
+            await_(self._connection.close())
         except ValueError:
             # this is undocumented for aiosqlite, that ValueError
             # was raised if .close() was called more than once, which is
@@ -170,12 +167,6 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
             super()._handle_exception(error)
 
 
-class AsyncAdaptFallback_aiosqlite_connection(
-    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aiosqlite_connection
-):
-    __slots__ = ()
-
-
 class AsyncAdapt_aiosqlite_dbapi:
     def __init__(self, aiosqlite, sqlite):
         self.aiosqlite = aiosqlite
@@ -203,8 +194,6 @@ class AsyncAdapt_aiosqlite_dbapi:
             setattr(self, name, getattr(self.sqlite, name))
 
     def connect(self, *arg, **kw):
-        async_fallback = kw.pop("async_fallback", False)
-
         creator_fn = kw.pop("async_creator_fn", None)
         if creator_fn:
             connection = creator_fn(*arg, **kw)
@@ -213,16 +202,10 @@ class AsyncAdapt_aiosqlite_dbapi:
             # it's a Thread.   you'll thank us later
             connection.daemon = True
 
-        if util.asbool(async_fallback):
-            return AsyncAdaptFallback_aiosqlite_connection(
-                self,
-                await_fallback(connection),
-            )
-        else:
-            return AsyncAdapt_aiosqlite_connection(
-                self,
-                await_only(connection),
-            )
+        return AsyncAdapt_aiosqlite_connection(
+            self,
+            await_(connection),
+        )
 
 
 class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
index 33e05120e24e14f7cf9075b77ca82e645f51ed94..6ad1de735ad2ae0fc1acb81db174c509feb0580b 100644 (file)
@@ -484,7 +484,13 @@ class DefaultDialect(Dialect):
 
     @classmethod
     def get_pool_class(cls, url: URL) -> Type[Pool]:
-        return getattr(cls, "poolclass", pool.QueuePool)
+        default: Type[pool.Pool]
+        if cls.is_async:
+            default = pool.AsyncAdaptedQueuePool
+        else:
+            default = pool.QueuePool
+
+        return getattr(cls, "poolclass", default)
 
     def get_dialect_pool_class(self, url: URL) -> Type[Pool]:
         return self.get_pool_class(url)
index faea997deacba6d7bbdf46899081f9d6c641e685..ddd4ceeabbd5ac7072220684423e7812ce9a4b8a 100644 (file)
@@ -42,7 +42,7 @@ from ..sql.compiler import Compiled  # noqa
 from ..sql.compiler import TypeCompiler as TypeCompiler
 from ..sql.compiler import TypeCompiler  # noqa
 from ..util import immutabledict
-from ..util.concurrency import await_only
+from ..util.concurrency import await_
 from ..util.typing import Literal
 from ..util.typing import NotRequired
 
@@ -3400,7 +3400,7 @@ class AdaptedConnection:
             :ref:`asyncio_events_run_async`
 
         """
-        return await_only(fn(self._connection))
+        return await_(fn(self._connection))
 
     def __repr__(self) -> str:
         return "<AdaptedConnection %s>" % self._connection
index 251f52125424252b07ff6fb2bcc20eec714db3b5..69d9cce55c85539955552efe28ad50afd07715ed 100644 (file)
@@ -182,7 +182,7 @@ class GeneratorStartableContext(StartableContext[_T_co]):
                 # tell if we get the same exception back
                 value = typ()
             try:
-                await util.athrow(self.gen, typ, value, traceback)
+                await self.gen.athrow(value)
             except StopAsyncIteration as exc:
                 # Suppress StopIteration *unless* it's the same exception that
                 # was passed to throw().  This prevents a StopIteration
index c25a8f85d87461aa2c08639dfb07b6b927144963..243862cdc530c1d52f37dda018367dd2a441ac8e 100644 (file)
@@ -35,9 +35,6 @@ from .base import reset_none as reset_none
 from .base import reset_rollback as reset_rollback
 from .impl import AssertionPool as AssertionPool
 from .impl import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool
-from .impl import (
-    FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool,
-)
 from .impl import NullPool as NullPool
 from .impl import QueuePool as QueuePool
 from .impl import SingletonThreadPool as SingletonThreadPool
index ced015088cb87d52de74c884d87aaa544a46fbea..9616ad29982ffc7c84268675e54c1b950198fe04 100644 (file)
@@ -257,10 +257,6 @@ class AsyncAdaptedQueuePool(QueuePool):
     _dialect = _AsyncConnDialect()
 
 
-class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool):
-    _queue_class = sqla_queue.FallbackAsyncAdaptedQueue
-
-
 class NullPool(Pool):
 
     """A Pool which does not pool connections.
index 4236dcf92e2ea061715f2e05647d4999ac41532e..1f2bc559125023bee0e339cffaf07ef0168d5c4b 100644 (file)
 # setup/teardown in an asyncio event loop, conditionally based on the
 # current DB driver being used for a test.
 
-# note that SQLAlchemy's asyncio integration also supports a method
-# of running individual asyncio functions inside of separate event loops
-# using "async_fallback" mode; however running whole functions in the event
-# loop is a more accurate test for how SQLAlchemy's asyncio features
-# would run in the real world.
-
-
 from __future__ import annotations
 
 from functools import wraps
 import inspect
 
 from . import config
-from ..util.concurrency import _util_async_run
-from ..util.concurrency import _util_async_run_coroutine_function
+from ..util.concurrency import _AsyncUtil
 
 # may be set to False if the
 # --disable-asyncio flag is passed to the test runner.
 ENABLE_ASYNCIO = True
+_async_util = _AsyncUtil()  # it has lazy init so just always create one
+
+
+def _shutdown():
+    """called when the test finishes"""
+    _async_util.close()
 
 
 def _run_coroutine_function(fn, *args, **kwargs):
-    return _util_async_run_coroutine_function(fn, *args, **kwargs)
+    return _async_util.run(fn, *args, **kwargs)
 
 
 def _assume_async(fn, *args, **kwargs):
@@ -50,7 +48,7 @@ def _assume_async(fn, *args, **kwargs):
     if not ENABLE_ASYNCIO:
         return fn(*args, **kwargs)
 
-    return _util_async_run(fn, *args, **kwargs)
+    return _async_util.run_in_greenlet(fn, *args, **kwargs)
 
 
 def _maybe_async_provisioning(fn, *args, **kwargs):
@@ -69,7 +67,7 @@ def _maybe_async_provisioning(fn, *args, **kwargs):
         return fn(*args, **kwargs)
 
     if config.any_async:
-        return _util_async_run(fn, *args, **kwargs)
+        return _async_util.run_in_greenlet(fn, *args, **kwargs)
     else:
         return fn(*args, **kwargs)
 
@@ -89,7 +87,7 @@ def _maybe_async(fn, *args, **kwargs):
     is_async = config._current.is_async
 
     if is_async:
-        return _util_async_run(fn, *args, **kwargs)
+        return _async_util.run_in_greenlet(fn, *args, **kwargs)
     else:
         return fn(*args, **kwargs)
 
index 8430203dee28824464877e98d5750b61a497dd3f..be22ff599130fe45aeacc15a29478228258e78c2 100644 (file)
@@ -25,7 +25,6 @@ from typing import Union
 from . import mock
 from . import requirements as _requirements
 from .util import fail
-from .. import util
 
 # default requirements; this is replaced by plugin_base when pytest
 # is run
@@ -330,9 +329,7 @@ class Config:
         self.test_schema = "test_schema"
         self.test_schema_2 = "test_schema_2"
 
-        self.is_async = db.dialect.is_async and not util.asbool(
-            db.url.query.get("async_fallback", False)
-        )
+        self.is_async = db.dialect.is_async
 
     _stack = collections.deque()
     _configs = set()
index 749f9c160e846eb36d34abdddb4b665088fdee3c..2bca37b2b8f2aa6d83d65c01f4326ef56f244b53 100644 (file)
@@ -23,7 +23,7 @@ from .util import decorator
 from .util import gc_collect
 from .. import event
 from .. import pool
-from ..util import await_only
+from ..util import await_
 from ..util.typing import Literal
 
 
@@ -112,7 +112,7 @@ class ConnectionKiller:
                         self._safe(proxy_ref._checkin)
 
             if hasattr(rec, "sync_engine"):
-                await_only(rec.dispose())
+                await_(rec.dispose())
             else:
                 rec.dispose()
         eng.clear()
index 47644e3d28b4ea549790b0ed42c5cb4841018d71..290e2cb5a4fb80b6060d8a059a42392d42f9ca14 100644 (file)
@@ -182,6 +182,12 @@ def pytest_sessionfinish(session):
         collect_types.dump_stats(session.config.option.dump_pyannotate)
 
 
+def pytest_unconfigure(config):
+    from sqlalchemy.testing import asyncio
+
+    asyncio._shutdown()
+
+
 def pytest_collection_finish(session):
     if session.config.option.dump_pyannotate:
         from pyannotate_runtime import collect_types
index 884d558138af079b013dd784cd0a12030df0dd3c..56b8c2972b8d1bde54d593b6e764617ef1954403 100644 (file)
@@ -113,7 +113,7 @@ def generate_db_urls(db_urls, extra_drivers):
         --dburi postgresql://db1  \
         --dburi postgresql://db2  \
         --dburi postgresql://db2  \
-        --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
+        --dbdriver=psycopg2 --dbdriver=asyncpg
 
     Noting that the default postgresql driver is psycopg2,  the output
     would be::
@@ -130,11 +130,10 @@ def generate_db_urls(db_urls, extra_drivers):
     we want to keep it in that dburi.
 
     Driver specific query options can be specified by added them to the
-    driver name. For example, to enable the async fallback option for
-    asyncpg::
+    driver name. For example, to a sample option the asyncpg::
 
         --dburi postgresql://db1  \
-        --dbdriver=asyncpg?async_fallback=true
+        --dbdriver=asyncpg?some_option=a_value
 
     """
     urls = set()
index eaba84ecd27eb99be4cb78df7d66653fe7eb1d15..b288cbbaf49a71553b1eadf6849f12d93dc98ef9 100644 (file)
@@ -1595,7 +1595,8 @@ class SuiteRequirements(Requirements):
 
     @property
     def async_dialect(self):
-        """dialect makes use of await_() to invoke operations on the DBAPI."""
+        """dialect makes use of await_() to invoke operations on the
+        DBAPI."""
 
         return exclusions.closed()
 
index caaa657f935fa1faf9e31b97f8d37ca095d13878..eb17d005e27592f2236d3f8dff14668bf4c5dca6 100644 (file)
@@ -49,7 +49,6 @@ from ._collections import WeakPopulateDict as WeakPopulateDict
 from ._collections import WeakSequence as WeakSequence
 from .compat import anext_ as anext_
 from .compat import arm as arm
-from .compat import athrow as athrow
 from .compat import b as b
 from .compat import b64decode as b64decode
 from .compat import b64encode as b64encode
@@ -69,7 +68,7 @@ from .compat import py312 as py312
 from .compat import py39 as py39
 from .compat import pypy as pypy
 from .compat import win32 as win32
-from .concurrency import await_fallback as await_fallback
+from .concurrency import await_ as await_
 from .concurrency import await_only as await_only
 from .concurrency import greenlet_spawn as greenlet_spawn
 from .concurrency import is_exit_exception as is_exit_exception
index 1bc89970313e64c4f6d93ecbff322e6296c2ffc0..cd071c376232f683cbaae00867d8ac22848e9341 100644 (file)
@@ -20,8 +20,6 @@ import platform
 import sys
 import typing
 from typing import Any
-from typing import AsyncGenerator
-from typing import Awaitable
 from typing import Callable
 from typing import Dict
 from typing import Iterable
@@ -32,7 +30,6 @@ from typing import Sequence
 from typing import Set
 from typing import Tuple
 from typing import Type
-from typing import TypeVar
 
 py312 = sys.version_info >= (3, 12)
 py311 = sys.version_info >= (3, 11)
@@ -50,8 +47,6 @@ has_refcount_gc = bool(cpython)
 
 dottedgetter = operator.attrgetter
 
-_T_co = TypeVar("_T_co", covariant=True)
-
 
 class FullArgSpec(typing.NamedTuple):
     args: List[str]
@@ -101,24 +96,6 @@ def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec:
     )
 
 
-if py312:
-    # we are 95% certain this form of athrow works in former Python
-    # versions, however we are unable to get confirmation;
-    # see https://github.com/python/cpython/issues/105269 where have
-    # been unable to get a straight answer so far
-    def athrow(  # noqa
-        gen: AsyncGenerator[_T_co, Any], typ: Any, value: Any, traceback: Any
-    ) -> Awaitable[_T_co]:
-        return gen.athrow(value)
-
-else:
-
-    def athrow(  # noqa
-        gen: AsyncGenerator[_T_co, Any], typ: Any, value: Any, traceback: Any
-    ) -> Awaitable[_T_co]:
-        return gen.athrow(typ, value, traceback)
-
-
 if py39:
     # python stubs don't have a public type for this. not worth
     # making a protocol
index 9e4c6c85da790ca65d13828552eab2d8a18476be..bcdb928c29619ae14036a54b2f8564a856fb32fa 100644 (file)
@@ -9,21 +9,22 @@
 from __future__ import annotations
 
 import asyncio
-from contextvars import Context
 import sys
 from typing import Any
 from typing import Awaitable
 from typing import Callable
 from typing import Coroutine
 from typing import NoReturn
-from typing import Optional
-from typing import Protocol
 from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 
+from .compat import py311
 from .langhelpers import memoized_property
+from .typing import Literal
+from .typing import Self
+from .typing import TypeGuard
 from .. import exc
-from ..util.typing import TypeGuard
 
 _T = TypeVar("_T")
 
@@ -44,25 +45,6 @@ _ERROR_MESSAGE = (
 )
 
 
-if TYPE_CHECKING:
-
-    class greenlet(Protocol):
-        dead: bool
-        gr_context: Optional[Context]
-
-        def __init__(self, fn: Callable[..., Any], driver: greenlet):
-            ...
-
-        def throw(self, *arg: Any) -> Any:
-            return None
-
-        def switch(self, value: Any) -> Any:
-            return None
-
-    def getcurrent() -> greenlet:
-        ...
-
-
 def _not_implemented(*arg: Any, **kw: Any) -> NoReturn:
     raise ImportError(_ERROR_MESSAGE)
 
@@ -71,10 +53,10 @@ class _concurrency_shim_cls:
     """Late import shim for greenlet"""
 
     __slots__ = (
+        "_has_greenlet",
         "greenlet",
         "_AsyncIoGreenlet",
         "getcurrent",
-        "_util_async_run",
     )
 
     def _initialize(self, *, raise_: bool = True) -> None:
@@ -84,7 +66,7 @@ class _concurrency_shim_cls:
 
         if not TYPE_CHECKING:
             global getcurrent, greenlet, _AsyncIoGreenlet
-            global _has_gr_context, _greenlet_error
+            global _has_gr_context
 
         try:
             from greenlet import getcurrent
@@ -93,73 +75,46 @@ class _concurrency_shim_cls:
             if not TYPE_CHECKING:
                 # set greenlet in the global scope to prevent re-init
                 greenlet = None
-
+            self._has_greenlet = False
             self._initialize_no_greenlet()
             if raise_:
                 raise ImportError(_ERROR_MESSAGE) from e
         else:
-            self._initialize_greenlet()
-
-    def _initialize_greenlet(self) -> None:
-        # If greenlet.gr_context is present in current version of greenlet,
-        # it will be set with the current context on creation.
-        # Refs: https://github.com/python-greenlet/greenlet/pull/198
-        _has_gr_context = hasattr(getcurrent(), "gr_context")
+            self._has_greenlet = True
+            # If greenlet.gr_context is present in current version of greenlet,
+            # it will be set with the current context on creation.
+            # Refs: https://github.com/python-greenlet/greenlet/pull/198
+            _has_gr_context = hasattr(getcurrent(), "gr_context")
 
-        # implementation based on snaury gist at
-        # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
-        # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 # noqa: E501
+            # implementation based on snaury gist at
+            # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
+            # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 # noqa: E501
 
-        class _AsyncIoGreenlet(greenlet):
-            dead: bool
+            class _AsyncIoGreenlet(greenlet):
+                dead: bool
 
-            def __init__(self, fn: Callable[..., Any], driver: greenlet):
-                greenlet.__init__(self, fn, driver)
-                self.driver = driver
-                if _has_gr_context:
-                    self.gr_context = driver.gr_context
+                def __init__(self, fn: Callable[..., Any], driver: greenlet):
+                    greenlet.__init__(self, fn, driver)
+                    self.driver = driver
+                    if _has_gr_context:
+                        self.gr_context = driver.gr_context
 
-        self.greenlet = greenlet
-        self.getcurrent = getcurrent
-        self._AsyncIoGreenlet = _AsyncIoGreenlet
-        self._util_async_run = self._greenlet_util_async_run
+            self.greenlet = greenlet
+            self.getcurrent = getcurrent
+            self._AsyncIoGreenlet = _AsyncIoGreenlet
 
     def _initialize_no_greenlet(self):
-        self._util_async_run = self._no_greenlet_util_async_run
         self.getcurrent = _not_implemented
-        self.greenlet = _not_implemented  # type: ignore
-        self._AsyncIoGreenlet = _not_implemented  # type: ignore
+        self.greenlet = _not_implemented  # type: ignore[assignment]
+        self._AsyncIoGreenlet = _not_implemented  # type: ignore[assignment]
 
     def __getattr__(self, key: str) -> Any:
         if key in self.__slots__:
-            self._initialize(raise_=not key.startswith("_util"))
+            self._initialize()
             return getattr(self, key)
         else:
             raise AttributeError(key)
 
-    def _greenlet_util_async_run(
-        self, fn: Callable[..., Any], *args: Any, **kwargs: Any
-    ) -> Any:
-        """for test suite/ util only"""
-
-        loop = get_event_loop()
-        if not loop.is_running():
-            return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
-        else:
-            # allow for a wrapped test function to call another
-            assert isinstance(
-                _concurrency_shim.getcurrent(),
-                _concurrency_shim._AsyncIoGreenlet,
-            )
-            return fn(*args, **kwargs)
-
-    def _no_greenlet_util_async_run(
-        self, fn: Callable[..., Any], *args: Any, **kwargs: Any
-    ) -> Any:
-        """for test suite/ util only"""
-
-        return fn(*args, **kwargs)
-
 
 _concurrency_shim = _concurrency_shim_cls()
 
@@ -187,11 +142,11 @@ def in_greenlet() -> bool:
     return isinstance(current, _concurrency_shim._AsyncIoGreenlet)
 
 
-def await_only(awaitable: Awaitable[_T]) -> _T:
+def await_(awaitable: Awaitable[_T]) -> _T:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
-    :func:`await_only` calls cannot be nested.
+    :func:`await_` calls cannot be nested.
 
     :param awaitable: The coroutine to call.
 
@@ -202,7 +157,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
         _safe_cancel_awaitable(awaitable)
 
         raise exc.MissingGreenlet(
-            "greenlet_spawn has not been called; can't call await_only() "
+            "greenlet_spawn has not been called; can't call await_() "
             "here. Was IO attempted in an unexpected place?"
         )
 
@@ -213,31 +168,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
     return current.driver.switch(awaitable)  # type: ignore[no-any-return]
 
 
-def await_fallback(awaitable: Awaitable[_T]) -> _T:
-    """Awaits an async function in a sync method.
-
-    The sync method must be inside a :func:`greenlet_spawn` context.
-    :func:`await_fallback` calls cannot be nested.
-
-    :param awaitable: The coroutine to call.
-
-    """
-
-    # this is called in the context greenlet while running fn
-    current = _concurrency_shim.getcurrent()
-    if not isinstance(current, _concurrency_shim._AsyncIoGreenlet):
-        loop = get_event_loop()
-        if loop.is_running():
-            _safe_cancel_awaitable(awaitable)
-
-            raise exc.MissingGreenlet(
-                "greenlet_spawn has not been called and asyncio event "
-                "loop is already running; can't call await_fallback() here. "
-                "Was IO attempted in an unexpected place?"
-            )
-        return loop.run_until_complete(awaitable)
-
-    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
+await_only = await_  # old name. deprecated on 2.2
 
 
 async def greenlet_spawn(
@@ -248,7 +179,7 @@ async def greenlet_spawn(
 ) -> _T:
     """Runs a sync function ``fn`` in a new greenlet.
 
-    The sync function can then use :func:`await_only` to wait for async
+    The sync function can then use :func:`await_` to wait for async
     functions.
 
     :param fn: The sync callable to call.
@@ -261,7 +192,7 @@ async def greenlet_spawn(
         fn, _concurrency_shim.getcurrent()
     )
     # runs the function synchronously in gl greenlet. If the execution
-    # is interrupted by await_only, context is not dead and result is a
+    # is interrupted by await_, context is not dead and result is a
     # coroutine to wait. If the context is dead the function has
     # returned, and its result can be returned.
     switch_occurred = False
@@ -270,7 +201,7 @@ async def greenlet_spawn(
         while not context.dead:
             switch_occurred = True
             try:
-                # wait for a coroutine from await_only and then return its
+                # wait for a coroutine from await_ and then return its
                 # result back to it.
                 value = await result
             except BaseException:
@@ -302,42 +233,92 @@ class AsyncAdaptedLock:
     def __enter__(self) -> bool:
         # await is used to acquire the lock only after the first calling
         # coroutine has created the mutex.
-        return await_fallback(self.mutex.acquire())
+        return await_(self.mutex.acquire())
 
     def __exit__(self, *arg: Any, **kw: Any) -> None:
         self.mutex.release()
 
 
-def _util_async_run_coroutine_function(
-    fn: Callable[..., Any], *args: Any, **kwargs: Any
-) -> Any:
-    """for test suite/ util only"""
-
-    loop = get_event_loop()
-    if loop.is_running():
-        raise Exception(
-            "for async run coroutine we expect that no greenlet or event "
-            "loop is running when we start out"
-        )
-    return loop.run_until_complete(fn(*args, **kwargs))
-
-
-def _util_async_run(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
-    """for test suite/ util only"""
-
-    _util_async_run = _concurrency_shim._util_async_run
-    return _util_async_run(fn, *args, **kwargs)
-
-
-def get_event_loop() -> asyncio.AbstractEventLoop:
-    """vendor asyncio.get_event_loop() for python 3.7 and above.
+if py311:
+    _Runner = asyncio.Runner
+else:
 
-    Python 3.10 deprecates get_event_loop() as a standalone.
+    class _Runner:  # type: ignore[no-redef]
+        """Runner implementation for test only"""
+
+        _loop: Union[None, asyncio.AbstractEventLoop, Literal[False]]
+
+        def __init__(self) -> None:
+            self._loop = None
+
+        def __enter__(self) -> Self:
+            self._lazy_init()
+            return self
+
+        def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+            self.close()
+
+        def close(self) -> None:
+            if self._loop:
+                try:
+                    self._loop.run_until_complete(
+                        self._loop.shutdown_asyncgens()
+                    )
+                finally:
+                    self._loop.close()
+                    self._loop = False
+
+        def get_loop(self) -> asyncio.AbstractEventLoop:
+            """Return embedded event loop."""
+            self._lazy_init()
+            assert self._loop
+            return self._loop
+
+        def run(self, coro: Coroutine[Any, Any, _T]) -> _T:
+            self._lazy_init()
+            assert self._loop
+            return self._loop.run_until_complete(coro)
+
+        def _lazy_init(self) -> None:
+            if self._loop is False:
+                raise RuntimeError("Runner is closed")
+            if self._loop is None:
+                self._loop = asyncio.new_event_loop()
+
+
+class _AsyncUtil:
+    """Asyncio util for test suite/ util only"""
+
+    def __init__(self) -> None:
+        self.runner = _Runner()  # runner it lazy so it can be created here
+
+    def run(
+        self,
+        fn: Callable[..., Coroutine[Any, Any, _T]],
+        *args: Any,
+        **kwargs: Any,
+    ) -> _T:
+        """Run coroutine on the loop"""
+        return self.runner.run(fn(*args, **kwargs))
+
+    def run_in_greenlet(
+        self, fn: Callable[..., _T], *args: Any, **kwargs: Any
+    ) -> _T:
+        """Run sync function in greenlet. Support nested calls"""
+        _concurrency_shim._initialize(raise_=False)
+
+        if _concurrency_shim._has_greenlet:
+            if self.runner.get_loop().is_running():
+                # allow for a wrapped test function to call another
+                assert isinstance(
+                    _concurrency_shim.getcurrent(),
+                    _concurrency_shim._AsyncIoGreenlet,
+                )
+                return fn(*args, **kwargs)
+            else:
+                return self.runner.run(greenlet_spawn(fn, *args, **kwargs))
+        else:
+            return fn(*args, **kwargs)
 
-    """
-    try:
-        return asyncio.get_running_loop()
-    except RuntimeError:
-        # avoid "During handling of the above exception, another exception..."
-        pass
-    return asyncio.get_event_loop_policy().get_event_loop()
+    def close(self) -> None:
+        self.runner.close()
index b641c910c7177bfd45c17948793f7a247aa66fe3..a631fa67ea0d224b96794ee03cbc2957d5f3da3a 100644 (file)
@@ -24,16 +24,13 @@ import asyncio
 from collections import deque
 import threading
 from time import time as _time
-import typing
 from typing import Any
-from typing import Awaitable
 from typing import Deque
 from typing import Generic
 from typing import Optional
 from typing import TypeVar
 
-from .concurrency import await_fallback
-from .concurrency import await_only
+from .concurrency import await_
 from .langhelpers import memoized_property
 
 
@@ -239,15 +236,6 @@ class Queue(QueueCommon[_T]):
 
 
 class AsyncAdaptedQueue(QueueCommon[_T]):
-    if typing.TYPE_CHECKING:
-
-        @staticmethod
-        def await_(coroutine: Awaitable[Any]) -> _T:
-            ...
-
-    else:
-        await_ = staticmethod(await_only)
-
     def __init__(self, maxsize: int = 0, use_lifo: bool = False):
         self.use_lifo = use_lifo
         self.maxsize = maxsize
@@ -292,9 +280,9 @@ class AsyncAdaptedQueue(QueueCommon[_T]):
 
         try:
             if timeout is not None:
-                self.await_(asyncio.wait_for(self._queue.put(item), timeout))
+                await_(asyncio.wait_for(self._queue.put(item), timeout))
             else:
-                self.await_(self._queue.put(item))
+                await_(self._queue.put(item))
         except (asyncio.QueueFull, asyncio.TimeoutError) as err:
             raise Full() from err
 
@@ -310,15 +298,8 @@ class AsyncAdaptedQueue(QueueCommon[_T]):
 
         try:
             if timeout is not None:
-                return self.await_(
-                    asyncio.wait_for(self._queue.get(), timeout)
-                )
+                return await_(asyncio.wait_for(self._queue.get(), timeout))
             else:
-                return self.await_(self._queue.get())
+                return await_(self._queue.get())
         except (asyncio.QueueEmpty, asyncio.TimeoutError) as err:
             raise Empty() from err
-
-
-class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue[_T]):
-    if not typing.TYPE_CHECKING:
-        await_ = staticmethod(await_fallback)
index 890aea977a40d2c3aa3528f932b94e49c2a5de5f..129a5aa82d9f28012a446ab9547409d5fe416fa5 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -64,8 +64,8 @@ postgresql_asyncpg =
     asyncpg
 postgresql_psycopg2binary = psycopg2-binary
 postgresql_psycopg2cffi = psycopg2cffi
-postgresql_psycopg = psycopg>=3.0.7
-postgresql_psycopgbinary = psycopg[binary]>=3.0.7
+postgresql_psycopg = psycopg>=3.0.7,!=3.1.15
+postgresql_psycopgbinary = psycopg[binary]>=3.0.7,!=3.1.15
 pymysql =
     pymysql
 aiomysql =
@@ -162,17 +162,13 @@ postgresql = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test
 psycopg2 = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test
 psycopg = postgresql+psycopg://scott:tiger@127.0.0.1:5432/test
 psycopg_async = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test
-psycopg_async_fallback = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test?async_fallback=true
 asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test
-asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true
 pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
 postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test
 mysql = mysql+mysqldb://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
 pymysql = mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
 aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
-aiomysql_fallback = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true
 asyncmy = mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
-asyncmy_fallback = mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true
 mariadb = mariadb+mysqldb://scott:tiger@127.0.0.1:3306/test
 mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test
 mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes
index 587eb644d1e9622baf6d63bc694ff0b2bd3e2147..0a7f00c089676577131a071aed1b94d61548969b 100644 (file)
@@ -13,7 +13,7 @@ def greenlet_not_imported():
     import sqlalchemy
     import sqlalchemy.util.concurrency  # noqa: F401
     from sqlalchemy.util import greenlet_spawn  # noqa: F401
-    from sqlalchemy.util.concurrency import await_only  # noqa: F401
+    from sqlalchemy.util.concurrency import await_  # noqa: F401
 
     assert "greenlet" not in sys.modules
 
index 1ea61ba7cecf905551ac231f92ad136d2823ae1b..274bcfe7c1b55b08b5a77964e2de78a46d92eb5a 100644 (file)
@@ -13,8 +13,7 @@ from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing.config import combinations
-from sqlalchemy.util import await_fallback
-from sqlalchemy.util import await_only
+from sqlalchemy.util import await_
 from sqlalchemy.util import greenlet_spawn
 from sqlalchemy.util import queue
 from ._concurrency_fixtures import greenlet_not_imported
@@ -36,7 +35,7 @@ async def run2():
 
 
 def go(*fns):
-    return sum(await_only(fn()) for fn in fns)
+    return sum(await_(fn()) for fn in fns)
 
 
 class TestAsyncioCompat(fixtures.TestBase):
@@ -64,7 +63,7 @@ class TestAsyncioCompat(fixtures.TestBase):
 
         def sync_meth():
             try:
-                await_only(async_meth_raise())
+                await_(async_meth_raise())
             except:
                 cleanup.append(True)
                 raise
@@ -80,56 +79,29 @@ class TestAsyncioCompat(fixtures.TestBase):
     @async_test
     async def test_sync_error(self):
         def go():
-            await_only(run1())
+            await_(run1())
             raise ValueError("sync error")
 
         with expect_raises_message(ValueError, "sync error"):
             await greenlet_spawn(go)
 
-    def test_await_fallback_no_greenlet(self):
-        to_await = run1()
-        await_fallback(to_await)
-
     @async_test
     async def test_await_only_no_greenlet(self):
         to_await = run1()
         with expect_raises_message(
             exc.MissingGreenlet,
             "greenlet_spawn has not been called; "
-            r"can't call await_only\(\) here.",
+            r"can't call await_\(\) here.",
         ):
-            await_only(to_await)
+            await_(to_await)
 
         # existing awaitable is done
         with expect_raises(RuntimeError):
-            await greenlet_spawn(await_fallback, to_await)
+            await greenlet_spawn(await_, to_await)
 
         # no warning for a new one...
         to_await = run1()
-        await greenlet_spawn(await_fallback, to_await)
-
-    @async_test
-    async def test_await_fallback_error(self):
-        to_await = run1()
-
-        await to_await
-
-        async def inner_await():
-            nonlocal to_await
-            to_await = run1()
-            await_fallback(to_await)
-
-        def go():
-            await_fallback(inner_await())
-
-        with expect_raises_message(
-            exc.MissingGreenlet,
-            "greenlet_spawn has not been called and asyncio event loop",
-        ):
-            await greenlet_spawn(go)
-
-        with expect_raises(RuntimeError):
-            await to_await
+        await greenlet_spawn(await_, to_await)
 
     @async_test
     async def test_await_only_error(self):
@@ -140,15 +112,15 @@ class TestAsyncioCompat(fixtures.TestBase):
         async def inner_await():
             nonlocal to_await
             to_await = run1()
-            await_only(to_await)
+            await_(to_await)
 
         def go():
-            await_only(inner_await())
+            await_(inner_await())
 
         with expect_raises_message(
             exc.InvalidRequestError,
             "greenlet_spawn has not been called; "
-            r"can't call await_only\(\) here.",
+            r"can't call await_\(\) here.",
         ):
             await greenlet_spawn(go)
 
@@ -172,22 +144,22 @@ class TestAsyncioCompat(fixtures.TestBase):
             var.set(val)
 
         def inner(val):
-            retval = await_only(async_inner(val))
+            retval = await_(async_inner(val))
             eq_(val, var.get())
             eq_(retval, val)
 
             # set the value in a sync function
             newval = val + concurrency
             var.set(newval)
-            syncset = await_only(async_inner(newval))
+            syncset = await_(async_inner(newval))
             eq_(newval, var.get())
             eq_(syncset, newval)
 
             # set the value in an async function
             retval = val + 2 * concurrency
-            await_only(async_set(retval))
+            await_(async_set(retval))
             eq_(var.get(), retval)
-            eq_(await_only(async_inner(retval)), retval)
+            eq_(await_(async_inner(retval)), retval)
 
             return retval
 
@@ -304,4 +276,4 @@ class GracefulNoGreenletTest(fixtures.TestBase):
             "The SQLAlchemy asyncio module requires that the Python "
             "'greenlet' library is installed",
         ):
-            await_only(async_fn())
+            await_(async_fn())
index 49736df9b65ddd7638b6f376637bca96ff9bab8f..33abfed4d274cc0f353d191e970bc51c9bcdc03c 100644 (file)
@@ -286,7 +286,6 @@ class PoolTest(PoolTestBase):
     @testing.combinations(
         (pool.QueuePool, False),
         (pool.AsyncAdaptedQueuePool, True),
-        (pool.FallbackAsyncAdaptedQueuePool, True),
         (pool.NullPool, None),
         (pool.SingletonThreadPool, False),
         (pool.StaticPool, None),
@@ -307,7 +306,6 @@ class PoolTest(PoolTestBase):
     @testing.combinations(
         (pool.QueuePool, False),
         (pool.AsyncAdaptedQueuePool, True),
-        (pool.FallbackAsyncAdaptedQueuePool, True),
         (pool.NullPool, False),
         (pool.SingletonThreadPool, False),
         (pool.StaticPool, False),
index 7289d5494ebcaa172254fad8cc05915798c5ab0b..adb6b0b6c9d08b31cd9c44c82aaf9a05b9779396 100644 (file)
@@ -351,9 +351,9 @@ class AsyncEngineTest(EngineFixture):
             pool_connection = await conn.get_raw_connection()
             return pool_connection
 
-        from sqlalchemy.util.concurrency import await_only
+        from sqlalchemy.util.concurrency import await_
 
-        pool_connection = await_only(go())
+        pool_connection = await_(go())
 
         rec = pool_connection._connection_record
         ref = rec.fairy_ref
index 4a0b365c2b546bd5bf5d699a3ef9de278dd092a7..1626c825f24020dc9ba2d360f3ca6b41ec31f971 100644 (file)
@@ -1527,7 +1527,8 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def async_dialect(self):
-        """dialect makes use of await_() to invoke operations on the DBAPI."""
+        """dialect makes use of await_() to invoke operations on
+        the DBAPI."""
 
         return self.asyncio + only_on(
             LambdaPredicate(