From: Federico Caselli Date: Tue, 5 Dec 2023 21:29:19 +0000 (+0100) Subject: Remove async_fallback mode X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3976537274e8e3d798c8c88bf570c49e9fd7ef6d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Remove async_fallback mode Removed the async_fallback mode and await_fallback function. Replace get_event_loop with Runner. Removed the internal function ``await_fallback()``. Renamed the internal function ``await_only()`` to ``await_()``. Change-Id: Ib43829be6ebdb59b6c4447f5a15b5d2b81403fa9 --- diff --git a/README.unittests.rst b/README.unittests.rst index d7155c1ac2..046a30f6a9 100644 --- a/README.unittests.rst +++ b/README.unittests.rst @@ -83,13 +83,10 @@ a pre-set URL. These can be seen using --dbs:: $ pytest --dbs Available --db options (use --dburi to override) aiomysql mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 - aiomysql_fallback mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true aiosqlite sqlite+aiosqlite:///:memory: aiosqlite_file sqlite+aiosqlite:///async_querytest.db asyncmy mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 - asyncmy_fallback mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true asyncpg postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test - asyncpg_fallback postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true default sqlite:///:memory: docker_mssql mssql+pymssql://scott:tiger^5HHH@127.0.0.1:1433/test mariadb mariadb+mysqldb://scott:tiger@127.0.0.1:3306/test @@ -105,7 +102,6 @@ a pre-set URL. These can be seen using --dbs:: psycopg postgresql+psycopg://scott:tiger@127.0.0.1:5432/test psycopg2 postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test psycopg_async postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test - psycopg_async_fallback postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test?async_fallback=true pymysql mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 pysqlcipher_file sqlite+pysqlcipher://:test@/querytest.db.enc sqlite sqlite:///:memory: diff --git a/doc/build/changelog/unreleased_21/async_fallback.rst b/doc/build/changelog/unreleased_21/async_fallback.rst new file mode 100644 index 0000000000..44b91d2156 --- /dev/null +++ b/doc/build/changelog/unreleased_21/async_fallback.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: change, asyncio + + Removed the compatibility ``async_fallback`` mode for async dialects, + since it's no longer used by SQLAlchemy tests. + Also removed the internal function ``await_fallback()`` and renamed + the internal function ``await_only()`` to ``await_()``. + No change is expected to user code. diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 2300c2d409..af030614a5 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -47,9 +47,6 @@ from .engine import URL as URL from .inspection import inspect as inspect from .pool import AssertionPool as AssertionPool from .pool import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool -from .pool import ( - FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, -) from .pool import NullPool as NullPool from .pool import Pool as Pool from .pool import PoolProxiedConnection as PoolProxiedConnection diff --git a/lib/sqlalchemy/connectors/aioodbc.py b/lib/sqlalchemy/connectors/aioodbc.py index e0f5f55474..927330b286 100644 --- a/lib/sqlalchemy/connectors/aioodbc.py +++ b/lib/sqlalchemy/connectors/aioodbc.py @@ -13,12 +13,8 @@ from typing import TYPE_CHECKING from .asyncio import AsyncAdapt_dbapi_connection from .asyncio import AsyncAdapt_dbapi_cursor from .asyncio import AsyncAdapt_dbapi_ss_cursor -from .asyncio import AsyncAdaptFallback_dbapi_connection from .pyodbc import PyODBCConnector -from .. import pool -from .. import util -from ..util.concurrency import await_fallback -from ..util.concurrency import await_only +from ..util.concurrency import await_ if TYPE_CHECKING: from ..engine.interfaces import ConnectArgsType @@ -33,7 +29,7 @@ class AsyncAdapt_aioodbc_cursor(AsyncAdapt_dbapi_cursor): return self._cursor._impl.setinputsizes(*inputsizes) # how it's supposed to work - # return self.await_(self._cursor.setinputsizes(*inputsizes)) + # return await_(self._cursor.setinputsizes(*inputsizes)) class AsyncAdapt_aioodbc_ss_cursor( @@ -59,7 +55,7 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection): self._connection._conn.autocommit = value def ping(self, reconnect): - return self.await_(self._connection.ping(reconnect)) + return await_(self._connection.ping(reconnect)) def add_output_converter(self, *arg, **kw): self._connection.add_output_converter(*arg, **kw) @@ -96,12 +92,6 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection): super().close() -class AsyncAdaptFallback_aioodbc_connection( - AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aioodbc_connection -): - __slots__ = () - - class AsyncAdapt_aioodbc_dbapi: def __init__(self, aioodbc, pyodbc): self.aioodbc = aioodbc @@ -136,19 +126,12 @@ class AsyncAdapt_aioodbc_dbapi: setattr(self, name, getattr(self.pyodbc, name)) def connect(self, *arg, **kw): - async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", self.aioodbc.connect) - if util.asbool(async_fallback): - return AsyncAdaptFallback_aioodbc_connection( - self, - await_fallback(creator_fn(*arg, **kw)), - ) - else: - return AsyncAdapt_aioodbc_connection( - self, - await_only(creator_fn(*arg, **kw)), - ) + return AsyncAdapt_aioodbc_connection( + self, + await_(creator_fn(*arg, **kw)), + ) class aiodbcConnector(PyODBCConnector): @@ -170,15 +153,6 @@ class aiodbcConnector(PyODBCConnector): return (), kw - @classmethod - def get_pool_class(cls, url): - async_fallback = url.query.get("async_fallback", False) - - if util.asbool(async_fallback): - return pool.FallbackAsyncAdaptedQueuePool - else: - return pool.AsyncAdaptedQueuePool - def _do_isolation_level(self, connection, autocommit, isolation_level): connection.set_autocommit(autocommit) connection.set_isolation_level(isolation_level) diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index 9358457ceb..f17831068c 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -25,8 +25,7 @@ from ..engine import AdaptedConnection from ..engine.interfaces import _DBAPICursorDescription from ..engine.interfaces import _DBAPIMultiExecuteParams from ..engine.interfaces import _DBAPISingleExecuteParams -from ..util.concurrency import await_fallback -from ..util.concurrency import await_only +from ..util.concurrency import await_ from ..util.typing import Self @@ -121,7 +120,6 @@ class AsyncAdapt_dbapi_cursor: __slots__ = ( "_adapt_connection", "_connection", - "await_", "_cursor", "_rows", ) @@ -134,12 +132,11 @@ class AsyncAdapt_dbapi_cursor: def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ cursor = self._make_new_cursor(self._connection) try: - self._cursor = self.await_(cursor.__aenter__()) + self._cursor = await_(cursor.__aenter__()) except Exception as error: self._adapt_connection._handle_exception(error) @@ -181,7 +178,7 @@ class AsyncAdapt_dbapi_cursor: parameters: Optional[_DBAPISingleExecuteParams] = None, ) -> Any: try: - return self.await_(self._execute_async(operation, parameters)) + return await_(self._execute_async(operation, parameters)) except Exception as error: self._adapt_connection._handle_exception(error) @@ -191,7 +188,7 @@ class AsyncAdapt_dbapi_cursor: seq_of_parameters: _DBAPIMultiExecuteParams, ) -> Any: try: - return self.await_( + return await_( self._executemany_async(operation, seq_of_parameters) ) except Exception as error: @@ -223,18 +220,16 @@ class AsyncAdapt_dbapi_cursor: return await self._cursor.executemany(operation, seq_of_parameters) def nextset(self) -> None: - self.await_(self._cursor.nextset()) + await_(self._cursor.nextset()) if self._cursor.description and not self.server_side: - self._rows = collections.deque( - self.await_(self._cursor.fetchall()) - ) + self._rows = collections.deque(await_(self._cursor.fetchall())) def setinputsizes(self, *inputsizes: Any) -> None: # NOTE: this is overrridden in aioodbc due to # see https://github.com/aio-libs/aioodbc/issues/451 # right now - return self.await_(self._cursor.setinputsizes(*inputsizes)) + return await_(self._cursor.setinputsizes(*inputsizes)) def __enter__(self) -> Self: return self @@ -273,24 +268,23 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): def close(self) -> None: if self._cursor is not None: - self.await_(self._cursor.close()) + await_(self._cursor.close()) self._cursor = None # type: ignore def fetchone(self) -> Optional[Any]: - return self.await_(self._cursor.fetchone()) + return await_(self._cursor.fetchone()) def fetchmany(self, size: Optional[int] = None) -> Any: - return self.await_(self._cursor.fetchmany(size=size)) + return await_(self._cursor.fetchmany(size=size)) def fetchall(self) -> Sequence[Any]: - return self.await_(self._cursor.fetchall()) + return await_(self._cursor.fetchall()) class AsyncAdapt_dbapi_connection(AdaptedConnection): _cursor_cls = AsyncAdapt_dbapi_cursor _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor - await_ = staticmethod(await_only) __slots__ = ("dbapi", "_execute_mutex") _connection: AsyncIODBAPIConnection @@ -323,21 +317,15 @@ class AsyncAdapt_dbapi_connection(AdaptedConnection): def rollback(self) -> None: try: - self.await_(self._connection.rollback()) + await_(self._connection.rollback()) except Exception as error: self._handle_exception(error) def commit(self) -> None: try: - self.await_(self._connection.commit()) + await_(self._connection.commit()) except Exception as error: self._handle_exception(error) def close(self) -> None: - self.await_(self._connection.close()) - - -class AsyncAdaptFallback_dbapi_connection(AsyncAdapt_dbapi_connection): - __slots__ = () - - await_ = staticmethod(await_fallback) + await_(self._connection.close()) diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index 978950b878..f92b1bfaa6 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -28,14 +28,10 @@ This dialect should normally be used only with the """ # noqa from .pymysql import MySQLDialect_pymysql -from ... import pool -from ... import util from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor -from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection -from ...util.concurrency import await_fallback -from ...util.concurrency import await_only +from ...util.concurrency import await_ class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor): @@ -64,25 +60,19 @@ class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection): def ping(self, reconnect): assert not reconnect - return self.await_(self._connection.ping(reconnect)) + return await_(self._connection.ping(reconnect)) def character_set_name(self): return self._connection.character_set_name() def autocommit(self, value): - self.await_(self._connection.autocommit(value)) + await_(self._connection.autocommit(value)) def close(self): # it's not awaitable. self._connection.close() -class AsyncAdaptFallback_aiomysql_connection( - AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aiomysql_connection -): - __slots__ = () - - class AsyncAdapt_aiomysql_dbapi: def __init__(self, aiomysql, pymysql): self.aiomysql = aiomysql @@ -118,19 +108,12 @@ class AsyncAdapt_aiomysql_dbapi: setattr(self, name, getattr(self.pymysql, name)) def connect(self, *arg, **kw): - async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect) - if util.asbool(async_fallback): - return AsyncAdaptFallback_aiomysql_connection( - self, - await_fallback(creator_fn(*arg, **kw)), - ) - else: - return AsyncAdapt_aiomysql_connection( - self, - await_only(creator_fn(*arg, **kw)), - ) + return AsyncAdapt_aiomysql_connection( + self, + await_(creator_fn(*arg, **kw)), + ) def _init_cursors_subclasses(self): # suppress unconditional warning emitted by aiomysql @@ -160,15 +143,6 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): __import__("aiomysql"), __import__("pymysql") ) - @classmethod - def get_pool_class(cls, url): - async_fallback = url.query.get("async_fallback", False) - - if util.asbool(async_fallback): - return pool.FallbackAsyncAdaptedQueuePool - else: - return pool.AsyncAdaptedQueuePool - def create_connect_args(self, url): return super().create_connect_args( url, _translate_args=dict(username="user", database="db") diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index 3029626fd5..7f2a9979e6 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -28,14 +28,11 @@ This dialect should normally be used only with the from __future__ import annotations from .pymysql import MySQLDialect_pymysql -from ... import pool from ... import util from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor -from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection -from ...util.concurrency import await_fallback -from ...util.concurrency import await_only +from ...util.concurrency import await_ class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor): @@ -69,7 +66,7 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection): def ping(self, reconnect): assert not reconnect - return self.await_(self._do_ping()) + return await_(self._do_ping()) async def _do_ping(self): try: @@ -82,19 +79,13 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection): return self._connection.character_set_name() def autocommit(self, value): - self.await_(self._connection.autocommit(value)) + await_(self._connection.autocommit(value)) def close(self): # it's not awaitable. self._connection.close() -class AsyncAdaptFallback_asyncmy_connection( - AsyncAdaptFallback_dbapi_connection, AsyncAdapt_asyncmy_connection -): - __slots__ = () - - def _Binary(x): """Return x as a binary type.""" return bytes(x) @@ -130,19 +121,12 @@ class AsyncAdapt_asyncmy_dbapi: Binary = staticmethod(_Binary) def connect(self, *arg, **kw): - async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect) - if util.asbool(async_fallback): - return AsyncAdaptFallback_asyncmy_connection( - self, - await_fallback(creator_fn(*arg, **kw)), - ) - else: - return AsyncAdapt_asyncmy_connection( - self, - await_only(creator_fn(*arg, **kw)), - ) + return AsyncAdapt_asyncmy_connection( + self, + await_(creator_fn(*arg, **kw)), + ) class MySQLDialect_asyncmy(MySQLDialect_pymysql): @@ -158,15 +142,6 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): def import_dbapi(cls): return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy")) - @classmethod - def get_pool_class(cls, url): - async_fallback = url.query.get("async_fallback", False) - - if util.asbool(async_fallback): - return pool.FallbackAsyncAdaptedQueuePool - else: - return pool.AsyncAdaptedQueuePool - def create_connect_args(self, url): return super().create_connect_args( url, _translate_args=dict(username="user", database="db") diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 2ce68acce6..d138c1819a 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -25,17 +25,6 @@ This dialect should normally be used only with the from sqlalchemy.ext.asyncio import create_async_engine engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname") -The dialect can also be run as a "synchronous" dialect within the -:func:`_sa.create_engine` function, which will pass "await" calls into -an ad-hoc event loop. This mode of operation is of **limited use** -and is for special testing scenarios only. The mode can be enabled by -adding the SQLAlchemy-specific flag ``async_fallback`` to the URL -in conjunction with :func:`_sa.create_engine`:: - - # for testing purposes only; do not use in production! - engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true") - - .. versionadded:: 1.4 .. note:: @@ -217,15 +206,13 @@ from .types import BIT from .types import BYTEA from .types import CITEXT from ... import exc -from ... import pool from ... import util from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor from ...engine import processors from ...sql import sqltypes -from ...util.concurrency import await_fallback -from ...util.concurrency import await_only +from ...util.concurrency import await_ if TYPE_CHECKING: from ...engine.interfaces import _DBAPICursorDescription @@ -556,7 +543,6 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor): def __init__(self, adapt_connection: AsyncAdapt_asyncpg_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ self._cursor = None self._rows = collections.deque() self._description = None @@ -654,14 +640,10 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor): self._handle_exception(error) def execute(self, operation, parameters=None): - self._adapt_connection.await_( - self._prepare_and_execute(operation, parameters) - ) + await_(self._prepare_and_execute(operation, parameters)) def executemany(self, operation, seq_of_parameters): - return self._adapt_connection.await_( - self._executemany(operation, seq_of_parameters) - ) + return await_(self._executemany(operation, seq_of_parameters)) def setinputsizes(self, *inputsizes): raise NotImplementedError() @@ -683,7 +665,7 @@ class AsyncAdapt_asyncpg_ss_cursor( def _buffer_rows(self): assert self._cursor is not None - new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) + new_rows = await_(self._cursor.fetch(50)) self._rowbuffer = collections.deque(new_rows) def __aiter__(self): @@ -721,9 +703,7 @@ class AsyncAdapt_asyncpg_ss_cursor( buf = list(self._rowbuffer) lb = len(buf) if size > lb: - buf.extend( - self._adapt_connection.await_(self._cursor.fetch(size - lb)) - ) + buf.extend(await_(self._cursor.fetch(size - lb))) result = buf[0:size] self._rowbuffer = collections.deque(buf[size:]) @@ -732,9 +712,7 @@ class AsyncAdapt_asyncpg_ss_cursor( def fetchall(self): assert self._rowbuffer is not None - ret = list(self._rowbuffer) + list( - self._adapt_connection.await_(self._all()) - ) + ret = list(self._rowbuffer) + list(await_(self._all())) self._rowbuffer.clear() return ret @@ -876,7 +854,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): def ping(self): try: - _ = self.await_(self._async_ping()) + _ = await_(self._async_ping()) except Exception as error: self._handle_exception(error) @@ -918,7 +896,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): if self._started: assert self._transaction is not None try: - self.await_(self._transaction.rollback()) + await_(self._transaction.rollback()) except Exception as error: self._handle_exception(error) finally: @@ -929,7 +907,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): if self._started: assert self._transaction is not None try: - self.await_(self._transaction.commit()) + await_(self._transaction.commit()) except Exception as error: self._handle_exception(error) finally: @@ -939,7 +917,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): def close(self): self.rollback() - self.await_(self._connection.close()) + await_(self._connection.close()) def terminate(self): if util.concurrency.in_greenlet(): @@ -948,7 +926,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): try: # try to gracefully close; see #10717 # timeout added in asyncpg 0.14.0 December 2017 - self.await_(self._connection.close(timeout=2)) + await_(self._connection.close(timeout=2)) except asyncio.TimeoutError: # in the case where we are recycling an old connection # that may have already been disconnected, close() will @@ -966,19 +944,12 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): return None -class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection): - __slots__ = () - - await_ = staticmethod(await_fallback) - - class AsyncAdapt_asyncpg_dbapi: def __init__(self, asyncpg): self.asyncpg = asyncpg self.paramstyle = "numeric_dollar" def connect(self, *arg, **kw): - async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect) prepared_statement_cache_size = kw.pop( "prepared_statement_cache_size", 100 @@ -987,20 +958,12 @@ class AsyncAdapt_asyncpg_dbapi: "prepared_statement_name_func", None ) - if util.asbool(async_fallback): - return AsyncAdaptFallback_asyncpg_connection( - self, - await_fallback(creator_fn(*arg, **kw)), - prepared_statement_cache_size=prepared_statement_cache_size, - prepared_statement_name_func=prepared_statement_name_func, - ) - else: - return AsyncAdapt_asyncpg_connection( - self, - await_only(creator_fn(*arg, **kw)), - prepared_statement_cache_size=prepared_statement_cache_size, - prepared_statement_name_func=prepared_statement_name_func, - ) + return AsyncAdapt_asyncpg_connection( + self, + await_(creator_fn(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, + prepared_statement_name_func=prepared_statement_name_func, + ) class Error(Exception): pass @@ -1201,15 +1164,6 @@ class PGDialect_asyncpg(PGDialect): dbapi_connection.ping() return True - @classmethod - def get_pool_class(cls, url): - async_fallback = url.query.get("async_fallback", False) - - if util.asbool(async_fallback): - return pool.FallbackAsyncAdaptedQueuePool - else: - return pool.AsyncAdaptedQueuePool - def is_disconnect(self, e, connection, cursor): if connection: return connection._connection.is_closed() @@ -1308,11 +1262,11 @@ class PGDialect_asyncpg(PGDialect): super_connect = super().on_connect() def connect(conn): - conn.await_(self.setup_asyncpg_json_codec(conn)) - conn.await_(self.setup_asyncpg_jsonb_codec(conn)) + await_(self.setup_asyncpg_json_codec(conn)) + await_(self.setup_asyncpg_jsonb_codec(conn)) if self._native_inet_types is False: - conn.await_(self._disable_asyncpg_inet_codecs(conn)) + await_(self._disable_asyncpg_inet_codecs(conn)) if super_connect is not None: super_connect(conn) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 743d038880..690cadb6b3 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -70,15 +70,12 @@ from .json import JSON from .json import JSONB from .json import JSONPathType from .types import CITEXT -from ... import pool from ... import util from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor -from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection from ...sql import sqltypes -from ...util.concurrency import await_fallback -from ...util.concurrency import await_only +from ...util.concurrency import await_ if TYPE_CHECKING: from typing import Iterable @@ -585,7 +582,7 @@ class AsyncAdapt_psycopg_ss_cursor( iterator = self._cursor.__aiter__() while True: try: - yield self.await_(iterator.__anext__()) + yield await_(iterator.__anext__()) except StopAsyncIteration: break @@ -632,16 +629,16 @@ class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection): self.set_autocommit(value) def set_autocommit(self, value): - self.await_(self._connection.set_autocommit(value)) + await_(self._connection.set_autocommit(value)) def set_isolation_level(self, value): - self.await_(self._connection.set_isolation_level(value)) + await_(self._connection.set_isolation_level(value)) def set_read_only(self, value): - self.await_(self._connection.set_read_only(value)) + await_(self._connection.set_read_only(value)) def set_deferrable(self, value): - self.await_(self._connection.set_deferrable(value)) + await_(self._connection.set_deferrable(value)) def cursor(self, name=None, /): if name: @@ -650,12 +647,6 @@ class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection): return AsyncAdapt_psycopg_cursor(self) -class AsyncAdaptFallback_psycopg_connection( - AsyncAdaptFallback_dbapi_connection, AsyncAdapt_psycopg_connection -): - __slots__ = () - - class PsycopgAdaptDBAPI: def __init__(self, psycopg, ExecStatus) -> None: self.psycopg = psycopg @@ -666,18 +657,12 @@ class PsycopgAdaptDBAPI: self.__dict__[k] = v def connect(self, *arg, **kw): - async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop( "async_creator_fn", self.psycopg.AsyncConnection.connect ) - if util.asbool(async_fallback): - return AsyncAdaptFallback_psycopg_connection( - self, await_fallback(creator_fn(*arg, **kw)) - ) - else: - return AsyncAdapt_psycopg_connection( - self, await_only(creator_fn(*arg, **kw)) - ) + return AsyncAdapt_psycopg_connection( + self, await_(creator_fn(*arg, **kw)) + ) class PGDialectAsync_psycopg(PGDialect_psycopg): @@ -691,20 +676,11 @@ class PGDialectAsync_psycopg(PGDialect_psycopg): return PsycopgAdaptDBAPI(psycopg, ExecStatus) - @classmethod - def get_pool_class(cls, url): - async_fallback = url.query.get("async_fallback", False) - - if util.asbool(async_fallback): - return pool.FallbackAsyncAdaptedQueuePool - else: - return pool.AsyncAdaptedQueuePool - def _type_info_fetch(self, connection, name): from psycopg.types import TypeInfo adapted = connection.connection - return adapted.await_(TypeInfo.fetch(adapted.driver_connection, name)) + return await_(TypeInfo.fetch(adapted.driver_connection, name)) def _do_isolation_level(self, connection, autocommit, isolation_level): connection.set_autocommit(autocommit) diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index 7eccf5fb17..05e64ee85d 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -83,13 +83,10 @@ from functools import partial from .base import SQLiteExecutionContext from .pysqlite import SQLiteDialect_pysqlite from ... import pool -from ... import util from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor -from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection -from ...util.concurrency import await_fallback -from ...util.concurrency import await_only +from ...util.concurrency import await_ class AsyncAdapt_aiosqlite_cursor(AsyncAdapt_dbapi_cursor): @@ -126,13 +123,13 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): self._connection._tx.put_nowait((future, function)) try: - return self.await_(future) + return await_(future) except Exception as error: self._handle_exception(error) def create_function(self, *args, **kw): try: - self.await_(self._connection.create_function(*args, **kw)) + await_(self._connection.create_function(*args, **kw)) except Exception as error: self._handle_exception(error) @@ -146,7 +143,7 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): def close(self): try: - self.await_(self._connection.close()) + await_(self._connection.close()) except ValueError: # this is undocumented for aiosqlite, that ValueError # was raised if .close() was called more than once, which is @@ -170,12 +167,6 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): super()._handle_exception(error) -class AsyncAdaptFallback_aiosqlite_connection( - AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aiosqlite_connection -): - __slots__ = () - - class AsyncAdapt_aiosqlite_dbapi: def __init__(self, aiosqlite, sqlite): self.aiosqlite = aiosqlite @@ -203,8 +194,6 @@ class AsyncAdapt_aiosqlite_dbapi: setattr(self, name, getattr(self.sqlite, name)) def connect(self, *arg, **kw): - async_fallback = kw.pop("async_fallback", False) - creator_fn = kw.pop("async_creator_fn", None) if creator_fn: connection = creator_fn(*arg, **kw) @@ -213,16 +202,10 @@ class AsyncAdapt_aiosqlite_dbapi: # it's a Thread. you'll thank us later connection.daemon = True - if util.asbool(async_fallback): - return AsyncAdaptFallback_aiosqlite_connection( - self, - await_fallback(connection), - ) - else: - return AsyncAdapt_aiosqlite_connection( - self, - await_only(connection), - ) + return AsyncAdapt_aiosqlite_connection( + self, + await_(connection), + ) class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 33e05120e2..6ad1de735a 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -484,7 +484,13 @@ class DefaultDialect(Dialect): @classmethod def get_pool_class(cls, url: URL) -> Type[Pool]: - return getattr(cls, "poolclass", pool.QueuePool) + default: Type[pool.Pool] + if cls.is_async: + default = pool.AsyncAdaptedQueuePool + else: + default = pool.QueuePool + + return getattr(cls, "poolclass", default) def get_dialect_pool_class(self, url: URL) -> Type[Pool]: return self.get_pool_class(url) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index faea997dea..ddd4ceeabb 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -42,7 +42,7 @@ from ..sql.compiler import Compiled # noqa from ..sql.compiler import TypeCompiler as TypeCompiler from ..sql.compiler import TypeCompiler # noqa from ..util import immutabledict -from ..util.concurrency import await_only +from ..util.concurrency import await_ from ..util.typing import Literal from ..util.typing import NotRequired @@ -3400,7 +3400,7 @@ class AdaptedConnection: :ref:`asyncio_events_run_async` """ - return await_only(fn(self._connection)) + return await_(fn(self._connection)) def __repr__(self) -> str: return "" % self._connection diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 251f521254..69d9cce55c 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -182,7 +182,7 @@ class GeneratorStartableContext(StartableContext[_T_co]): # tell if we get the same exception back value = typ() try: - await util.athrow(self.gen, typ, value, traceback) + await self.gen.athrow(value) except StopAsyncIteration as exc: # Suppress StopIteration *unless* it's the same exception that # was passed to throw(). This prevents a StopIteration diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index c25a8f85d8..243862cdc5 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -35,9 +35,6 @@ from .base import reset_none as reset_none from .base import reset_rollback as reset_rollback from .impl import AssertionPool as AssertionPool from .impl import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool -from .impl import ( - FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, -) from .impl import NullPool as NullPool from .impl import QueuePool as QueuePool from .impl import SingletonThreadPool as SingletonThreadPool diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index ced015088c..9616ad2998 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -257,10 +257,6 @@ class AsyncAdaptedQueuePool(QueuePool): _dialect = _AsyncConnDialect() -class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool): - _queue_class = sqla_queue.FallbackAsyncAdaptedQueue - - class NullPool(Pool): """A Pool which does not pool connections. diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py index 4236dcf92e..1f2bc55912 100644 --- a/lib/sqlalchemy/testing/asyncio.py +++ b/lib/sqlalchemy/testing/asyncio.py @@ -11,29 +11,27 @@ # setup/teardown in an asyncio event loop, conditionally based on the # current DB driver being used for a test. -# note that SQLAlchemy's asyncio integration also supports a method -# of running individual asyncio functions inside of separate event loops -# using "async_fallback" mode; however running whole functions in the event -# loop is a more accurate test for how SQLAlchemy's asyncio features -# would run in the real world. - - from __future__ import annotations from functools import wraps import inspect from . import config -from ..util.concurrency import _util_async_run -from ..util.concurrency import _util_async_run_coroutine_function +from ..util.concurrency import _AsyncUtil # may be set to False if the # --disable-asyncio flag is passed to the test runner. ENABLE_ASYNCIO = True +_async_util = _AsyncUtil() # it has lazy init so just always create one + + +def _shutdown(): + """called when the test finishes""" + _async_util.close() def _run_coroutine_function(fn, *args, **kwargs): - return _util_async_run_coroutine_function(fn, *args, **kwargs) + return _async_util.run(fn, *args, **kwargs) def _assume_async(fn, *args, **kwargs): @@ -50,7 +48,7 @@ def _assume_async(fn, *args, **kwargs): if not ENABLE_ASYNCIO: return fn(*args, **kwargs) - return _util_async_run(fn, *args, **kwargs) + return _async_util.run_in_greenlet(fn, *args, **kwargs) def _maybe_async_provisioning(fn, *args, **kwargs): @@ -69,7 +67,7 @@ def _maybe_async_provisioning(fn, *args, **kwargs): return fn(*args, **kwargs) if config.any_async: - return _util_async_run(fn, *args, **kwargs) + return _async_util.run_in_greenlet(fn, *args, **kwargs) else: return fn(*args, **kwargs) @@ -89,7 +87,7 @@ def _maybe_async(fn, *args, **kwargs): is_async = config._current.is_async if is_async: - return _util_async_run(fn, *args, **kwargs) + return _async_util.run_in_greenlet(fn, *args, **kwargs) else: return fn(*args, **kwargs) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 8430203dee..be22ff5991 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -25,7 +25,6 @@ from typing import Union from . import mock from . import requirements as _requirements from .util import fail -from .. import util # default requirements; this is replaced by plugin_base when pytest # is run @@ -330,9 +329,7 @@ class Config: self.test_schema = "test_schema" self.test_schema_2 = "test_schema_2" - self.is_async = db.dialect.is_async and not util.asbool( - db.url.query.get("async_fallback", False) - ) + self.is_async = db.dialect.is_async _stack = collections.deque() _configs = set() diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 749f9c160e..2bca37b2b8 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -23,7 +23,7 @@ from .util import decorator from .util import gc_collect from .. import event from .. import pool -from ..util import await_only +from ..util import await_ from ..util.typing import Literal @@ -112,7 +112,7 @@ class ConnectionKiller: self._safe(proxy_ref._checkin) if hasattr(rec, "sync_engine"): - await_only(rec.dispose()) + await_(rec.dispose()) else: rec.dispose() eng.clear() diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 47644e3d28..290e2cb5a4 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -182,6 +182,12 @@ def pytest_sessionfinish(session): collect_types.dump_stats(session.config.option.dump_pyannotate) +def pytest_unconfigure(config): + from sqlalchemy.testing import asyncio + + asyncio._shutdown() + + def pytest_collection_finish(session): if session.config.option.dump_pyannotate: from pyannotate_runtime import collect_types diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 884d558138..56b8c2972b 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -113,7 +113,7 @@ def generate_db_urls(db_urls, extra_drivers): --dburi postgresql://db1 \ --dburi postgresql://db2 \ --dburi postgresql://db2 \ - --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true + --dbdriver=psycopg2 --dbdriver=asyncpg Noting that the default postgresql driver is psycopg2, the output would be:: @@ -130,11 +130,10 @@ def generate_db_urls(db_urls, extra_drivers): we want to keep it in that dburi. Driver specific query options can be specified by added them to the - driver name. For example, to enable the async fallback option for - asyncpg:: + driver name. For example, to a sample option the asyncpg:: --dburi postgresql://db1 \ - --dbdriver=asyncpg?async_fallback=true + --dbdriver=asyncpg?some_option=a_value """ urls = set() diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index eaba84ecd2..b288cbbaf4 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1595,7 +1595,8 @@ class SuiteRequirements(Requirements): @property def async_dialect(self): - """dialect makes use of await_() to invoke operations on the DBAPI.""" + """dialect makes use of await_() to invoke operations on the + DBAPI.""" return exclusions.closed() diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index caaa657f93..eb17d005e2 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -49,7 +49,6 @@ from ._collections import WeakPopulateDict as WeakPopulateDict from ._collections import WeakSequence as WeakSequence from .compat import anext_ as anext_ from .compat import arm as arm -from .compat import athrow as athrow from .compat import b as b from .compat import b64decode as b64decode from .compat import b64encode as b64encode @@ -69,7 +68,7 @@ from .compat import py312 as py312 from .compat import py39 as py39 from .compat import pypy as pypy from .compat import win32 as win32 -from .concurrency import await_fallback as await_fallback +from .concurrency import await_ as await_ from .concurrency import await_only as await_only from .concurrency import greenlet_spawn as greenlet_spawn from .concurrency import is_exit_exception as is_exit_exception diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 1bc8997031..cd071c3762 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -20,8 +20,6 @@ import platform import sys import typing from typing import Any -from typing import AsyncGenerator -from typing import Awaitable from typing import Callable from typing import Dict from typing import Iterable @@ -32,7 +30,6 @@ from typing import Sequence from typing import Set from typing import Tuple from typing import Type -from typing import TypeVar py312 = sys.version_info >= (3, 12) py311 = sys.version_info >= (3, 11) @@ -50,8 +47,6 @@ has_refcount_gc = bool(cpython) dottedgetter = operator.attrgetter -_T_co = TypeVar("_T_co", covariant=True) - class FullArgSpec(typing.NamedTuple): args: List[str] @@ -101,24 +96,6 @@ def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec: ) -if py312: - # we are 95% certain this form of athrow works in former Python - # versions, however we are unable to get confirmation; - # see https://github.com/python/cpython/issues/105269 where have - # been unable to get a straight answer so far - def athrow( # noqa - gen: AsyncGenerator[_T_co, Any], typ: Any, value: Any, traceback: Any - ) -> Awaitable[_T_co]: - return gen.athrow(value) - -else: - - def athrow( # noqa - gen: AsyncGenerator[_T_co, Any], typ: Any, value: Any, traceback: Any - ) -> Awaitable[_T_co]: - return gen.athrow(typ, value, traceback) - - if py39: # python stubs don't have a public type for this. not worth # making a protocol diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 9e4c6c85da..bcdb928c29 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -9,21 +9,22 @@ from __future__ import annotations import asyncio -from contextvars import Context import sys from typing import Any from typing import Awaitable from typing import Callable from typing import Coroutine from typing import NoReturn -from typing import Optional -from typing import Protocol from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union +from .compat import py311 from .langhelpers import memoized_property +from .typing import Literal +from .typing import Self +from .typing import TypeGuard from .. import exc -from ..util.typing import TypeGuard _T = TypeVar("_T") @@ -44,25 +45,6 @@ _ERROR_MESSAGE = ( ) -if TYPE_CHECKING: - - class greenlet(Protocol): - dead: bool - gr_context: Optional[Context] - - def __init__(self, fn: Callable[..., Any], driver: greenlet): - ... - - def throw(self, *arg: Any) -> Any: - return None - - def switch(self, value: Any) -> Any: - return None - - def getcurrent() -> greenlet: - ... - - def _not_implemented(*arg: Any, **kw: Any) -> NoReturn: raise ImportError(_ERROR_MESSAGE) @@ -71,10 +53,10 @@ class _concurrency_shim_cls: """Late import shim for greenlet""" __slots__ = ( + "_has_greenlet", "greenlet", "_AsyncIoGreenlet", "getcurrent", - "_util_async_run", ) def _initialize(self, *, raise_: bool = True) -> None: @@ -84,7 +66,7 @@ class _concurrency_shim_cls: if not TYPE_CHECKING: global getcurrent, greenlet, _AsyncIoGreenlet - global _has_gr_context, _greenlet_error + global _has_gr_context try: from greenlet import getcurrent @@ -93,73 +75,46 @@ class _concurrency_shim_cls: if not TYPE_CHECKING: # set greenlet in the global scope to prevent re-init greenlet = None - + self._has_greenlet = False self._initialize_no_greenlet() if raise_: raise ImportError(_ERROR_MESSAGE) from e else: - self._initialize_greenlet() - - def _initialize_greenlet(self) -> None: - # If greenlet.gr_context is present in current version of greenlet, - # it will be set with the current context on creation. - # Refs: https://github.com/python-greenlet/greenlet/pull/198 - _has_gr_context = hasattr(getcurrent(), "gr_context") + self._has_greenlet = True + # If greenlet.gr_context is present in current version of greenlet, + # it will be set with the current context on creation. + # Refs: https://github.com/python-greenlet/greenlet/pull/198 + _has_gr_context = hasattr(getcurrent(), "gr_context") - # implementation based on snaury gist at - # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef - # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 # noqa: E501 + # implementation based on snaury gist at + # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef + # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 # noqa: E501 - class _AsyncIoGreenlet(greenlet): - dead: bool + class _AsyncIoGreenlet(greenlet): + dead: bool - def __init__(self, fn: Callable[..., Any], driver: greenlet): - greenlet.__init__(self, fn, driver) - self.driver = driver - if _has_gr_context: - self.gr_context = driver.gr_context + def __init__(self, fn: Callable[..., Any], driver: greenlet): + greenlet.__init__(self, fn, driver) + self.driver = driver + if _has_gr_context: + self.gr_context = driver.gr_context - self.greenlet = greenlet - self.getcurrent = getcurrent - self._AsyncIoGreenlet = _AsyncIoGreenlet - self._util_async_run = self._greenlet_util_async_run + self.greenlet = greenlet + self.getcurrent = getcurrent + self._AsyncIoGreenlet = _AsyncIoGreenlet def _initialize_no_greenlet(self): - self._util_async_run = self._no_greenlet_util_async_run self.getcurrent = _not_implemented - self.greenlet = _not_implemented # type: ignore - self._AsyncIoGreenlet = _not_implemented # type: ignore + self.greenlet = _not_implemented # type: ignore[assignment] + self._AsyncIoGreenlet = _not_implemented # type: ignore[assignment] def __getattr__(self, key: str) -> Any: if key in self.__slots__: - self._initialize(raise_=not key.startswith("_util")) + self._initialize() return getattr(self, key) else: raise AttributeError(key) - def _greenlet_util_async_run( - self, fn: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Any: - """for test suite/ util only""" - - loop = get_event_loop() - if not loop.is_running(): - return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs)) - else: - # allow for a wrapped test function to call another - assert isinstance( - _concurrency_shim.getcurrent(), - _concurrency_shim._AsyncIoGreenlet, - ) - return fn(*args, **kwargs) - - def _no_greenlet_util_async_run( - self, fn: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Any: - """for test suite/ util only""" - - return fn(*args, **kwargs) - _concurrency_shim = _concurrency_shim_cls() @@ -187,11 +142,11 @@ def in_greenlet() -> bool: return isinstance(current, _concurrency_shim._AsyncIoGreenlet) -def await_only(awaitable: Awaitable[_T]) -> _T: +def await_(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. - :func:`await_only` calls cannot be nested. + :func:`await_` calls cannot be nested. :param awaitable: The coroutine to call. @@ -202,7 +157,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T: _safe_cancel_awaitable(awaitable) raise exc.MissingGreenlet( - "greenlet_spawn has not been called; can't call await_only() " + "greenlet_spawn has not been called; can't call await_() " "here. Was IO attempted in an unexpected place?" ) @@ -213,31 +168,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T: return current.driver.switch(awaitable) # type: ignore[no-any-return] -def await_fallback(awaitable: Awaitable[_T]) -> _T: - """Awaits an async function in a sync method. - - The sync method must be inside a :func:`greenlet_spawn` context. - :func:`await_fallback` calls cannot be nested. - - :param awaitable: The coroutine to call. - - """ - - # this is called in the context greenlet while running fn - current = _concurrency_shim.getcurrent() - if not isinstance(current, _concurrency_shim._AsyncIoGreenlet): - loop = get_event_loop() - if loop.is_running(): - _safe_cancel_awaitable(awaitable) - - raise exc.MissingGreenlet( - "greenlet_spawn has not been called and asyncio event " - "loop is already running; can't call await_fallback() here. " - "Was IO attempted in an unexpected place?" - ) - return loop.run_until_complete(awaitable) - - return current.driver.switch(awaitable) # type: ignore[no-any-return] +await_only = await_ # old name. deprecated on 2.2 async def greenlet_spawn( @@ -248,7 +179,7 @@ async def greenlet_spawn( ) -> _T: """Runs a sync function ``fn`` in a new greenlet. - The sync function can then use :func:`await_only` to wait for async + The sync function can then use :func:`await_` to wait for async functions. :param fn: The sync callable to call. @@ -261,7 +192,7 @@ async def greenlet_spawn( fn, _concurrency_shim.getcurrent() ) # runs the function synchronously in gl greenlet. If the execution - # is interrupted by await_only, context is not dead and result is a + # is interrupted by await_, context is not dead and result is a # coroutine to wait. If the context is dead the function has # returned, and its result can be returned. switch_occurred = False @@ -270,7 +201,7 @@ async def greenlet_spawn( while not context.dead: switch_occurred = True try: - # wait for a coroutine from await_only and then return its + # wait for a coroutine from await_ and then return its # result back to it. value = await result except BaseException: @@ -302,42 +233,92 @@ class AsyncAdaptedLock: def __enter__(self) -> bool: # await is used to acquire the lock only after the first calling # coroutine has created the mutex. - return await_fallback(self.mutex.acquire()) + return await_(self.mutex.acquire()) def __exit__(self, *arg: Any, **kw: Any) -> None: self.mutex.release() -def _util_async_run_coroutine_function( - fn: Callable[..., Any], *args: Any, **kwargs: Any -) -> Any: - """for test suite/ util only""" - - loop = get_event_loop() - if loop.is_running(): - raise Exception( - "for async run coroutine we expect that no greenlet or event " - "loop is running when we start out" - ) - return loop.run_until_complete(fn(*args, **kwargs)) - - -def _util_async_run(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: - """for test suite/ util only""" - - _util_async_run = _concurrency_shim._util_async_run - return _util_async_run(fn, *args, **kwargs) - - -def get_event_loop() -> asyncio.AbstractEventLoop: - """vendor asyncio.get_event_loop() for python 3.7 and above. +if py311: + _Runner = asyncio.Runner +else: - Python 3.10 deprecates get_event_loop() as a standalone. + class _Runner: # type: ignore[no-redef] + """Runner implementation for test only""" + + _loop: Union[None, asyncio.AbstractEventLoop, Literal[False]] + + def __init__(self) -> None: + self._loop = None + + def __enter__(self) -> Self: + self._lazy_init() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + def close(self) -> None: + if self._loop: + try: + self._loop.run_until_complete( + self._loop.shutdown_asyncgens() + ) + finally: + self._loop.close() + self._loop = False + + def get_loop(self) -> asyncio.AbstractEventLoop: + """Return embedded event loop.""" + self._lazy_init() + assert self._loop + return self._loop + + def run(self, coro: Coroutine[Any, Any, _T]) -> _T: + self._lazy_init() + assert self._loop + return self._loop.run_until_complete(coro) + + def _lazy_init(self) -> None: + if self._loop is False: + raise RuntimeError("Runner is closed") + if self._loop is None: + self._loop = asyncio.new_event_loop() + + +class _AsyncUtil: + """Asyncio util for test suite/ util only""" + + def __init__(self) -> None: + self.runner = _Runner() # runner it lazy so it can be created here + + def run( + self, + fn: Callable[..., Coroutine[Any, Any, _T]], + *args: Any, + **kwargs: Any, + ) -> _T: + """Run coroutine on the loop""" + return self.runner.run(fn(*args, **kwargs)) + + def run_in_greenlet( + self, fn: Callable[..., _T], *args: Any, **kwargs: Any + ) -> _T: + """Run sync function in greenlet. Support nested calls""" + _concurrency_shim._initialize(raise_=False) + + if _concurrency_shim._has_greenlet: + if self.runner.get_loop().is_running(): + # allow for a wrapped test function to call another + assert isinstance( + _concurrency_shim.getcurrent(), + _concurrency_shim._AsyncIoGreenlet, + ) + return fn(*args, **kwargs) + else: + return self.runner.run(greenlet_spawn(fn, *args, **kwargs)) + else: + return fn(*args, **kwargs) - """ - try: - return asyncio.get_running_loop() - except RuntimeError: - # avoid "During handling of the above exception, another exception..." - pass - return asyncio.get_event_loop_policy().get_event_loop() + def close(self) -> None: + self.runner.close() diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index b641c910c7..a631fa67ea 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -24,16 +24,13 @@ import asyncio from collections import deque import threading from time import time as _time -import typing from typing import Any -from typing import Awaitable from typing import Deque from typing import Generic from typing import Optional from typing import TypeVar -from .concurrency import await_fallback -from .concurrency import await_only +from .concurrency import await_ from .langhelpers import memoized_property @@ -239,15 +236,6 @@ class Queue(QueueCommon[_T]): class AsyncAdaptedQueue(QueueCommon[_T]): - if typing.TYPE_CHECKING: - - @staticmethod - def await_(coroutine: Awaitable[Any]) -> _T: - ... - - else: - await_ = staticmethod(await_only) - def __init__(self, maxsize: int = 0, use_lifo: bool = False): self.use_lifo = use_lifo self.maxsize = maxsize @@ -292,9 +280,9 @@ class AsyncAdaptedQueue(QueueCommon[_T]): try: if timeout is not None: - self.await_(asyncio.wait_for(self._queue.put(item), timeout)) + await_(asyncio.wait_for(self._queue.put(item), timeout)) else: - self.await_(self._queue.put(item)) + await_(self._queue.put(item)) except (asyncio.QueueFull, asyncio.TimeoutError) as err: raise Full() from err @@ -310,15 +298,8 @@ class AsyncAdaptedQueue(QueueCommon[_T]): try: if timeout is not None: - return self.await_( - asyncio.wait_for(self._queue.get(), timeout) - ) + return await_(asyncio.wait_for(self._queue.get(), timeout)) else: - return self.await_(self._queue.get()) + return await_(self._queue.get()) except (asyncio.QueueEmpty, asyncio.TimeoutError) as err: raise Empty() from err - - -class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue[_T]): - if not typing.TYPE_CHECKING: - await_ = staticmethod(await_fallback) diff --git a/setup.cfg b/setup.cfg index 890aea977a..129a5aa82d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -64,8 +64,8 @@ postgresql_asyncpg = asyncpg postgresql_psycopg2binary = psycopg2-binary postgresql_psycopg2cffi = psycopg2cffi -postgresql_psycopg = psycopg>=3.0.7 -postgresql_psycopgbinary = psycopg[binary]>=3.0.7 +postgresql_psycopg = psycopg>=3.0.7,!=3.1.15 +postgresql_psycopgbinary = psycopg[binary]>=3.0.7,!=3.1.15 pymysql = pymysql aiomysql = @@ -162,17 +162,13 @@ postgresql = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test psycopg2 = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test psycopg = postgresql+psycopg://scott:tiger@127.0.0.1:5432/test psycopg_async = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test -psycopg_async_fallback = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test?async_fallback=true asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test -asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test mysql = mysql+mysqldb://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 pymysql = mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 -aiomysql_fallback = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true asyncmy = mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 -asyncmy_fallback = mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true mariadb = mariadb+mysqldb://scott:tiger@127.0.0.1:3306/test mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes diff --git a/test/base/_concurrency_fixtures.py b/test/base/_concurrency_fixtures.py index 587eb644d1..0a7f00c089 100644 --- a/test/base/_concurrency_fixtures.py +++ b/test/base/_concurrency_fixtures.py @@ -13,7 +13,7 @@ def greenlet_not_imported(): import sqlalchemy import sqlalchemy.util.concurrency # noqa: F401 from sqlalchemy.util import greenlet_spawn # noqa: F401 - from sqlalchemy.util.concurrency import await_only # noqa: F401 + from sqlalchemy.util.concurrency import await_ # noqa: F401 assert "greenlet" not in sys.modules diff --git a/test/base/test_concurrency.py b/test/base/test_concurrency.py index 1ea61ba7ce..274bcfe7c1 100644 --- a/test/base/test_concurrency.py +++ b/test/base/test_concurrency.py @@ -13,8 +13,7 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_true from sqlalchemy.testing.config import combinations -from sqlalchemy.util import await_fallback -from sqlalchemy.util import await_only +from sqlalchemy.util import await_ from sqlalchemy.util import greenlet_spawn from sqlalchemy.util import queue from ._concurrency_fixtures import greenlet_not_imported @@ -36,7 +35,7 @@ async def run2(): def go(*fns): - return sum(await_only(fn()) for fn in fns) + return sum(await_(fn()) for fn in fns) class TestAsyncioCompat(fixtures.TestBase): @@ -64,7 +63,7 @@ class TestAsyncioCompat(fixtures.TestBase): def sync_meth(): try: - await_only(async_meth_raise()) + await_(async_meth_raise()) except: cleanup.append(True) raise @@ -80,56 +79,29 @@ class TestAsyncioCompat(fixtures.TestBase): @async_test async def test_sync_error(self): def go(): - await_only(run1()) + await_(run1()) raise ValueError("sync error") with expect_raises_message(ValueError, "sync error"): await greenlet_spawn(go) - def test_await_fallback_no_greenlet(self): - to_await = run1() - await_fallback(to_await) - @async_test async def test_await_only_no_greenlet(self): to_await = run1() with expect_raises_message( exc.MissingGreenlet, "greenlet_spawn has not been called; " - r"can't call await_only\(\) here.", + r"can't call await_\(\) here.", ): - await_only(to_await) + await_(to_await) # existing awaitable is done with expect_raises(RuntimeError): - await greenlet_spawn(await_fallback, to_await) + await greenlet_spawn(await_, to_await) # no warning for a new one... to_await = run1() - await greenlet_spawn(await_fallback, to_await) - - @async_test - async def test_await_fallback_error(self): - to_await = run1() - - await to_await - - async def inner_await(): - nonlocal to_await - to_await = run1() - await_fallback(to_await) - - def go(): - await_fallback(inner_await()) - - with expect_raises_message( - exc.MissingGreenlet, - "greenlet_spawn has not been called and asyncio event loop", - ): - await greenlet_spawn(go) - - with expect_raises(RuntimeError): - await to_await + await greenlet_spawn(await_, to_await) @async_test async def test_await_only_error(self): @@ -140,15 +112,15 @@ class TestAsyncioCompat(fixtures.TestBase): async def inner_await(): nonlocal to_await to_await = run1() - await_only(to_await) + await_(to_await) def go(): - await_only(inner_await()) + await_(inner_await()) with expect_raises_message( exc.InvalidRequestError, "greenlet_spawn has not been called; " - r"can't call await_only\(\) here.", + r"can't call await_\(\) here.", ): await greenlet_spawn(go) @@ -172,22 +144,22 @@ class TestAsyncioCompat(fixtures.TestBase): var.set(val) def inner(val): - retval = await_only(async_inner(val)) + retval = await_(async_inner(val)) eq_(val, var.get()) eq_(retval, val) # set the value in a sync function newval = val + concurrency var.set(newval) - syncset = await_only(async_inner(newval)) + syncset = await_(async_inner(newval)) eq_(newval, var.get()) eq_(syncset, newval) # set the value in an async function retval = val + 2 * concurrency - await_only(async_set(retval)) + await_(async_set(retval)) eq_(var.get(), retval) - eq_(await_only(async_inner(retval)), retval) + eq_(await_(async_inner(retval)), retval) return retval @@ -304,4 +276,4 @@ class GracefulNoGreenletTest(fixtures.TestBase): "The SQLAlchemy asyncio module requires that the Python " "'greenlet' library is installed", ): - await_only(async_fn()) + await_(async_fn()) diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 49736df9b6..33abfed4d2 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -286,7 +286,6 @@ class PoolTest(PoolTestBase): @testing.combinations( (pool.QueuePool, False), (pool.AsyncAdaptedQueuePool, True), - (pool.FallbackAsyncAdaptedQueuePool, True), (pool.NullPool, None), (pool.SingletonThreadPool, False), (pool.StaticPool, None), @@ -307,7 +306,6 @@ class PoolTest(PoolTestBase): @testing.combinations( (pool.QueuePool, False), (pool.AsyncAdaptedQueuePool, True), - (pool.FallbackAsyncAdaptedQueuePool, True), (pool.NullPool, False), (pool.SingletonThreadPool, False), (pool.StaticPool, False), diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 7289d5494e..adb6b0b6c9 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -351,9 +351,9 @@ class AsyncEngineTest(EngineFixture): pool_connection = await conn.get_raw_connection() return pool_connection - from sqlalchemy.util.concurrency import await_only + from sqlalchemy.util.concurrency import await_ - pool_connection = await_only(go()) + pool_connection = await_(go()) rec = pool_connection._connection_record ref = rec.fairy_ref diff --git a/test/requirements.py b/test/requirements.py index 4a0b365c2b..1626c825f2 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1527,7 +1527,8 @@ class DefaultRequirements(SuiteRequirements): @property def async_dialect(self): - """dialect makes use of await_() to invoke operations on the DBAPI.""" + """dialect makes use of await_() to invoke operations on + the DBAPI.""" return self.asyncio + only_on( LambdaPredicate(