]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
type aiosqlite
authorPablo Estevez <pablo22estevez@gmail.com>
Thu, 26 Jun 2025 19:21:55 +0000 (15:21 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Fri, 27 Jun 2025 20:13:56 +0000 (22:13 +0200)
Closes: #12656
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12656
Pull-request-sha: 396e9cfaaccbc56537b967e62decbbb3eb0e036a

Change-Id: I598ee6022616f265824291544750e571eaba413c
(cherry picked from commit 944df50e92bd077b81775b080bfd9347f3baabdc)

lib/sqlalchemy/dialects/sqlite/aiosqlite.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py

index b8cb8c3819b32e154b3dd6e8a19f1787d4b4db77..3f39d4dbc7db56bbd27408f527f601695773ecd9 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 r"""
@@ -78,19 +77,43 @@ based on the kind of SQLite database that's requested:
     :paramref:`_sa.create_engine.poolclass` parameter.
 
 """  # noqa
+from __future__ import annotations
 
 import asyncio
 from collections import deque
 from functools import partial
+from types import ModuleType
+from typing import Any
+from typing import cast
+from typing import Deque
+from typing import Iterator
+from typing import NoReturn
+from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .base import SQLiteExecutionContext
 from .pysqlite import SQLiteDialect_pysqlite
 from ... import pool
 from ... import util
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...engine import AdaptedConnection
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
+if TYPE_CHECKING:
+    from ...connectors.asyncio import AsyncIODBAPIConnection
+    from ...connectors.asyncio import AsyncIODBAPICursor
+    from ...engine.interfaces import _DBAPICursorDescription
+    from ...engine.interfaces import _DBAPIMultiExecuteParams
+    from ...engine.interfaces import _DBAPISingleExecuteParams
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.url import URL
+    from ...pool.base import PoolProxiedConnection
+
 
 class AsyncAdapt_aiosqlite_cursor:
     # TODO: base on connectors/asyncio.py
@@ -109,21 +132,26 @@ class AsyncAdapt_aiosqlite_cursor:
 
     server_side = False
 
-    def __init__(self, adapt_connection):
+    def __init__(self, adapt_connection: AsyncAdapt_aiosqlite_connection):
         self._adapt_connection = adapt_connection
         self._connection = adapt_connection._connection
         self.await_ = adapt_connection.await_
         self.arraysize = 1
         self.rowcount = -1
-        self.description = None
-        self._rows = deque()
+        self.description: Optional[_DBAPICursorDescription] = None
+        self._rows: Deque[Any] = deque()
 
-    def close(self):
+    def close(self) -> None:
         self._rows.clear()
 
-    def execute(self, operation, parameters=None):
+    def execute(
+        self,
+        operation: Any,
+        parameters: Optional[_DBAPISingleExecuteParams] = None,
+    ) -> Any:
+
         try:
-            _cursor = self.await_(self._connection.cursor())
+            _cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor())  # type: ignore[arg-type] # noqa: E501
 
             if parameters is None:
                 self.await_(_cursor.execute(operation))
@@ -144,13 +172,17 @@ class AsyncAdapt_aiosqlite_cursor:
             if not self.server_side:
                 self.await_(_cursor.close())
             else:
-                self._cursor = _cursor
+                self._cursor = _cursor  # type: ignore[misc]
         except Exception as error:
             self._adapt_connection._handle_exception(error)
 
-    def executemany(self, operation, seq_of_parameters):
+    def executemany(
+        self,
+        operation: Any,
+        seq_of_parameters: _DBAPIMultiExecuteParams,
+    ) -> Any:
         try:
-            _cursor = self.await_(self._connection.cursor())
+            _cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor())  # type: ignore[arg-type] # noqa: E501
             self.await_(_cursor.executemany(operation, seq_of_parameters))
             self.description = None
             self.lastrowid = _cursor.lastrowid
@@ -159,27 +191,27 @@ class AsyncAdapt_aiosqlite_cursor:
         except Exception as error:
             self._adapt_connection._handle_exception(error)
 
-    def setinputsizes(self, *inputsizes):
+    def setinputsizes(self, *inputsizes: Any) -> None:
         pass
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Any]:
         while self._rows:
             yield self._rows.popleft()
 
-    def fetchone(self):
+    def fetchone(self) -> Optional[Any]:
         if self._rows:
             return self._rows.popleft()
         else:
             return None
 
-    def fetchmany(self, size=None):
+    def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
         if size is None:
             size = self.arraysize
 
         rr = self._rows
         return [rr.popleft() for _ in range(min(size, len(rr)))]
 
-    def fetchall(self):
+    def fetchall(self) -> Sequence[Any]:
         retval = list(self._rows)
         self._rows.clear()
         return retval
@@ -192,24 +224,27 @@ class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
 
     server_side = True
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any) -> None:
         super().__init__(*arg, **kw)
-        self._cursor = None
+        self._cursor: Optional[AsyncIODBAPICursor] = None
 
-    def close(self):
+    def close(self) -> None:
         if self._cursor is not None:
             self.await_(self._cursor.close())
             self._cursor = None
 
-    def fetchone(self):
+    def fetchone(self) -> Optional[Any]:
+        assert self._cursor is not None
         return self.await_(self._cursor.fetchone())
 
-    def fetchmany(self, size=None):
+    def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
+        assert self._cursor is not None
         if size is None:
             size = self.arraysize
         return self.await_(self._cursor.fetchmany(size=size))
 
-    def fetchall(self):
+    def fetchall(self) -> Sequence[Any]:
+        assert self._cursor is not None
         return self.await_(self._cursor.fetchall())
 
 
@@ -217,22 +252,24 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
     await_ = staticmethod(await_only)
     __slots__ = ("dbapi",)
 
-    def __init__(self, dbapi, connection):
+    def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection) -> None:
         self.dbapi = dbapi
         self._connection = connection
 
     @property
-    def isolation_level(self):
-        return self._connection.isolation_level
+    def isolation_level(self) -> Optional[str]:
+        return cast(str, self._connection.isolation_level)
 
     @isolation_level.setter
-    def isolation_level(self, value):
+    def isolation_level(self, value: Optional[str]) -> None:
         # aiosqlite's isolation_level setter works outside the Thread
         # that it's supposed to, necessitating setting check_same_thread=False.
         # for improved stability, we instead invent our own awaitable version
         # using aiosqlite's async queue directly.
 
-        def set_iso(connection, value):
+        def set_iso(
+            connection: AsyncAdapt_aiosqlite_connection, value: Optional[str]
+        ) -> None:
             connection.isolation_level = value
 
         function = partial(set_iso, self._connection._conn, value)
@@ -241,38 +278,38 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
         self._connection._tx.put_nowait((future, function))
 
         try:
-            return self.await_(future)
+            self.await_(future)
         except Exception as error:
             self._handle_exception(error)
 
-    def create_function(self, *args, **kw):
+    def create_function(self, *args: Any, **kw: Any) -> None:
         try:
             self.await_(self._connection.create_function(*args, **kw))
         except Exception as error:
             self._handle_exception(error)
 
-    def cursor(self, server_side=False):
+    def cursor(self, server_side: bool = False) -> AsyncAdapt_aiosqlite_cursor:
         if server_side:
             return AsyncAdapt_aiosqlite_ss_cursor(self)
         else:
             return AsyncAdapt_aiosqlite_cursor(self)
 
-    def execute(self, *args, **kw):
+    def execute(self, *args: Any, **kw: Any) -> Any:
         return self.await_(self._connection.execute(*args, **kw))
 
-    def rollback(self):
+    def rollback(self) -> None:
         try:
             self.await_(self._connection.rollback())
         except Exception as error:
             self._handle_exception(error)
 
-    def commit(self):
+    def commit(self) -> None:
         try:
             self.await_(self._connection.commit())
         except Exception as error:
             self._handle_exception(error)
 
-    def close(self):
+    def close(self) -> None:
         try:
             self.await_(self._connection.close())
         except ValueError:
@@ -288,7 +325,7 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
         except Exception as error:
             self._handle_exception(error)
 
-    def _handle_exception(self, error):
+    def _handle_exception(self, error: Exception) -> NoReturn:
         if (
             isinstance(error, ValueError)
             and error.args[0] == "no active connection"
@@ -306,14 +343,14 @@ class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
     await_ = staticmethod(await_fallback)
 
 
-class AsyncAdapt_aiosqlite_dbapi:
-    def __init__(self, aiosqlite, sqlite):
+class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
+    def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType):
         self.aiosqlite = aiosqlite
         self.sqlite = sqlite
         self.paramstyle = "qmark"
         self._init_dbapi_attributes()
 
-    def _init_dbapi_attributes(self):
+    def _init_dbapi_attributes(self) -> None:
         for name in (
             "DatabaseError",
             "Error",
@@ -332,7 +369,7 @@ class AsyncAdapt_aiosqlite_dbapi:
         for name in ("Binary",):
             setattr(self, name, getattr(self.sqlite, name))
 
-    def connect(self, *arg, **kw):
+    def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiosqlite_connection:
         async_fallback = kw.pop("async_fallback", False)
 
         creator_fn = kw.pop("async_creator_fn", None)
@@ -356,7 +393,7 @@ class AsyncAdapt_aiosqlite_dbapi:
 
 
 class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
-    def create_server_side_cursor(self):
+    def create_server_side_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor(server_side=True)
 
 
@@ -371,19 +408,25 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
     execution_ctx_cls = SQLiteExecutionContext_aiosqlite
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> AsyncAdapt_aiosqlite_dbapi:
         return AsyncAdapt_aiosqlite_dbapi(
             __import__("aiosqlite"), __import__("sqlite3")
         )
 
     @classmethod
-    def get_pool_class(cls, url):
+    def get_pool_class(cls, url: URL) -> type[pool.Pool]:
         if cls._is_url_file_db(url):
             return pool.AsyncAdaptedQueuePool
         else:
             return pool.StaticPool
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
+        self.dbapi = cast("DBAPIModule", self.dbapi)
         if isinstance(
             e, self.dbapi.OperationalError
         ) and "no active connection" in str(e):
@@ -391,8 +434,10 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
 
         return super().is_disconnect(e, connection, cursor)
 
-    def get_driver_connection(self, connection):
-        return connection._connection
+    def get_driver_connection(
+        self, connection: DBAPIConnection
+    ) -> AsyncIODBAPIConnection:
+        return connection._connection  # type: ignore[no-any-return]
 
 
 dialect = SQLiteDialect_aiosqlite
index 4a777e3b81d981fb7e9a698593d859d9d6b67266..2f23886b54ded8d18b7dccd38f187ac6f59359ff 100644 (file)
@@ -393,10 +393,15 @@ connection when it is created. That is accomplished with an event listener::
             print(conn.scalar(text("SELECT UDF()")))
 
 """  # noqa
+from __future__ import annotations
 
 import math
 import os
 import re
+from typing import cast
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .base import DATE
 from .base import DATETIME
@@ -406,6 +411,13 @@ from ... import pool
 from ... import types as sqltypes
 from ... import util
 
+if TYPE_CHECKING:
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.url import URL
+    from ...pool.base import PoolProxiedConnection
+
 
 class _SQLite_pysqliteTimeStamp(DATETIME):
     def bind_processor(self, dialect):
@@ -459,7 +471,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
         return sqlite
 
     @classmethod
-    def _is_url_file_db(cls, url):
+    def _is_url_file_db(cls, url: URL):
         if (url.database and url.database != ":memory:") and (
             url.query.get("mode", None) != "memory"
         ):
@@ -589,7 +601,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
 
         return ([filename], pysqlite_opts)
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
+        self.dbapi = cast("DBAPIModule", self.dbapi)
         return isinstance(
             e, self.dbapi.ProgrammingError
         ) and "Cannot operate on a closed database." in str(e)