]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
adapt all asyncio dialects to asyncio connector
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Nov 2023 20:52:43 +0000 (15:52 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Nov 2023 16:44:47 +0000 (11:44 -0500)
Adapted all asyncio dialects, including aiosqlite, aiomysql, asyncmy,
psycopg, asyncpg to use the generic asyncio connection adapter first added
in :ticket:`6521` for the aioodbc DBAPI, allowing these dialects to take
advantage of a common framework.

Fixes: #10415
Change-Id: I24123175aa787f3a2c550d9e02d3827173794e3b

doc/build/changelog/unreleased_21/10415.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/aioodbc.py
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py

diff --git a/doc/build/changelog/unreleased_21/10415.rst b/doc/build/changelog/unreleased_21/10415.rst
new file mode 100644 (file)
index 0000000..ee96c2d
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: change, asyncio
+    :tickets: 10415
+
+    Adapted all asyncio dialects, including aiosqlite, aiomysql, asyncmy,
+    psycopg, asyncpg to use the generic asyncio connection adapter first added
+    in :ticket:`6521` for the aioodbc DBAPI, allowing these dialects to take
+    advantage of a common framework.
index c6986366e1c3f9b4b6c83d7952cbdcdc497f9a31..e0f5f55474fe460397295ecc834fb2888d14145a 100644 (file)
@@ -58,6 +58,15 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection):
 
         self._connection._conn.autocommit = value
 
+    def ping(self, reconnect):
+        return self.await_(self._connection.ping(reconnect))
+
+    def add_output_converter(self, *arg, **kw):
+        self._connection.add_output_converter(*arg, **kw)
+
+    def character_set_name(self):
+        return self._connection.character_set_name()
+
     def cursor(self, server_side=False):
         # aioodbc sets connection=None when closed and just fails with
         # AttributeError here.  Here we use the same ProgrammingError +
index 997407ccd58429ba671dc7685ea22b1eadf7d961..9358457ceb26caf3e0d8541ca2a2f9f575317d46 100644 (file)
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """generic asyncio-adapted versions of DBAPI connection and cursor"""
 
 from __future__ import annotations
 
+import asyncio
 import collections
 import itertools
+import sys
+from typing import Any
+from typing import Deque
+from typing import Iterator
+from typing import NoReturn
+from typing import Optional
+from typing import Protocol
+from typing import Sequence
 
 from ..engine import AdaptedConnection
-from ..util.concurrency import asyncio
+from ..engine.interfaces import _DBAPICursorDescription
+from ..engine.interfaces import _DBAPIMultiExecuteParams
+from ..engine.interfaces import _DBAPISingleExecuteParams
 from ..util.concurrency import await_fallback
 from ..util.concurrency import await_only
+from ..util.typing import Self
+
+
+class AsyncIODBAPIConnection(Protocol):
+    """protocol representing an async adapted version of a
+    :pep:`249` database connection.
+
+
+    """
+
+    async def close(self) -> None:
+        ...
+
+    async def commit(self) -> None:
+        ...
+
+    def cursor(self) -> AsyncIODBAPICursor:
+        ...
+
+    async def rollback(self) -> None:
+        ...
+
+
+class AsyncIODBAPICursor(Protocol):
+    """protocol representing an async adapted version
+    of a :pep:`249` database cursor.
+
+
+    """
+
+    def __aenter__(self) -> Any:
+        ...
+
+    @property
+    def description(
+        self,
+    ) -> _DBAPICursorDescription:
+        """The description attribute of the Cursor."""
+        ...
+
+    @property
+    def rowcount(self) -> int:
+        ...
+
+    arraysize: int
+
+    lastrowid: int
+
+    async def close(self) -> None:
+        ...
+
+    async def execute(
+        self,
+        operation: Any,
+        parameters: Optional[_DBAPISingleExecuteParams] = None,
+    ) -> Any:
+        ...
+
+    async def executemany(
+        self,
+        operation: Any,
+        parameters: _DBAPIMultiExecuteParams,
+    ) -> Any:
+        ...
+
+    async def fetchone(self) -> Optional[Any]:
+        ...
+
+    async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]:
+        ...
+
+    async def fetchall(self) -> Sequence[Any]:
+        ...
+
+    async def setinputsizes(self, sizes: Sequence[Any]) -> None:
+        ...
+
+    def setoutputsize(self, size: Any, column: Any) -> None:
+        ...
+
+    async def callproc(
+        self, procname: str, parameters: Sequence[Any] = ...
+    ) -> Any:
+        ...
+
+    async def nextset(self) -> Optional[bool]:
+        ...
 
 
 class AsyncAdapt_dbapi_cursor:
@@ -29,52 +126,85 @@ class AsyncAdapt_dbapi_cursor:
         "_rows",
     )
 
-    def __init__(self, adapt_connection):
+    _cursor: AsyncIODBAPICursor
+    _adapt_connection: AsyncAdapt_dbapi_connection
+    _connection: AsyncIODBAPIConnection
+    _rows: Deque[Any]
+
+    def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
         self._adapt_connection = adapt_connection
         self._connection = adapt_connection._connection
         self.await_ = adapt_connection.await_
 
-        cursor = self._connection.cursor()
+        cursor = self._make_new_cursor(self._connection)
+
+        try:
+            self._cursor = self.await_(cursor.__aenter__())
+        except Exception as error:
+            self._adapt_connection._handle_exception(error)
 
-        self._cursor = self.await_(cursor.__aenter__())
         self._rows = collections.deque()
 
+    def _make_new_cursor(
+        self, connection: AsyncIODBAPIConnection
+    ) -> AsyncIODBAPICursor:
+        return connection.cursor()
+
     @property
-    def description(self):
+    def description(self) -> Optional[_DBAPICursorDescription]:
         return self._cursor.description
 
     @property
-    def rowcount(self):
+    def rowcount(self) -> int:
         return self._cursor.rowcount
 
     @property
-    def arraysize(self):
+    def arraysize(self) -> int:
         return self._cursor.arraysize
 
     @arraysize.setter
-    def arraysize(self, value):
+    def arraysize(self, value: int) -> None:
         self._cursor.arraysize = value
 
     @property
-    def lastrowid(self):
+    def lastrowid(self) -> int:
         return self._cursor.lastrowid
 
-    def close(self):
+    def close(self) -> None:
         # note we aren't actually closing the cursor here,
         # we are just letting GC do it.  see notes in aiomysql dialect
         self._rows.clear()
 
-    def execute(self, operation, parameters=None):
-        return self.await_(self._execute_async(operation, parameters))
-
-    def executemany(self, operation, seq_of_parameters):
-        return self.await_(
-            self._executemany_async(operation, seq_of_parameters)
-        )
+    def execute(
+        self,
+        operation: Any,
+        parameters: Optional[_DBAPISingleExecuteParams] = None,
+    ) -> Any:
+        try:
+            return self.await_(self._execute_async(operation, parameters))
+        except Exception as error:
+            self._adapt_connection._handle_exception(error)
+
+    def executemany(
+        self,
+        operation: Any,
+        seq_of_parameters: _DBAPIMultiExecuteParams,
+    ) -> Any:
+        try:
+            return self.await_(
+                self._executemany_async(operation, seq_of_parameters)
+            )
+        except Exception as error:
+            self._adapt_connection._handle_exception(error)
 
-    async def _execute_async(self, operation, parameters):
+    async def _execute_async(
+        self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
+    ) -> Any:
         async with self._adapt_connection._execute_mutex:
-            result = await self._cursor.execute(operation, parameters or ())
+            if parameters is None:
+                result = await self._cursor.execute(operation)
+            else:
+                result = await self._cursor.execute(operation, parameters)
 
             if self._cursor.description and not self.server_side:
                 # aioodbc has a "fake" async result, so we have to pull it out
@@ -84,35 +214,45 @@ class AsyncAdapt_dbapi_cursor:
                 self._rows = collections.deque(await self._cursor.fetchall())
             return result
 
-    async def _executemany_async(self, operation, seq_of_parameters):
+    async def _executemany_async(
+        self,
+        operation: Any,
+        seq_of_parameters: _DBAPIMultiExecuteParams,
+    ) -> Any:
         async with self._adapt_connection._execute_mutex:
             return await self._cursor.executemany(operation, seq_of_parameters)
 
-    def nextset(self):
+    def nextset(self) -> None:
         self.await_(self._cursor.nextset())
         if self._cursor.description and not self.server_side:
             self._rows = collections.deque(
                 self.await_(self._cursor.fetchall())
             )
 
-    def setinputsizes(self, *inputsizes):
+    def setinputsizes(self, *inputsizes: Any) -> None:
         # NOTE: this is overrridden in aioodbc due to
         # see https://github.com/aio-libs/aioodbc/issues/451
         # right now
 
         return self.await_(self._cursor.setinputsizes(*inputsizes))
 
-    def __iter__(self):
+    def __enter__(self) -> Self:
+        return self
+
+    def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
+        self.close()
+
+    def __iter__(self) -> Iterator[Any]:
         while self._rows:
             yield self._rows.popleft()
 
-    def fetchone(self):
+    def fetchone(self) -> Optional[Any]:
         if self._rows:
             return self._rows.popleft()
         else:
             return None
 
-    def fetchmany(self, size=None):
+    def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
         if size is None:
             size = self.arraysize
 
@@ -121,7 +261,7 @@ class AsyncAdapt_dbapi_cursor:
         self._rows = collections.deque(rr)
         return retval
 
-    def fetchall(self):
+    def fetchall(self) -> Sequence[Any]:
         retval = list(self._rows)
         self._rows.clear()
         return retval
@@ -131,27 +271,18 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
     __slots__ = ()
     server_side = True
 
-    def __init__(self, adapt_connection):
-        self._adapt_connection = adapt_connection
-        self._connection = adapt_connection._connection
-        self.await_ = adapt_connection.await_
-
-        cursor = self._connection.cursor()
-
-        self._cursor = self.await_(cursor.__aenter__())
-
-    def close(self):
+    def close(self) -> None:
         if self._cursor is not None:
             self.await_(self._cursor.close())
-            self._cursor = None
+            self._cursor = None  # type: ignore
 
-    def fetchone(self):
+    def fetchone(self) -> Optional[Any]:
         return self.await_(self._cursor.fetchone())
 
-    def fetchmany(self, size=None):
+    def fetchmany(self, size: Optional[int] = None) -> Any:
         return self.await_(self._cursor.fetchmany(size=size))
 
-    def fetchall(self):
+    def fetchall(self) -> Sequence[Any]:
         return self.await_(self._cursor.fetchall())
 
 
@@ -162,44 +293,47 @@ class AsyncAdapt_dbapi_connection(AdaptedConnection):
     await_ = staticmethod(await_only)
     __slots__ = ("dbapi", "_execute_mutex")
 
-    def __init__(self, dbapi, connection):
+    _connection: AsyncIODBAPIConnection
+
+    def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
         self.dbapi = dbapi
         self._connection = connection
         self._execute_mutex = asyncio.Lock()
 
-    def ping(self, reconnect):
-        return self.await_(self._connection.ping(reconnect))
-
-    def add_output_converter(self, *arg, **kw):
-        self._connection.add_output_converter(*arg, **kw)
-
-    def character_set_name(self):
-        return self._connection.character_set_name()
-
-    @property
-    def autocommit(self):
-        return self._connection.autocommit
-
-    @autocommit.setter
-    def autocommit(self, value):
-        # https://github.com/aio-libs/aioodbc/issues/448
-        # self._connection.autocommit = value
-
-        self._connection._conn.autocommit = value
-
-    def cursor(self, server_side=False):
+    def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
         if server_side:
             return self._ss_cursor_cls(self)
         else:
             return self._cursor_cls(self)
 
-    def rollback(self):
-        self.await_(self._connection.rollback())
-
-    def commit(self):
-        self.await_(self._connection.commit())
-
-    def close(self):
+    def execute(
+        self,
+        operation: Any,
+        parameters: Optional[_DBAPISingleExecuteParams] = None,
+    ) -> Any:
+        """lots of DBAPIs seem to provide this, so include it"""
+        cursor = self.cursor()
+        cursor.execute(operation, parameters)
+        return cursor
+
+    def _handle_exception(self, error: Exception) -> NoReturn:
+        exc_info = sys.exc_info()
+
+        raise error.with_traceback(exc_info[2])
+
+    def rollback(self) -> None:
+        try:
+            self.await_(self._connection.rollback())
+        except Exception as error:
+            self._handle_exception(error)
+
+    def commit(self) -> None:
+        try:
+            self.await_(self._connection.commit())
+        except Exception as error:
+            self._handle_exception(error)
+
+    def close(self) -> None:
         self.await_(self._connection.close())
 
 
index 2a0c6ba7832eaa3cb799ec12339acf340aec15bc..41f4c09e9329a385422c3837d4ae83ea5427b338 100644 (file)
@@ -30,158 +30,40 @@ This dialect should normally be used only with the
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
 from ... import util
-from ...engine import AdaptedConnection
-from ...util.concurrency import asyncio
+from ...connectors.asyncio import AsyncAdapt_dbapi_connection
+from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
 
-class AsyncAdapt_aiomysql_cursor:
-    # TODO: base on connectors/asyncio.py
-    # see #10415
-    server_side = False
-    __slots__ = (
-        "_adapt_connection",
-        "_connection",
-        "await_",
-        "_cursor",
-        "_rows",
-    )
-
-    def __init__(self, adapt_connection):
-        self._adapt_connection = adapt_connection
-        self._connection = adapt_connection._connection
-        self.await_ = adapt_connection.await_
-
-        cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
-
-        # see https://github.com/aio-libs/aiomysql/issues/543
-        self._cursor = self.await_(cursor.__aenter__())
-        self._rows = []
-
-    @property
-    def description(self):
-        return self._cursor.description
-
-    @property
-    def rowcount(self):
-        return self._cursor.rowcount
+class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
+    __slots__ = ()
 
-    @property
-    def arraysize(self):
-        return self._cursor.arraysize
+    def _make_new_cursor(self, connection):
+        return connection.cursor(self._adapt_connection.dbapi.Cursor)
 
-    @arraysize.setter
-    def arraysize(self, value):
-        self._cursor.arraysize = value
 
-    @property
-    def lastrowid(self):
-        return self._cursor.lastrowid
+class AsyncAdapt_aiomysql_ss_cursor(
+    AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_aiomysql_cursor
+):
+    __slots__ = ()
 
-    def close(self):
-        # note we aren't actually closing the cursor here,
-        # we are just letting GC do it.   to allow this to be async
-        # we would need the Result to change how it does "Safe close cursor".
-        # MySQL "cursors" don't actually have state to be "closed" besides
-        # exhausting rows, which we already have done for sync cursor.
-        # another option would be to emulate aiosqlite dialect and assign
-        # cursor only if we are doing server side cursor operation.
-        self._rows[:] = []
-
-    def execute(self, operation, parameters=None):
-        return self.await_(self._execute_async(operation, parameters))
-
-    def executemany(self, operation, seq_of_parameters):
-        return self.await_(
-            self._executemany_async(operation, seq_of_parameters)
+    def _make_new_cursor(self, connection):
+        return connection.cursor(
+            self._adapt_connection.dbapi.aiomysql.cursors.SSCursor
         )
 
-    async def _execute_async(self, operation, parameters):
-        async with self._adapt_connection._execute_mutex:
-            result = await self._cursor.execute(operation, parameters)
-
-            if not self.server_side:
-                # aiomysql has a "fake" async result, so we have to pull it out
-                # of that here since our default result is not async.
-                # we could just as easily grab "_rows" here and be done with it
-                # but this is safer.
-                self._rows = list(await self._cursor.fetchall())
-            return result
-
-    async def _executemany_async(self, operation, seq_of_parameters):
-        async with self._adapt_connection._execute_mutex:
-            return await self._cursor.executemany(operation, seq_of_parameters)
-
-    def setinputsizes(self, *inputsizes):
-        pass
-
-    def __iter__(self):
-        while self._rows:
-            yield self._rows.pop(0)
-
-    def fetchone(self):
-        if self._rows:
-            return self._rows.pop(0)
-        else:
-            return None
-
-    def fetchmany(self, size=None):
-        if size is None:
-            size = self.arraysize
-
-        retval = self._rows[0:size]
-        self._rows[:] = self._rows[size:]
-        return retval
-
-    def fetchall(self):
-        retval = self._rows[:]
-        self._rows[:] = []
-        return retval
 
-
-class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
-    # TODO: base on connectors/asyncio.py
-    # see #10415
+class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
     __slots__ = ()
-    server_side = True
-
-    def __init__(self, adapt_connection):
-        self._adapt_connection = adapt_connection
-        self._connection = adapt_connection._connection
-        self.await_ = adapt_connection.await_
-
-        cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
-
-        self._cursor = self.await_(cursor.__aenter__())
-
-    def close(self):
-        if self._cursor is not None:
-            self.await_(self._cursor.close())
-            self._cursor = None
-
-    def fetchone(self):
-        return self.await_(self._cursor.fetchone())
-
-    def fetchmany(self, size=None):
-        return self.await_(self._cursor.fetchmany(size=size))
 
-    def fetchall(self):
-        return self.await_(self._cursor.fetchall())
-
-
-class AsyncAdapt_aiomysql_connection(AdaptedConnection):
-    # TODO: base on connectors/asyncio.py
-    # see #10415
-    await_ = staticmethod(await_only)
-    __slots__ = ("dbapi", "_execute_mutex")
-
-    def __init__(self, dbapi, connection):
-        self.dbapi = dbapi
-        self._connection = connection
-        self._execute_mutex = asyncio.Lock()
+    _cursor_cls = AsyncAdapt_aiomysql_cursor
+    _ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor
 
     def ping(self, reconnect):
+        assert not reconnect
         return self.await_(self._connection.ping(reconnect))
 
     def character_set_name(self):
@@ -190,30 +72,16 @@ class AsyncAdapt_aiomysql_connection(AdaptedConnection):
     def autocommit(self, value):
         self.await_(self._connection.autocommit(value))
 
-    def cursor(self, server_side=False):
-        if server_side:
-            return AsyncAdapt_aiomysql_ss_cursor(self)
-        else:
-            return AsyncAdapt_aiomysql_cursor(self)
-
-    def rollback(self):
-        self.await_(self._connection.rollback())
-
-    def commit(self):
-        self.await_(self._connection.commit())
-
     def close(self):
         # it's not awaitable.
         self._connection.close()
 
 
-class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
-    # TODO: base on connectors/asyncio.py
-    # see #10415
+class AsyncAdaptFallback_aiomysql_connection(
+    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aiomysql_connection
+):
     __slots__ = ()
 
-    await_ = staticmethod(await_fallback)
-
 
 class AsyncAdapt_aiomysql_dbapi:
     def __init__(self, aiomysql, pymysql):
index 92058d60dd39ab2f95b02d4d7351e9767bb43dfd..c5caf79d3ab2c64691cf89396664c379c9380da2 100644 (file)
@@ -25,183 +25,58 @@ This dialect should normally be used only with the
 
 
 """  # noqa
-from contextlib import asynccontextmanager
+from __future__ import annotations
 
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
 from ... import util
-from ...engine import AdaptedConnection
-from ...util.concurrency import asyncio
+from ...connectors.asyncio import AsyncAdapt_dbapi_connection
+from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
 
-class AsyncAdapt_asyncmy_cursor:
-    # TODO: base on connectors/asyncio.py
-    # see #10415
-    server_side = False
-    __slots__ = (
-        "_adapt_connection",
-        "_connection",
-        "await_",
-        "_cursor",
-        "_rows",
-    )
-
-    def __init__(self, adapt_connection):
-        self._adapt_connection = adapt_connection
-        self._connection = adapt_connection._connection
-        self.await_ = adapt_connection.await_
-
-        cursor = self._connection.cursor()
-
-        self._cursor = self.await_(cursor.__aenter__())
-        self._rows = []
-
-    @property
-    def description(self):
-        return self._cursor.description
-
-    @property
-    def rowcount(self):
-        return self._cursor.rowcount
-
-    @property
-    def arraysize(self):
-        return self._cursor.arraysize
-
-    @arraysize.setter
-    def arraysize(self, value):
-        self._cursor.arraysize = value
-
-    @property
-    def lastrowid(self):
-        return self._cursor.lastrowid
-
-    def close(self):
-        # note we aren't actually closing the cursor here,
-        # we are just letting GC do it.   to allow this to be async
-        # we would need the Result to change how it does "Safe close cursor".
-        # MySQL "cursors" don't actually have state to be "closed" besides
-        # exhausting rows, which we already have done for sync cursor.
-        # another option would be to emulate aiosqlite dialect and assign
-        # cursor only if we are doing server side cursor operation.
-        self._rows[:] = []
-
-    def execute(self, operation, parameters=None):
-        return self.await_(self._execute_async(operation, parameters))
-
-    def executemany(self, operation, seq_of_parameters):
-        return self.await_(
-            self._executemany_async(operation, seq_of_parameters)
-        )
-
-    async def _execute_async(self, operation, parameters):
-        async with self._adapt_connection._mutex_and_adapt_errors():
-            if parameters is None:
-                result = await self._cursor.execute(operation)
-            else:
-                result = await self._cursor.execute(operation, parameters)
-
-            if not self.server_side:
-                # asyncmy has a "fake" async result, so we have to pull it out
-                # of that here since our default result is not async.
-                # we could just as easily grab "_rows" here and be done with it
-                # but this is safer.
-                self._rows = list(await self._cursor.fetchall())
-            return result
-
-    async def _executemany_async(self, operation, seq_of_parameters):
-        async with self._adapt_connection._mutex_and_adapt_errors():
-            return await self._cursor.executemany(operation, seq_of_parameters)
-
-    def setinputsizes(self, *inputsizes):
-        pass
-
-    def __iter__(self):
-        while self._rows:
-            yield self._rows.pop(0)
-
-    def fetchone(self):
-        if self._rows:
-            return self._rows.pop(0)
-        else:
-            return None
-
-    def fetchmany(self, size=None):
-        if size is None:
-            size = self.arraysize
-
-        retval = self._rows[0:size]
-        self._rows[:] = self._rows[size:]
-        return retval
-
-    def fetchall(self):
-        retval = self._rows[:]
-        self._rows[:] = []
-        return retval
+class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor):
+    __slots__ = ()
 
 
-class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
-    # TODO: base on connectors/asyncio.py
-    # see #10415
+class AsyncAdapt_asyncmy_ss_cursor(
+    AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncmy_cursor
+):
     __slots__ = ()
-    server_side = True
 
-    def __init__(self, adapt_connection):
-        self._adapt_connection = adapt_connection
-        self._connection = adapt_connection._connection
-        self.await_ = adapt_connection.await_
-
-        cursor = self._connection.cursor(
-            adapt_connection.dbapi.asyncmy.cursors.SSCursor
+    def _make_new_cursor(self, connection):
+        return connection.cursor(
+            self._adapt_connection.dbapi.asyncmy.cursors.SSCursor
         )
 
-        self._cursor = self.await_(cursor.__aenter__())
-
-    def close(self):
-        if self._cursor is not None:
-            self.await_(self._cursor.close())
-            self._cursor = None
-
-    def fetchone(self):
-        return self.await_(self._cursor.fetchone())
-
-    def fetchmany(self, size=None):
-        return self.await_(self._cursor.fetchmany(size=size))
-
-    def fetchall(self):
-        return self.await_(self._cursor.fetchall())
 
+class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
+    __slots__ = ()
 
-class AsyncAdapt_asyncmy_connection(AdaptedConnection):
-    # TODO: base on connectors/asyncio.py
-    # see #10415
-    await_ = staticmethod(await_only)
-    __slots__ = ("dbapi", "_execute_mutex")
+    _cursor_cls = AsyncAdapt_asyncmy_cursor
+    _ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
 
-    def __init__(self, dbapi, connection):
-        self.dbapi = dbapi
-        self._connection = connection
-        self._execute_mutex = asyncio.Lock()
+    def _handle_exception(self, error):
+        if isinstance(error, AttributeError):
+            raise self.dbapi.InternalError(
+                "network operation failed due to asyncmy attribute error"
+            )
 
-    @asynccontextmanager
-    async def _mutex_and_adapt_errors(self):
-        async with self._execute_mutex:
-            try:
-                yield
-            except AttributeError:
-                raise self.dbapi.InternalError(
-                    "network operation failed due to asyncmy attribute error"
-                )
+        raise error
 
     def ping(self, reconnect):
         assert not reconnect
         return self.await_(self._do_ping())
 
     async def _do_ping(self):
-        async with self._mutex_and_adapt_errors():
-            return await self._connection.ping(False)
+        try:
+            async with self._execute_mutex:
+                return await self._connection.ping(False)
+        except Exception as error:
+            self._handle_exception(error)
 
     def character_set_name(self):
         return self._connection.character_set_name()
@@ -209,28 +84,16 @@ class AsyncAdapt_asyncmy_connection(AdaptedConnection):
     def autocommit(self, value):
         self.await_(self._connection.autocommit(value))
 
-    def cursor(self, server_side=False):
-        if server_side:
-            return AsyncAdapt_asyncmy_ss_cursor(self)
-        else:
-            return AsyncAdapt_asyncmy_cursor(self)
-
-    def rollback(self):
-        self.await_(self._connection.rollback())
-
-    def commit(self):
-        self.await_(self._connection.commit())
-
     def close(self):
         # it's not awaitable.
         self._connection.close()
 
 
-class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
+class AsyncAdaptFallback_asyncmy_connection(
+    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_asyncmy_connection
+):
     __slots__ = ()
 
-    await_ = staticmethod(await_fallback)
-
 
 def _Binary(x):
     """Return x as a binary type."""
index ca35bf96075d249d679243e9df922bdd2f4cfc28..d57c94a170f9c3d085be8cc4e8ba9a473b808b27 100644 (file)
@@ -187,7 +187,14 @@ import decimal
 import json as _py_json
 import re
 import time
+from typing import Any
 from typing import cast
+from typing import Iterable
+from typing import NoReturn
+from typing import Optional
+from typing import Protocol
+from typing import Sequence
+from typing import Tuple
 from typing import TYPE_CHECKING
 
 from . import json
@@ -211,15 +218,16 @@ from .types import CITEXT
 from ... import exc
 from ... import pool
 from ... import util
-from ...engine import AdaptedConnection
+from ...connectors.asyncio import AsyncAdapt_dbapi_connection
+from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
 from ...engine import processors
 from ...sql import sqltypes
-from ...util.concurrency import asyncio
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
 if TYPE_CHECKING:
-    from typing import Iterable
+    from ...engine.interfaces import _DBAPICursorDescription
 
 
 class AsyncpgARRAY(PGARRAY):
@@ -489,33 +497,72 @@ class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
     pass
 
 
-class AsyncAdapt_asyncpg_cursor:
+class _AsyncpgConnection(Protocol):
+    async def executemany(
+        self, operation: Any, seq_of_parameters: Sequence[Tuple[Any, ...]]
+    ) -> Any:
+        ...
+
+    async def reload_schema_state(self) -> None:
+        ...
+
+    async def prepare(
+        self, operation: Any, *, name: Optional[str] = None
+    ) -> Any:
+        ...
+
+    def is_closed(self) -> bool:
+        ...
+
+    def transaction(
+        self,
+        *,
+        isolation: Optional[str] = None,
+        readonly: bool = False,
+        deferrable: bool = False,
+    ) -> Any:
+        ...
+
+    def fetchrow(self, operation: str) -> Any:
+        ...
+
+    async def close(self) -> None:
+        ...
+
+    def terminate(self) -> None:
+        ...
+
+
+class _AsyncpgCursor(Protocol):
+    def fetch(self, size: int) -> Any:
+        ...
+
+
+class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
     __slots__ = (
-        "_adapt_connection",
-        "_connection",
-        "_rows",
-        "description",
-        "arraysize",
-        "rowcount",
-        "_cursor",
+        "_description",
+        "_arraysize",
+        "_rowcount",
         "_invalidate_schema_cache_asof",
     )
 
     server_side = False
 
-    def __init__(self, adapt_connection):
+    _adapt_connection: AsyncAdapt_asyncpg_connection
+    _connection: _AsyncpgConnection
+    _cursor: Optional[_AsyncpgCursor]
+
+    def __init__(self, adapt_connection: AsyncAdapt_asyncpg_connection):
         self._adapt_connection = adapt_connection
         self._connection = adapt_connection._connection
-        self._rows = []
+        self.await_ = adapt_connection.await_
         self._cursor = None
-        self.description = None
-        self.arraysize = 1
-        self.rowcount = -1
+        self._rows = collections.deque()
+        self._description = None
+        self._arraysize = 1
+        self._rowcount = -1
         self._invalidate_schema_cache_asof = 0
 
-    def close(self):
-        self._rows[:] = []
-
     def _handle_exception(self, error):
         self._adapt_connection._handle_exception(error)
 
@@ -535,7 +582,7 @@ class AsyncAdapt_asyncpg_cursor:
                 )
 
                 if attributes:
-                    self.description = [
+                    self._description = [
                         (
                             attr.name,
                             attr.type.oid,
@@ -548,30 +595,48 @@ class AsyncAdapt_asyncpg_cursor:
                         for attr in attributes
                     ]
                 else:
-                    self.description = None
+                    self._description = None
 
                 if self.server_side:
                     self._cursor = await prepared_stmt.cursor(*parameters)
-                    self.rowcount = -1
+                    self._rowcount = -1
                 else:
-                    self._rows = await prepared_stmt.fetch(*parameters)
+                    self._rows = collections.deque(
+                        await prepared_stmt.fetch(*parameters)
+                    )
                     status = prepared_stmt.get_statusmsg()
 
                     reg = re.match(
                         r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status
                     )
                     if reg:
-                        self.rowcount = int(reg.group(1))
+                        self._rowcount = int(reg.group(1))
                     else:
-                        self.rowcount = -1
+                        self._rowcount = -1
 
             except Exception as error:
                 self._handle_exception(error)
 
+    @property
+    def description(self) -> Optional[_DBAPICursorDescription]:
+        return self._description
+
+    @property
+    def rowcount(self) -> int:
+        return self._rowcount
+
+    @property
+    def arraysize(self) -> int:
+        return self._arraysize
+
+    @arraysize.setter
+    def arraysize(self, value: int) -> None:
+        self._arraysize = value
+
     async def _executemany(self, operation, seq_of_parameters):
         adapt_connection = self._adapt_connection
 
-        self.description = None
+        self._description = None
         async with adapt_connection._execute_mutex:
             await adapt_connection._check_type_cache_invalidation(
                 self._invalidate_schema_cache_asof
@@ -600,31 +665,10 @@ class AsyncAdapt_asyncpg_cursor:
     def setinputsizes(self, *inputsizes):
         raise NotImplementedError()
 
-    def __iter__(self):
-        while self._rows:
-            yield self._rows.pop(0)
-
-    def fetchone(self):
-        if self._rows:
-            return self._rows.pop(0)
-        else:
-            return None
-
-    def fetchmany(self, size=None):
-        if size is None:
-            size = self.arraysize
-
-        retval = self._rows[0:size]
-        self._rows[:] = self._rows[size:]
-        return retval
-
-    def fetchall(self):
-        retval = self._rows[:]
-        self._rows[:] = []
-        return retval
 
-
-class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
+class AsyncAdapt_asyncpg_ss_cursor(
+    AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncpg_cursor
+):
     server_side = True
     __slots__ = ("_rowbuffer",)
 
@@ -637,6 +681,7 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
         self._rowbuffer = None
 
     def _buffer_rows(self):
+        assert self._cursor is not None
         new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
         self._rowbuffer = collections.deque(new_rows)
 
@@ -669,6 +714,9 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
         if not self._rowbuffer:
             self._buffer_rows()
 
+        assert self._rowbuffer is not None
+        assert self._cursor is not None
+
         buf = list(self._rowbuffer)
         lb = len(buf)
         if size > lb:
@@ -681,6 +729,8 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
         return result
 
     def fetchall(self):
+        assert self._rowbuffer is not None
+
         ret = list(self._rowbuffer) + list(
             self._adapt_connection.await_(self._all())
         )
@@ -690,6 +740,8 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
     async def _all(self):
         rows = []
 
+        assert self._cursor is not None
+
         # TODO: looks like we have to hand-roll some kind of batching here.
         # hardcoding for the moment but this should be improved.
         while True:
@@ -707,9 +759,13 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
         )
 
 
-class AsyncAdapt_asyncpg_connection(AdaptedConnection):
+class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
+    _cursor_cls = AsyncAdapt_asyncpg_cursor
+    _ss_cursor_cls = AsyncAdapt_asyncpg_ss_cursor
+
+    _connection: _AsyncpgConnection
+
     __slots__ = (
-        "dbapi",
         "isolation_level",
         "_isolation_setting",
         "readonly",
@@ -719,11 +775,8 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
         "_prepared_statement_cache",
         "_prepared_statement_name_func",
         "_invalidate_schema_cache_asof",
-        "_execute_mutex",
     )
 
-    await_ = staticmethod(await_only)
-
     def __init__(
         self,
         dbapi,
@@ -731,15 +784,13 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
         prepared_statement_cache_size=100,
         prepared_statement_name_func=None,
     ):
-        self.dbapi = dbapi
-        self._connection = connection
+        super().__init__(dbapi, connection)
         self.isolation_level = self._isolation_setting = "read_committed"
         self.readonly = False
         self.deferrable = False
         self._transaction = None
         self._started = False
         self._invalidate_schema_cache_asof = time.time()
-        self._execute_mutex = asyncio.Lock()
 
         if prepared_statement_cache_size:
             self._prepared_statement_cache = util.LRUCache(
@@ -789,7 +840,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
 
         return prepared_stmt, attributes
 
-    def _handle_exception(self, error):
+    def _handle_exception(self, error: Exception) -> NoReturn:
         if self._connection.is_closed():
             self._transaction = None
             self._started = False
@@ -807,9 +858,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
                     ) = getattr(error, "sqlstate", None)
                     raise translated_error from error
             else:
-                raise error
+                super()._handle_exception(error)
         else:
-            raise error
+            super()._handle_exception(error)
 
     @property
     def autocommit(self):
@@ -862,14 +913,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
         else:
             self._started = True
 
-    def cursor(self, server_side=False):
-        if server_side:
-            return AsyncAdapt_asyncpg_ss_cursor(self)
-        else:
-            return AsyncAdapt_asyncpg_cursor(self)
-
     def rollback(self):
         if self._started:
+            assert self._transaction is not None
             try:
                 self.await_(self._transaction.rollback())
             except Exception as error:
@@ -880,6 +926,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
 
     def commit(self):
         if self._started:
+            assert self._transaction is not None
             try:
                 self.await_(self._transaction.commit())
             except Exception as error:
index dcd69ce6631fc8a1d58ddcd4f699015adf0ae12d..485687638050bc33c2075cead3bb6275b643b696 100644 (file)
@@ -53,6 +53,7 @@ The asyncio version of the dialect may also be specified explicitly using the
 """  # noqa
 from __future__ import annotations
 
+import collections
 import logging
 import re
 from typing import cast
@@ -71,7 +72,10 @@ from .json import JSONPathType
 from .types import CITEXT
 from ... import pool
 from ... import util
-from ...engine import AdaptedConnection
+from ...connectors.asyncio import AsyncAdapt_dbapi_connection
+from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
 from ...sql import sqltypes
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
@@ -492,7 +496,8 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
         try:
             if not before_autocommit:
                 self._do_autocommit(dbapi_conn, True)
-            dbapi_conn.execute(command)
+            with dbapi_conn.cursor() as cursor:
+                cursor.execute(command)
         finally:
             if not before_autocommit:
                 self._do_autocommit(dbapi_conn, before_autocommit)
@@ -522,93 +527,60 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
         return ";"
 
 
-class AsyncAdapt_psycopg_cursor:
-    __slots__ = ("_cursor", "await_", "_rows")
-
-    _psycopg_ExecStatus = None
-
-    def __init__(self, cursor, await_) -> None:
-        self._cursor = cursor
-        self.await_ = await_
-        self._rows = []
-
-    def __getattr__(self, name):
-        return getattr(self._cursor, name)
-
-    @property
-    def arraysize(self):
-        return self._cursor.arraysize
-
-    @arraysize.setter
-    def arraysize(self, value):
-        self._cursor.arraysize = value
+class AsyncAdapt_psycopg_cursor(AsyncAdapt_dbapi_cursor):
+    __slots__ = ()
 
     def close(self):
         self._rows.clear()
         # Normal cursor just call _close() in a non-sync way.
         self._cursor._close()
 
-    def execute(self, query, params=None, **kw):
-        result = self.await_(self._cursor.execute(query, params, **kw))
+    async def _execute_async(self, operation, parameters):
+        # override to not use mutex, psycopg3 already has mutex
+
+        if parameters is None:
+            result = await self._cursor.execute(operation)
+        else:
+            result = await self._cursor.execute(operation, parameters)
+
         # sqlalchemy result is not async, so need to pull all rows here
+        # (assuming not a server side cursor)
         res = self._cursor.pgresult
 
         # don't rely on psycopg providing enum symbols, compare with
         # eq/ne
-        if res and res.status == self._psycopg_ExecStatus.TUPLES_OK:
-            rows = self.await_(self._cursor.fetchall())
-            if not isinstance(rows, list):
-                self._rows = list(rows)
-            else:
-                self._rows = rows
+        if (
+            not self.server_side
+            and res
+            and res.status == self._adapt_connection.dbapi.ExecStatus.TUPLES_OK
+        ):
+            self._rows = collections.deque(await self._cursor.fetchall())
         return result
 
-    def executemany(self, query, params_seq):
-        return self.await_(self._cursor.executemany(query, params_seq))
-
-    def __iter__(self):
-        # TODO: try to avoid pop(0) on a list
-        while self._rows:
-            yield self._rows.pop(0)
-
-    def fetchone(self):
-        if self._rows:
-            # TODO: try to avoid pop(0) on a list
-            return self._rows.pop(0)
-        else:
-            return None
-
-    def fetchmany(self, size=None):
-        if size is None:
-            size = self._cursor.arraysize
-
-        retval = self._rows[0:size]
-        self._rows = self._rows[size:]
-        return retval
-
-    def fetchall(self):
-        retval = self._rows
-        self._rows = []
-        return retval
-
+    async def _executemany_async(
+        self,
+        operation,
+        seq_of_parameters,
+    ):
+        # override to not use mutex, psycopg3 already has mutex
+        return await self._cursor.executemany(operation, seq_of_parameters)
 
-class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
-    def execute(self, query, params=None, **kw):
-        self.await_(self._cursor.execute(query, params, **kw))
-        return self
 
-    def close(self):
-        self.await_(self._cursor.close())
+class AsyncAdapt_psycopg_ss_cursor(
+    AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_psycopg_cursor
+):
+    __slots__ = ("name",)
 
-    def fetchone(self):
-        return self.await_(self._cursor.fetchone())
+    name: str
 
-    def fetchmany(self, size=0):
-        return self.await_(self._cursor.fetchmany(size))
+    def __init__(self, adapt_connection, name):
+        self.name = name
+        super().__init__(adapt_connection)
 
-    def fetchall(self):
-        return self.await_(self._cursor.fetchall())
+    def _make_new_cursor(self, connection):
+        return connection.cursor(self.name)
 
+    # TODO: should this be on the base asyncio adapter?
     def __iter__(self):
         iterator = self._cursor.__aiter__()
         while True:
@@ -618,35 +590,38 @@ class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
                 break
 
 
-class AsyncAdapt_psycopg_connection(AdaptedConnection):
+class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
     __slots__ = ()
-    await_ = staticmethod(await_only)
 
-    def __init__(self, connection) -> None:
-        self._connection = connection
+    _cursor_cls = AsyncAdapt_psycopg_cursor
+    _ss_cursor_cls = AsyncAdapt_psycopg_ss_cursor
 
-    def __getattr__(self, name):
-        return getattr(self._connection, name)
+    def add_notice_handler(self, handler):
+        self._connection.add_notice_handler(handler)
 
-    def execute(self, query, params=None, **kw):
-        cursor = self.await_(self._connection.execute(query, params, **kw))
-        return AsyncAdapt_psycopg_cursor(cursor, self.await_)
+    @property
+    def info(self):
+        return self._connection.info
 
-    def cursor(self, *args, **kw):
-        cursor = self._connection.cursor(*args, **kw)
-        if hasattr(cursor, "name"):
-            return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_)
-        else:
-            return AsyncAdapt_psycopg_cursor(cursor, self.await_)
+    @property
+    def adapters(self):
+        return self._connection.adapters
+
+    @property
+    def closed(self):
+        return self._connection.closed
 
-    def commit(self):
-        self.await_(self._connection.commit())
+    @property
+    def broken(self):
+        return self._connection.broken
 
-    def rollback(self):
-        self.await_(self._connection.rollback())
+    @property
+    def read_only(self):
+        return self._connection.read_only
 
-    def close(self):
-        self.await_(self._connection.close())
+    @property
+    def deferrable(self):
+        return self._connection.deferrable
 
     @property
     def autocommit(self):
@@ -668,15 +643,23 @@ class AsyncAdapt_psycopg_connection(AdaptedConnection):
     def set_deferrable(self, value):
         self.await_(self._connection.set_deferrable(value))
 
+    def cursor(self, name=None, /):
+        if name:
+            return AsyncAdapt_psycopg_ss_cursor(self, name)
+        else:
+            return AsyncAdapt_psycopg_cursor(self)
+
 
-class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection):
+class AsyncAdaptFallback_psycopg_connection(
+    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_psycopg_connection
+):
     __slots__ = ()
-    await_ = staticmethod(await_fallback)
 
 
 class PsycopgAdaptDBAPI:
-    def __init__(self, psycopg) -> None:
+    def __init__(self, psycopg, ExecStatus) -> None:
         self.psycopg = psycopg
+        self.ExecStatus = ExecStatus
 
         for k, v in self.psycopg.__dict__.items():
             if k != "connect":
@@ -689,11 +672,11 @@ class PsycopgAdaptDBAPI:
         )
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_psycopg_connection(
-                await_fallback(creator_fn(*arg, **kw))
+                self, await_fallback(creator_fn(*arg, **kw))
             )
         else:
             return AsyncAdapt_psycopg_connection(
-                await_only(creator_fn(*arg, **kw))
+                self, await_only(creator_fn(*arg, **kw))
             )
 
 
@@ -706,9 +689,7 @@ class PGDialectAsync_psycopg(PGDialect_psycopg):
         import psycopg
         from psycopg.pq import ExecStatus
 
-        AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus
-
-        return PsycopgAdaptDBAPI(psycopg)
+        return PsycopgAdaptDBAPI(psycopg, ExecStatus)
 
     @classmethod
     def get_pool_class(cls, url):
index d9438d1880e3a58d367e5f8bd5db6640ea21a42c..41e406164e3f1257a10866f386aae1020db0bc62 100644 (file)
@@ -84,140 +84,27 @@ from .base import SQLiteExecutionContext
 from .pysqlite import SQLiteDialect_pysqlite
 from ... import pool
 from ... import util
-from ...engine import AdaptedConnection
+from ...connectors.asyncio import AsyncAdapt_dbapi_connection
+from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
 
-class AsyncAdapt_aiosqlite_cursor:
-    # TODO: base on connectors/asyncio.py
-    # see #10415
-
-    __slots__ = (
-        "_adapt_connection",
-        "_connection",
-        "description",
-        "await_",
-        "_rows",
-        "arraysize",
-        "rowcount",
-        "lastrowid",
-    )
-
-    server_side = False
-
-    def __init__(self, adapt_connection):
-        self._adapt_connection = adapt_connection
-        self._connection = adapt_connection._connection
-        self.await_ = adapt_connection.await_
-        self.arraysize = 1
-        self.rowcount = -1
-        self.description = None
-        self._rows = []
-
-    def close(self):
-        self._rows[:] = []
-
-    def execute(self, operation, parameters=None):
-        try:
-            _cursor = self.await_(self._connection.cursor())
-
-            if parameters is None:
-                self.await_(_cursor.execute(operation))
-            else:
-                self.await_(_cursor.execute(operation, parameters))
-
-            if _cursor.description:
-                self.description = _cursor.description
-                self.lastrowid = self.rowcount = -1
-
-                if not self.server_side:
-                    self._rows = self.await_(_cursor.fetchall())
-            else:
-                self.description = None
-                self.lastrowid = _cursor.lastrowid
-                self.rowcount = _cursor.rowcount
-
-            if not self.server_side:
-                self.await_(_cursor.close())
-            else:
-                self._cursor = _cursor
-        except Exception as error:
-            self._adapt_connection._handle_exception(error)
-
-    def executemany(self, operation, seq_of_parameters):
-        try:
-            _cursor = self.await_(self._connection.cursor())
-            self.await_(_cursor.executemany(operation, seq_of_parameters))
-            self.description = None
-            self.lastrowid = _cursor.lastrowid
-            self.rowcount = _cursor.rowcount
-            self.await_(_cursor.close())
-        except Exception as error:
-            self._adapt_connection._handle_exception(error)
-
-    def setinputsizes(self, *inputsizes):
-        pass
-
-    def __iter__(self):
-        while self._rows:
-            yield self._rows.pop(0)
-
-    def fetchone(self):
-        if self._rows:
-            return self._rows.pop(0)
-        else:
-            return None
-
-    def fetchmany(self, size=None):
-        if size is None:
-            size = self.arraysize
-
-        retval = self._rows[0:size]
-        self._rows[:] = self._rows[size:]
-        return retval
-
-    def fetchall(self):
-        retval = self._rows[:]
-        self._rows[:] = []
-        return retval
-
-
-class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
-    # TODO: base on connectors/asyncio.py
-    # see #10415
-    __slots__ = "_cursor"
-
-    server_side = True
-
-    def __init__(self, *arg, **kw):
-        super().__init__(*arg, **kw)
-        self._cursor = None
-
-    def close(self):
-        if self._cursor is not None:
-            self.await_(self._cursor.close())
-            self._cursor = None
-
-    def fetchone(self):
-        return self.await_(self._cursor.fetchone())
+class AsyncAdapt_aiosqlite_cursor(AsyncAdapt_dbapi_cursor):
+    __slots__ = ()
 
-    def fetchmany(self, size=None):
-        if size is None:
-            size = self.arraysize
-        return self.await_(self._cursor.fetchmany(size=size))
 
-    def fetchall(self):
-        return self.await_(self._cursor.fetchall())
+class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_dbapi_ss_cursor):
+    __slots__ = ()
 
 
-class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
-    await_ = staticmethod(await_only)
-    __slots__ = ("dbapi",)
+class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
+    __slots__ = ()
 
-    def __init__(self, dbapi, connection):
-        self.dbapi = dbapi
-        self._connection = connection
+    _cursor_cls = AsyncAdapt_aiosqlite_cursor
+    _ss_cursor_cls = AsyncAdapt_aiosqlite_ss_cursor
 
     @property
     def isolation_level(self):
@@ -249,26 +136,13 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
         except Exception as error:
             self._handle_exception(error)
 
-    def cursor(self, server_side=False):
-        if server_side:
-            return AsyncAdapt_aiosqlite_ss_cursor(self)
-        else:
-            return AsyncAdapt_aiosqlite_cursor(self)
-
-    def execute(self, *args, **kw):
-        return self.await_(self._connection.execute(*args, **kw))
-
     def rollback(self):
-        try:
-            self.await_(self._connection.rollback())
-        except Exception as error:
-            self._handle_exception(error)
+        if self._connection._connection:
+            super().rollback()
 
     def commit(self):
-        try:
-            self.await_(self._connection.commit())
-        except Exception as error:
-            self._handle_exception(error)
+        if self._connection._connection:
+            super().commit()
 
     def close(self):
         try:
@@ -287,22 +161,20 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
             self._handle_exception(error)
 
     def _handle_exception(self, error):
-        if (
-            isinstance(error, ValueError)
-            and error.args[0] == "no active connection"
+        if isinstance(error, ValueError) and error.args[0].lower() in (
+            "no active connection",
+            "connection closed",
         ):
-            raise self.dbapi.sqlite.OperationalError(
-                "no active connection"
-            ) from error
+            raise self.dbapi.sqlite.OperationalError(error.args[0]) from error
         else:
-            raise error
+            super()._handle_exception(error)
 
 
-class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
+class AsyncAdaptFallback_aiosqlite_connection(
+    AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aiosqlite_connection
+):
     __slots__ = ()
 
-    await_ = staticmethod(await_fallback)
-
 
 class AsyncAdapt_aiosqlite_dbapi:
     def __init__(self, aiosqlite, sqlite):
@@ -382,10 +254,13 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
             return pool.StaticPool
 
     def is_disconnect(self, e, connection, cursor):
-        if isinstance(
-            e, self.dbapi.OperationalError
-        ) and "no active connection" in str(e):
-            return True
+        if isinstance(e, self.dbapi.OperationalError):
+            err_lower = str(e).lower()
+            if (
+                "no active connection" in err_lower
+                or "connection closed" in err_lower
+            ):
+                return True
 
         return super().is_disconnect(e, connection, cursor)