From: Mike Bayer Date: Mon, 13 Nov 2023 20:52:43 +0000 (-0500) Subject: adapt all asyncio dialects to asyncio connector X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=900d13acb4f19de955eb609dea52a755f0d11acb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git adapt all asyncio dialects to asyncio connector Adapted all asyncio dialects, including aiosqlite, aiomysql, asyncmy, psycopg, asyncpg to use the generic asyncio connection adapter first added in :ticket:`6521` for the aioodbc DBAPI, allowing these dialects to take advantage of a common framework. Fixes: #10415 Change-Id: I24123175aa787f3a2c550d9e02d3827173794e3b --- diff --git a/doc/build/changelog/unreleased_21/10415.rst b/doc/build/changelog/unreleased_21/10415.rst new file mode 100644 index 0000000000..ee96c2df5a --- /dev/null +++ b/doc/build/changelog/unreleased_21/10415.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: change, asyncio + :tickets: 10415 + + Adapted all asyncio dialects, including aiosqlite, aiomysql, asyncmy, + psycopg, asyncpg to use the generic asyncio connection adapter first added + in :ticket:`6521` for the aioodbc DBAPI, allowing these dialects to take + advantage of a common framework. diff --git a/lib/sqlalchemy/connectors/aioodbc.py b/lib/sqlalchemy/connectors/aioodbc.py index c6986366e1..e0f5f55474 100644 --- a/lib/sqlalchemy/connectors/aioodbc.py +++ b/lib/sqlalchemy/connectors/aioodbc.py @@ -58,6 +58,15 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection): self._connection._conn.autocommit = value + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def add_output_converter(self, *arg, **kw): + self._connection.add_output_converter(*arg, **kw) + + def character_set_name(self): + return self._connection.character_set_name() + def cursor(self, server_side=False): # aioodbc sets connection=None when closed and just fails with # AttributeError here. Here we use the same ProgrammingError + diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index 997407ccd5..9358457ceb 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -4,19 +4,116 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """generic asyncio-adapted versions of DBAPI connection and cursor""" from __future__ import annotations +import asyncio import collections import itertools +import sys +from typing import Any +from typing import Deque +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import Protocol +from typing import Sequence from ..engine import AdaptedConnection -from ..util.concurrency import asyncio +from ..engine.interfaces import _DBAPICursorDescription +from ..engine.interfaces import _DBAPIMultiExecuteParams +from ..engine.interfaces import _DBAPISingleExecuteParams from ..util.concurrency import await_fallback from ..util.concurrency import await_only +from ..util.typing import Self + + +class AsyncIODBAPIConnection(Protocol): + """protocol representing an async adapted version of a + :pep:`249` database connection. + + + """ + + async def close(self) -> None: + ... + + async def commit(self) -> None: + ... + + def cursor(self) -> AsyncIODBAPICursor: + ... + + async def rollback(self) -> None: + ... + + +class AsyncIODBAPICursor(Protocol): + """protocol representing an async adapted version + of a :pep:`249` database cursor. + + + """ + + def __aenter__(self) -> Any: + ... + + @property + def description( + self, + ) -> _DBAPICursorDescription: + """The description attribute of the Cursor.""" + ... + + @property + def rowcount(self) -> int: + ... + + arraysize: int + + lastrowid: int + + async def close(self) -> None: + ... + + async def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: + ... + + async def executemany( + self, + operation: Any, + parameters: _DBAPIMultiExecuteParams, + ) -> Any: + ... + + async def fetchone(self) -> Optional[Any]: + ... + + async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: + ... + + async def fetchall(self) -> Sequence[Any]: + ... + + async def setinputsizes(self, sizes: Sequence[Any]) -> None: + ... + + def setoutputsize(self, size: Any, column: Any) -> None: + ... + + async def callproc( + self, procname: str, parameters: Sequence[Any] = ... + ) -> Any: + ... + + async def nextset(self) -> Optional[bool]: + ... class AsyncAdapt_dbapi_cursor: @@ -29,52 +126,85 @@ class AsyncAdapt_dbapi_cursor: "_rows", ) - def __init__(self, adapt_connection): + _cursor: AsyncIODBAPICursor + _adapt_connection: AsyncAdapt_dbapi_connection + _connection: AsyncIODBAPIConnection + _rows: Deque[Any] + + def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection self.await_ = adapt_connection.await_ - cursor = self._connection.cursor() + cursor = self._make_new_cursor(self._connection) + + try: + self._cursor = self.await_(cursor.__aenter__()) + except Exception as error: + self._adapt_connection._handle_exception(error) - self._cursor = self.await_(cursor.__aenter__()) self._rows = collections.deque() + def _make_new_cursor( + self, connection: AsyncIODBAPIConnection + ) -> AsyncIODBAPICursor: + return connection.cursor() + @property - def description(self): + def description(self) -> Optional[_DBAPICursorDescription]: return self._cursor.description @property - def rowcount(self): + def rowcount(self) -> int: return self._cursor.rowcount @property - def arraysize(self): + def arraysize(self) -> int: return self._cursor.arraysize @arraysize.setter - def arraysize(self, value): + def arraysize(self, value: int) -> None: self._cursor.arraysize = value @property - def lastrowid(self): + def lastrowid(self) -> int: return self._cursor.lastrowid - def close(self): + def close(self) -> None: # note we aren't actually closing the cursor here, # we are just letting GC do it. see notes in aiomysql dialect self._rows.clear() - def execute(self, operation, parameters=None): - return self.await_(self._execute_async(operation, parameters)) - - def executemany(self, operation, seq_of_parameters): - return self.await_( - self._executemany_async(operation, seq_of_parameters) - ) + def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: + try: + return self.await_(self._execute_async(operation, parameters)) + except Exception as error: + self._adapt_connection._handle_exception(error) + + def executemany( + self, + operation: Any, + seq_of_parameters: _DBAPIMultiExecuteParams, + ) -> Any: + try: + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + except Exception as error: + self._adapt_connection._handle_exception(error) - async def _execute_async(self, operation, parameters): + async def _execute_async( + self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] + ) -> Any: async with self._adapt_connection._execute_mutex: - result = await self._cursor.execute(operation, parameters or ()) + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) if self._cursor.description and not self.server_side: # aioodbc has a "fake" async result, so we have to pull it out @@ -84,35 +214,45 @@ class AsyncAdapt_dbapi_cursor: self._rows = collections.deque(await self._cursor.fetchall()) return result - async def _executemany_async(self, operation, seq_of_parameters): + async def _executemany_async( + self, + operation: Any, + seq_of_parameters: _DBAPIMultiExecuteParams, + ) -> Any: async with self._adapt_connection._execute_mutex: return await self._cursor.executemany(operation, seq_of_parameters) - def nextset(self): + def nextset(self) -> None: self.await_(self._cursor.nextset()) if self._cursor.description and not self.server_side: self._rows = collections.deque( self.await_(self._cursor.fetchall()) ) - def setinputsizes(self, *inputsizes): + def setinputsizes(self, *inputsizes: Any) -> None: # NOTE: this is overrridden in aioodbc due to # see https://github.com/aio-libs/aioodbc/issues/451 # right now return self.await_(self._cursor.setinputsizes(*inputsizes)) - def __iter__(self): + def __enter__(self) -> Self: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + def __iter__(self) -> Iterator[Any]: while self._rows: yield self._rows.popleft() - def fetchone(self): + def fetchone(self) -> Optional[Any]: if self._rows: return self._rows.popleft() else: return None - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: if size is None: size = self.arraysize @@ -121,7 +261,7 @@ class AsyncAdapt_dbapi_cursor: self._rows = collections.deque(rr) return retval - def fetchall(self): + def fetchall(self) -> Sequence[Any]: retval = list(self._rows) self._rows.clear() return retval @@ -131,27 +271,18 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): __slots__ = () server_side = True - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor() - - self._cursor = self.await_(cursor.__aenter__()) - - def close(self): + def close(self) -> None: if self._cursor is not None: self.await_(self._cursor.close()) - self._cursor = None + self._cursor = None # type: ignore - def fetchone(self): + def fetchone(self) -> Optional[Any]: return self.await_(self._cursor.fetchone()) - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> Any: return self.await_(self._cursor.fetchmany(size=size)) - def fetchall(self): + def fetchall(self) -> Sequence[Any]: return self.await_(self._cursor.fetchall()) @@ -162,44 +293,47 @@ class AsyncAdapt_dbapi_connection(AdaptedConnection): await_ = staticmethod(await_only) __slots__ = ("dbapi", "_execute_mutex") - def __init__(self, dbapi, connection): + _connection: AsyncIODBAPIConnection + + def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection): self.dbapi = dbapi self._connection = connection self._execute_mutex = asyncio.Lock() - def ping(self, reconnect): - return self.await_(self._connection.ping(reconnect)) - - def add_output_converter(self, *arg, **kw): - self._connection.add_output_converter(*arg, **kw) - - def character_set_name(self): - return self._connection.character_set_name() - - @property - def autocommit(self): - return self._connection.autocommit - - @autocommit.setter - def autocommit(self, value): - # https://github.com/aio-libs/aioodbc/issues/448 - # self._connection.autocommit = value - - self._connection._conn.autocommit = value - - def cursor(self, server_side=False): + def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor: if server_side: return self._ss_cursor_cls(self) else: return self._cursor_cls(self) - def rollback(self): - self.await_(self._connection.rollback()) - - def commit(self): - self.await_(self._connection.commit()) - - def close(self): + def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: + """lots of DBAPIs seem to provide this, so include it""" + cursor = self.cursor() + cursor.execute(operation, parameters) + return cursor + + def _handle_exception(self, error: Exception) -> NoReturn: + exc_info = sys.exc_info() + + raise error.with_traceback(exc_info[2]) + + def rollback(self) -> None: + try: + self.await_(self._connection.rollback()) + except Exception as error: + self._handle_exception(error) + + def commit(self) -> None: + try: + self.await_(self._connection.commit()) + except Exception as error: + self._handle_exception(error) + + def close(self) -> None: self.await_(self._connection.close()) diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index 2a0c6ba783..41f4c09e93 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -30,158 +30,40 @@ This dialect should normally be used only with the from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util -from ...engine import AdaptedConnection -from ...util.concurrency import asyncio +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection from ...util.concurrency import await_fallback from ...util.concurrency import await_only -class AsyncAdapt_aiomysql_cursor: - # TODO: base on connectors/asyncio.py - # see #10415 - server_side = False - __slots__ = ( - "_adapt_connection", - "_connection", - "await_", - "_cursor", - "_rows", - ) - - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor(adapt_connection.dbapi.Cursor) - - # see https://github.com/aio-libs/aiomysql/issues/543 - self._cursor = self.await_(cursor.__aenter__()) - self._rows = [] - - @property - def description(self): - return self._cursor.description - - @property - def rowcount(self): - return self._cursor.rowcount +class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor): + __slots__ = () - @property - def arraysize(self): - return self._cursor.arraysize + def _make_new_cursor(self, connection): + return connection.cursor(self._adapt_connection.dbapi.Cursor) - @arraysize.setter - def arraysize(self, value): - self._cursor.arraysize = value - @property - def lastrowid(self): - return self._cursor.lastrowid +class AsyncAdapt_aiomysql_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_aiomysql_cursor +): + __slots__ = () - def close(self): - # note we aren't actually closing the cursor here, - # we are just letting GC do it. to allow this to be async - # we would need the Result to change how it does "Safe close cursor". - # MySQL "cursors" don't actually have state to be "closed" besides - # exhausting rows, which we already have done for sync cursor. - # another option would be to emulate aiosqlite dialect and assign - # cursor only if we are doing server side cursor operation. - self._rows[:] = [] - - def execute(self, operation, parameters=None): - return self.await_(self._execute_async(operation, parameters)) - - def executemany(self, operation, seq_of_parameters): - return self.await_( - self._executemany_async(operation, seq_of_parameters) + def _make_new_cursor(self, connection): + return connection.cursor( + self._adapt_connection.dbapi.aiomysql.cursors.SSCursor ) - async def _execute_async(self, operation, parameters): - async with self._adapt_connection._execute_mutex: - result = await self._cursor.execute(operation, parameters) - - if not self.server_side: - # aiomysql has a "fake" async result, so we have to pull it out - # of that here since our default result is not async. - # we could just as easily grab "_rows" here and be done with it - # but this is safer. - self._rows = list(await self._cursor.fetchall()) - return result - - async def _executemany_async(self, operation, seq_of_parameters): - async with self._adapt_connection._execute_mutex: - return await self._cursor.executemany(operation, seq_of_parameters) - - def setinputsizes(self, *inputsizes): - pass - - def __iter__(self): - while self._rows: - yield self._rows.pop(0) - - def fetchone(self): - if self._rows: - return self._rows.pop(0) - else: - return None - - def fetchmany(self, size=None): - if size is None: - size = self.arraysize - - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval - - def fetchall(self): - retval = self._rows[:] - self._rows[:] = [] - return retval - -class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): - # TODO: base on connectors/asyncio.py - # see #10415 +class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection): __slots__ = () - server_side = True - - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor) - - self._cursor = self.await_(cursor.__aenter__()) - - def close(self): - if self._cursor is not None: - self.await_(self._cursor.close()) - self._cursor = None - - def fetchone(self): - return self.await_(self._cursor.fetchone()) - - def fetchmany(self, size=None): - return self.await_(self._cursor.fetchmany(size=size)) - def fetchall(self): - return self.await_(self._cursor.fetchall()) - - -class AsyncAdapt_aiomysql_connection(AdaptedConnection): - # TODO: base on connectors/asyncio.py - # see #10415 - await_ = staticmethod(await_only) - __slots__ = ("dbapi", "_execute_mutex") - - def __init__(self, dbapi, connection): - self.dbapi = dbapi - self._connection = connection - self._execute_mutex = asyncio.Lock() + _cursor_cls = AsyncAdapt_aiomysql_cursor + _ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor def ping(self, reconnect): + assert not reconnect return self.await_(self._connection.ping(reconnect)) def character_set_name(self): @@ -190,30 +72,16 @@ class AsyncAdapt_aiomysql_connection(AdaptedConnection): def autocommit(self, value): self.await_(self._connection.autocommit(value)) - def cursor(self, server_side=False): - if server_side: - return AsyncAdapt_aiomysql_ss_cursor(self) - else: - return AsyncAdapt_aiomysql_cursor(self) - - def rollback(self): - self.await_(self._connection.rollback()) - - def commit(self): - self.await_(self._connection.commit()) - def close(self): # it's not awaitable. self._connection.close() -class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection): - # TODO: base on connectors/asyncio.py - # see #10415 +class AsyncAdaptFallback_aiomysql_connection( + AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aiomysql_connection +): __slots__ = () - await_ = staticmethod(await_fallback) - class AsyncAdapt_aiomysql_dbapi: def __init__(self, aiomysql, pymysql): diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index 92058d60dd..c5caf79d3a 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -25,183 +25,58 @@ This dialect should normally be used only with the """ # noqa -from contextlib import asynccontextmanager +from __future__ import annotations from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util -from ...engine import AdaptedConnection -from ...util.concurrency import asyncio +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection from ...util.concurrency import await_fallback from ...util.concurrency import await_only -class AsyncAdapt_asyncmy_cursor: - # TODO: base on connectors/asyncio.py - # see #10415 - server_side = False - __slots__ = ( - "_adapt_connection", - "_connection", - "await_", - "_cursor", - "_rows", - ) - - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor() - - self._cursor = self.await_(cursor.__aenter__()) - self._rows = [] - - @property - def description(self): - return self._cursor.description - - @property - def rowcount(self): - return self._cursor.rowcount - - @property - def arraysize(self): - return self._cursor.arraysize - - @arraysize.setter - def arraysize(self, value): - self._cursor.arraysize = value - - @property - def lastrowid(self): - return self._cursor.lastrowid - - def close(self): - # note we aren't actually closing the cursor here, - # we are just letting GC do it. to allow this to be async - # we would need the Result to change how it does "Safe close cursor". - # MySQL "cursors" don't actually have state to be "closed" besides - # exhausting rows, which we already have done for sync cursor. - # another option would be to emulate aiosqlite dialect and assign - # cursor only if we are doing server side cursor operation. - self._rows[:] = [] - - def execute(self, operation, parameters=None): - return self.await_(self._execute_async(operation, parameters)) - - def executemany(self, operation, seq_of_parameters): - return self.await_( - self._executemany_async(operation, seq_of_parameters) - ) - - async def _execute_async(self, operation, parameters): - async with self._adapt_connection._mutex_and_adapt_errors(): - if parameters is None: - result = await self._cursor.execute(operation) - else: - result = await self._cursor.execute(operation, parameters) - - if not self.server_side: - # asyncmy has a "fake" async result, so we have to pull it out - # of that here since our default result is not async. - # we could just as easily grab "_rows" here and be done with it - # but this is safer. - self._rows = list(await self._cursor.fetchall()) - return result - - async def _executemany_async(self, operation, seq_of_parameters): - async with self._adapt_connection._mutex_and_adapt_errors(): - return await self._cursor.executemany(operation, seq_of_parameters) - - def setinputsizes(self, *inputsizes): - pass - - def __iter__(self): - while self._rows: - yield self._rows.pop(0) - - def fetchone(self): - if self._rows: - return self._rows.pop(0) - else: - return None - - def fetchmany(self, size=None): - if size is None: - size = self.arraysize - - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval - - def fetchall(self): - retval = self._rows[:] - self._rows[:] = [] - return retval +class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor): + __slots__ = () -class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor): - # TODO: base on connectors/asyncio.py - # see #10415 +class AsyncAdapt_asyncmy_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncmy_cursor +): __slots__ = () - server_side = True - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor( - adapt_connection.dbapi.asyncmy.cursors.SSCursor + def _make_new_cursor(self, connection): + return connection.cursor( + self._adapt_connection.dbapi.asyncmy.cursors.SSCursor ) - self._cursor = self.await_(cursor.__aenter__()) - - def close(self): - if self._cursor is not None: - self.await_(self._cursor.close()) - self._cursor = None - - def fetchone(self): - return self.await_(self._cursor.fetchone()) - - def fetchmany(self, size=None): - return self.await_(self._cursor.fetchmany(size=size)) - - def fetchall(self): - return self.await_(self._cursor.fetchall()) +class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection): + __slots__ = () -class AsyncAdapt_asyncmy_connection(AdaptedConnection): - # TODO: base on connectors/asyncio.py - # see #10415 - await_ = staticmethod(await_only) - __slots__ = ("dbapi", "_execute_mutex") + _cursor_cls = AsyncAdapt_asyncmy_cursor + _ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor - def __init__(self, dbapi, connection): - self.dbapi = dbapi - self._connection = connection - self._execute_mutex = asyncio.Lock() + def _handle_exception(self, error): + if isinstance(error, AttributeError): + raise self.dbapi.InternalError( + "network operation failed due to asyncmy attribute error" + ) - @asynccontextmanager - async def _mutex_and_adapt_errors(self): - async with self._execute_mutex: - try: - yield - except AttributeError: - raise self.dbapi.InternalError( - "network operation failed due to asyncmy attribute error" - ) + raise error def ping(self, reconnect): assert not reconnect return self.await_(self._do_ping()) async def _do_ping(self): - async with self._mutex_and_adapt_errors(): - return await self._connection.ping(False) + try: + async with self._execute_mutex: + return await self._connection.ping(False) + except Exception as error: + self._handle_exception(error) def character_set_name(self): return self._connection.character_set_name() @@ -209,28 +84,16 @@ class AsyncAdapt_asyncmy_connection(AdaptedConnection): def autocommit(self, value): self.await_(self._connection.autocommit(value)) - def cursor(self, server_side=False): - if server_side: - return AsyncAdapt_asyncmy_ss_cursor(self) - else: - return AsyncAdapt_asyncmy_cursor(self) - - def rollback(self): - self.await_(self._connection.rollback()) - - def commit(self): - self.await_(self._connection.commit()) - def close(self): # it's not awaitable. self._connection.close() -class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection): +class AsyncAdaptFallback_asyncmy_connection( + AsyncAdaptFallback_dbapi_connection, AsyncAdapt_asyncmy_connection +): __slots__ = () - await_ = staticmethod(await_fallback) - def _Binary(x): """Return x as a binary type.""" diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index ca35bf9607..d57c94a170 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -187,7 +187,14 @@ import decimal import json as _py_json import re import time +from typing import Any from typing import cast +from typing import Iterable +from typing import NoReturn +from typing import Optional +from typing import Protocol +from typing import Sequence +from typing import Tuple from typing import TYPE_CHECKING from . import json @@ -211,15 +218,16 @@ from .types import CITEXT from ... import exc from ... import pool from ... import util -from ...engine import AdaptedConnection +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor from ...engine import processors from ...sql import sqltypes -from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only if TYPE_CHECKING: - from typing import Iterable + from ...engine.interfaces import _DBAPICursorDescription class AsyncpgARRAY(PGARRAY): @@ -489,33 +497,72 @@ class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer): pass -class AsyncAdapt_asyncpg_cursor: +class _AsyncpgConnection(Protocol): + async def executemany( + self, operation: Any, seq_of_parameters: Sequence[Tuple[Any, ...]] + ) -> Any: + ... + + async def reload_schema_state(self) -> None: + ... + + async def prepare( + self, operation: Any, *, name: Optional[str] = None + ) -> Any: + ... + + def is_closed(self) -> bool: + ... + + def transaction( + self, + *, + isolation: Optional[str] = None, + readonly: bool = False, + deferrable: bool = False, + ) -> Any: + ... + + def fetchrow(self, operation: str) -> Any: + ... + + async def close(self) -> None: + ... + + def terminate(self) -> None: + ... + + +class _AsyncpgCursor(Protocol): + def fetch(self, size: int) -> Any: + ... + + +class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor): __slots__ = ( - "_adapt_connection", - "_connection", - "_rows", - "description", - "arraysize", - "rowcount", - "_cursor", + "_description", + "_arraysize", + "_rowcount", "_invalidate_schema_cache_asof", ) server_side = False - def __init__(self, adapt_connection): + _adapt_connection: AsyncAdapt_asyncpg_connection + _connection: _AsyncpgConnection + _cursor: Optional[_AsyncpgCursor] + + def __init__(self, adapt_connection: AsyncAdapt_asyncpg_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection - self._rows = [] + self.await_ = adapt_connection.await_ self._cursor = None - self.description = None - self.arraysize = 1 - self.rowcount = -1 + self._rows = collections.deque() + self._description = None + self._arraysize = 1 + self._rowcount = -1 self._invalidate_schema_cache_asof = 0 - def close(self): - self._rows[:] = [] - def _handle_exception(self, error): self._adapt_connection._handle_exception(error) @@ -535,7 +582,7 @@ class AsyncAdapt_asyncpg_cursor: ) if attributes: - self.description = [ + self._description = [ ( attr.name, attr.type.oid, @@ -548,30 +595,48 @@ class AsyncAdapt_asyncpg_cursor: for attr in attributes ] else: - self.description = None + self._description = None if self.server_side: self._cursor = await prepared_stmt.cursor(*parameters) - self.rowcount = -1 + self._rowcount = -1 else: - self._rows = await prepared_stmt.fetch(*parameters) + self._rows = collections.deque( + await prepared_stmt.fetch(*parameters) + ) status = prepared_stmt.get_statusmsg() reg = re.match( r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status ) if reg: - self.rowcount = int(reg.group(1)) + self._rowcount = int(reg.group(1)) else: - self.rowcount = -1 + self._rowcount = -1 except Exception as error: self._handle_exception(error) + @property + def description(self) -> Optional[_DBAPICursorDescription]: + return self._description + + @property + def rowcount(self) -> int: + return self._rowcount + + @property + def arraysize(self) -> int: + return self._arraysize + + @arraysize.setter + def arraysize(self, value: int) -> None: + self._arraysize = value + async def _executemany(self, operation, seq_of_parameters): adapt_connection = self._adapt_connection - self.description = None + self._description = None async with adapt_connection._execute_mutex: await adapt_connection._check_type_cache_invalidation( self._invalidate_schema_cache_asof @@ -600,31 +665,10 @@ class AsyncAdapt_asyncpg_cursor: def setinputsizes(self, *inputsizes): raise NotImplementedError() - def __iter__(self): - while self._rows: - yield self._rows.pop(0) - - def fetchone(self): - if self._rows: - return self._rows.pop(0) - else: - return None - - def fetchmany(self, size=None): - if size is None: - size = self.arraysize - - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval - - def fetchall(self): - retval = self._rows[:] - self._rows[:] = [] - return retval - -class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): +class AsyncAdapt_asyncpg_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncpg_cursor +): server_side = True __slots__ = ("_rowbuffer",) @@ -637,6 +681,7 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): self._rowbuffer = None def _buffer_rows(self): + assert self._cursor is not None new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) self._rowbuffer = collections.deque(new_rows) @@ -669,6 +714,9 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): if not self._rowbuffer: self._buffer_rows() + assert self._rowbuffer is not None + assert self._cursor is not None + buf = list(self._rowbuffer) lb = len(buf) if size > lb: @@ -681,6 +729,8 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): return result def fetchall(self): + assert self._rowbuffer is not None + ret = list(self._rowbuffer) + list( self._adapt_connection.await_(self._all()) ) @@ -690,6 +740,8 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): async def _all(self): rows = [] + assert self._cursor is not None + # TODO: looks like we have to hand-roll some kind of batching here. # hardcoding for the moment but this should be improved. while True: @@ -707,9 +759,13 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): ) -class AsyncAdapt_asyncpg_connection(AdaptedConnection): +class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): + _cursor_cls = AsyncAdapt_asyncpg_cursor + _ss_cursor_cls = AsyncAdapt_asyncpg_ss_cursor + + _connection: _AsyncpgConnection + __slots__ = ( - "dbapi", "isolation_level", "_isolation_setting", "readonly", @@ -719,11 +775,8 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): "_prepared_statement_cache", "_prepared_statement_name_func", "_invalidate_schema_cache_asof", - "_execute_mutex", ) - await_ = staticmethod(await_only) - def __init__( self, dbapi, @@ -731,15 +784,13 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): prepared_statement_cache_size=100, prepared_statement_name_func=None, ): - self.dbapi = dbapi - self._connection = connection + super().__init__(dbapi, connection) self.isolation_level = self._isolation_setting = "read_committed" self.readonly = False self.deferrable = False self._transaction = None self._started = False self._invalidate_schema_cache_asof = time.time() - self._execute_mutex = asyncio.Lock() if prepared_statement_cache_size: self._prepared_statement_cache = util.LRUCache( @@ -789,7 +840,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): return prepared_stmt, attributes - def _handle_exception(self, error): + def _handle_exception(self, error: Exception) -> NoReturn: if self._connection.is_closed(): self._transaction = None self._started = False @@ -807,9 +858,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): ) = getattr(error, "sqlstate", None) raise translated_error from error else: - raise error + super()._handle_exception(error) else: - raise error + super()._handle_exception(error) @property def autocommit(self): @@ -862,14 +913,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): else: self._started = True - def cursor(self, server_side=False): - if server_side: - return AsyncAdapt_asyncpg_ss_cursor(self) - else: - return AsyncAdapt_asyncpg_cursor(self) - def rollback(self): if self._started: + assert self._transaction is not None try: self.await_(self._transaction.rollback()) except Exception as error: @@ -880,6 +926,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): def commit(self): if self._started: + assert self._transaction is not None try: self.await_(self._transaction.commit()) except Exception as error: diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index dcd69ce663..4856876380 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -53,6 +53,7 @@ The asyncio version of the dialect may also be specified explicitly using the """ # noqa from __future__ import annotations +import collections import logging import re from typing import cast @@ -71,7 +72,10 @@ from .json import JSONPathType from .types import CITEXT from ... import pool from ... import util -from ...engine import AdaptedConnection +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection from ...sql import sqltypes from ...util.concurrency import await_fallback from ...util.concurrency import await_only @@ -492,7 +496,8 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): try: if not before_autocommit: self._do_autocommit(dbapi_conn, True) - dbapi_conn.execute(command) + with dbapi_conn.cursor() as cursor: + cursor.execute(command) finally: if not before_autocommit: self._do_autocommit(dbapi_conn, before_autocommit) @@ -522,93 +527,60 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): return ";" -class AsyncAdapt_psycopg_cursor: - __slots__ = ("_cursor", "await_", "_rows") - - _psycopg_ExecStatus = None - - def __init__(self, cursor, await_) -> None: - self._cursor = cursor - self.await_ = await_ - self._rows = [] - - def __getattr__(self, name): - return getattr(self._cursor, name) - - @property - def arraysize(self): - return self._cursor.arraysize - - @arraysize.setter - def arraysize(self, value): - self._cursor.arraysize = value +class AsyncAdapt_psycopg_cursor(AsyncAdapt_dbapi_cursor): + __slots__ = () def close(self): self._rows.clear() # Normal cursor just call _close() in a non-sync way. self._cursor._close() - def execute(self, query, params=None, **kw): - result = self.await_(self._cursor.execute(query, params, **kw)) + async def _execute_async(self, operation, parameters): + # override to not use mutex, psycopg3 already has mutex + + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + # sqlalchemy result is not async, so need to pull all rows here + # (assuming not a server side cursor) res = self._cursor.pgresult # don't rely on psycopg providing enum symbols, compare with # eq/ne - if res and res.status == self._psycopg_ExecStatus.TUPLES_OK: - rows = self.await_(self._cursor.fetchall()) - if not isinstance(rows, list): - self._rows = list(rows) - else: - self._rows = rows + if ( + not self.server_side + and res + and res.status == self._adapt_connection.dbapi.ExecStatus.TUPLES_OK + ): + self._rows = collections.deque(await self._cursor.fetchall()) return result - def executemany(self, query, params_seq): - return self.await_(self._cursor.executemany(query, params_seq)) - - def __iter__(self): - # TODO: try to avoid pop(0) on a list - while self._rows: - yield self._rows.pop(0) - - def fetchone(self): - if self._rows: - # TODO: try to avoid pop(0) on a list - return self._rows.pop(0) - else: - return None - - def fetchmany(self, size=None): - if size is None: - size = self._cursor.arraysize - - retval = self._rows[0:size] - self._rows = self._rows[size:] - return retval - - def fetchall(self): - retval = self._rows - self._rows = [] - return retval - + async def _executemany_async( + self, + operation, + seq_of_parameters, + ): + # override to not use mutex, psycopg3 already has mutex + return await self._cursor.executemany(operation, seq_of_parameters) -class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor): - def execute(self, query, params=None, **kw): - self.await_(self._cursor.execute(query, params, **kw)) - return self - def close(self): - self.await_(self._cursor.close()) +class AsyncAdapt_psycopg_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_psycopg_cursor +): + __slots__ = ("name",) - def fetchone(self): - return self.await_(self._cursor.fetchone()) + name: str - def fetchmany(self, size=0): - return self.await_(self._cursor.fetchmany(size)) + def __init__(self, adapt_connection, name): + self.name = name + super().__init__(adapt_connection) - def fetchall(self): - return self.await_(self._cursor.fetchall()) + def _make_new_cursor(self, connection): + return connection.cursor(self.name) + # TODO: should this be on the base asyncio adapter? def __iter__(self): iterator = self._cursor.__aiter__() while True: @@ -618,35 +590,38 @@ class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor): break -class AsyncAdapt_psycopg_connection(AdaptedConnection): +class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection): __slots__ = () - await_ = staticmethod(await_only) - def __init__(self, connection) -> None: - self._connection = connection + _cursor_cls = AsyncAdapt_psycopg_cursor + _ss_cursor_cls = AsyncAdapt_psycopg_ss_cursor - def __getattr__(self, name): - return getattr(self._connection, name) + def add_notice_handler(self, handler): + self._connection.add_notice_handler(handler) - def execute(self, query, params=None, **kw): - cursor = self.await_(self._connection.execute(query, params, **kw)) - return AsyncAdapt_psycopg_cursor(cursor, self.await_) + @property + def info(self): + return self._connection.info - def cursor(self, *args, **kw): - cursor = self._connection.cursor(*args, **kw) - if hasattr(cursor, "name"): - return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_) - else: - return AsyncAdapt_psycopg_cursor(cursor, self.await_) + @property + def adapters(self): + return self._connection.adapters + + @property + def closed(self): + return self._connection.closed - def commit(self): - self.await_(self._connection.commit()) + @property + def broken(self): + return self._connection.broken - def rollback(self): - self.await_(self._connection.rollback()) + @property + def read_only(self): + return self._connection.read_only - def close(self): - self.await_(self._connection.close()) + @property + def deferrable(self): + return self._connection.deferrable @property def autocommit(self): @@ -668,15 +643,23 @@ class AsyncAdapt_psycopg_connection(AdaptedConnection): def set_deferrable(self, value): self.await_(self._connection.set_deferrable(value)) + def cursor(self, name=None, /): + if name: + return AsyncAdapt_psycopg_ss_cursor(self, name) + else: + return AsyncAdapt_psycopg_cursor(self) + -class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection): +class AsyncAdaptFallback_psycopg_connection( + AsyncAdaptFallback_dbapi_connection, AsyncAdapt_psycopg_connection +): __slots__ = () - await_ = staticmethod(await_fallback) class PsycopgAdaptDBAPI: - def __init__(self, psycopg) -> None: + def __init__(self, psycopg, ExecStatus) -> None: self.psycopg = psycopg + self.ExecStatus = ExecStatus for k, v in self.psycopg.__dict__.items(): if k != "connect": @@ -689,11 +672,11 @@ class PsycopgAdaptDBAPI: ) if util.asbool(async_fallback): return AsyncAdaptFallback_psycopg_connection( - await_fallback(creator_fn(*arg, **kw)) + self, await_fallback(creator_fn(*arg, **kw)) ) else: return AsyncAdapt_psycopg_connection( - await_only(creator_fn(*arg, **kw)) + self, await_only(creator_fn(*arg, **kw)) ) @@ -706,9 +689,7 @@ class PGDialectAsync_psycopg(PGDialect_psycopg): import psycopg from psycopg.pq import ExecStatus - AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus - - return PsycopgAdaptDBAPI(psycopg) + return PsycopgAdaptDBAPI(psycopg, ExecStatus) @classmethod def get_pool_class(cls, url): diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index d9438d1880..41e406164e 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -84,140 +84,27 @@ from .base import SQLiteExecutionContext from .pysqlite import SQLiteDialect_pysqlite from ... import pool from ... import util -from ...engine import AdaptedConnection +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection from ...util.concurrency import await_fallback from ...util.concurrency import await_only -class AsyncAdapt_aiosqlite_cursor: - # TODO: base on connectors/asyncio.py - # see #10415 - - __slots__ = ( - "_adapt_connection", - "_connection", - "description", - "await_", - "_rows", - "arraysize", - "rowcount", - "lastrowid", - ) - - server_side = False - - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - self.arraysize = 1 - self.rowcount = -1 - self.description = None - self._rows = [] - - def close(self): - self._rows[:] = [] - - def execute(self, operation, parameters=None): - try: - _cursor = self.await_(self._connection.cursor()) - - if parameters is None: - self.await_(_cursor.execute(operation)) - else: - self.await_(_cursor.execute(operation, parameters)) - - if _cursor.description: - self.description = _cursor.description - self.lastrowid = self.rowcount = -1 - - if not self.server_side: - self._rows = self.await_(_cursor.fetchall()) - else: - self.description = None - self.lastrowid = _cursor.lastrowid - self.rowcount = _cursor.rowcount - - if not self.server_side: - self.await_(_cursor.close()) - else: - self._cursor = _cursor - except Exception as error: - self._adapt_connection._handle_exception(error) - - def executemany(self, operation, seq_of_parameters): - try: - _cursor = self.await_(self._connection.cursor()) - self.await_(_cursor.executemany(operation, seq_of_parameters)) - self.description = None - self.lastrowid = _cursor.lastrowid - self.rowcount = _cursor.rowcount - self.await_(_cursor.close()) - except Exception as error: - self._adapt_connection._handle_exception(error) - - def setinputsizes(self, *inputsizes): - pass - - def __iter__(self): - while self._rows: - yield self._rows.pop(0) - - def fetchone(self): - if self._rows: - return self._rows.pop(0) - else: - return None - - def fetchmany(self, size=None): - if size is None: - size = self.arraysize - - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval - - def fetchall(self): - retval = self._rows[:] - self._rows[:] = [] - return retval - - -class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor): - # TODO: base on connectors/asyncio.py - # see #10415 - __slots__ = "_cursor" - - server_side = True - - def __init__(self, *arg, **kw): - super().__init__(*arg, **kw) - self._cursor = None - - def close(self): - if self._cursor is not None: - self.await_(self._cursor.close()) - self._cursor = None - - def fetchone(self): - return self.await_(self._cursor.fetchone()) +class AsyncAdapt_aiosqlite_cursor(AsyncAdapt_dbapi_cursor): + __slots__ = () - def fetchmany(self, size=None): - if size is None: - size = self.arraysize - return self.await_(self._cursor.fetchmany(size=size)) - def fetchall(self): - return self.await_(self._cursor.fetchall()) +class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_dbapi_ss_cursor): + __slots__ = () -class AsyncAdapt_aiosqlite_connection(AdaptedConnection): - await_ = staticmethod(await_only) - __slots__ = ("dbapi",) +class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): + __slots__ = () - def __init__(self, dbapi, connection): - self.dbapi = dbapi - self._connection = connection + _cursor_cls = AsyncAdapt_aiosqlite_cursor + _ss_cursor_cls = AsyncAdapt_aiosqlite_ss_cursor @property def isolation_level(self): @@ -249,26 +136,13 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection): except Exception as error: self._handle_exception(error) - def cursor(self, server_side=False): - if server_side: - return AsyncAdapt_aiosqlite_ss_cursor(self) - else: - return AsyncAdapt_aiosqlite_cursor(self) - - def execute(self, *args, **kw): - return self.await_(self._connection.execute(*args, **kw)) - def rollback(self): - try: - self.await_(self._connection.rollback()) - except Exception as error: - self._handle_exception(error) + if self._connection._connection: + super().rollback() def commit(self): - try: - self.await_(self._connection.commit()) - except Exception as error: - self._handle_exception(error) + if self._connection._connection: + super().commit() def close(self): try: @@ -287,22 +161,20 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection): self._handle_exception(error) def _handle_exception(self, error): - if ( - isinstance(error, ValueError) - and error.args[0] == "no active connection" + if isinstance(error, ValueError) and error.args[0].lower() in ( + "no active connection", + "connection closed", ): - raise self.dbapi.sqlite.OperationalError( - "no active connection" - ) from error + raise self.dbapi.sqlite.OperationalError(error.args[0]) from error else: - raise error + super()._handle_exception(error) -class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection): +class AsyncAdaptFallback_aiosqlite_connection( + AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aiosqlite_connection +): __slots__ = () - await_ = staticmethod(await_fallback) - class AsyncAdapt_aiosqlite_dbapi: def __init__(self, aiosqlite, sqlite): @@ -382,10 +254,13 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): return pool.StaticPool def is_disconnect(self, e, connection, cursor): - if isinstance( - e, self.dbapi.OperationalError - ) and "no active connection" in str(e): - return True + if isinstance(e, self.dbapi.OperationalError): + err_lower = str(e).lower() + if ( + "no active connection" in err_lower + or "connection closed" in err_lower + ): + return True return super().is_disconnect(e, connection, cursor)