]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
type aiosqlite
authorPablo Estevez <pablo22estevez@gmail.com>
Thu, 26 Jun 2025 19:21:55 +0000 (15:21 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 26 Jun 2025 19:29:29 +0000 (21:29 +0200)
Closes: #12656
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12656
Pull-request-sha: 396e9cfaaccbc56537b967e62decbbb3eb0e036a

Change-Id: I598ee6022616f265824291544750e571eaba413c

lib/sqlalchemy/dialects/sqlite/aiosqlite.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py

index ad718a4ae8b8424c3d159e19911e5167b67a24a1..cf8726c1f34b681e269992235151f1921b17e180 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 r"""
@@ -78,18 +77,35 @@ based on the kind of SQLite database that's requested:
     :paramref:`_sa.create_engine.poolclass` parameter.
 
 """  # noqa
+from __future__ import annotations
 
 import asyncio
 from functools import partial
+from types import ModuleType
+from typing import Any
+from typing import cast
+from typing import NoReturn
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .base import SQLiteExecutionContext
 from .pysqlite import SQLiteDialect_pysqlite
 from ... import pool
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
 from ...util.concurrency import await_
 
+if TYPE_CHECKING:
+    from ...connectors.asyncio import AsyncIODBAPIConnection
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.url import URL
+    from ...pool.base import PoolProxiedConnection
+
 
 class AsyncAdapt_aiosqlite_cursor(AsyncAdapt_dbapi_cursor):
     __slots__ = ()
@@ -106,17 +122,19 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
     _ss_cursor_cls = AsyncAdapt_aiosqlite_ss_cursor
 
     @property
-    def isolation_level(self):
-        return self._connection.isolation_level
+    def isolation_level(self) -> Optional[str]:
+        return cast(str, self._connection.isolation_level)
 
     @isolation_level.setter
-    def isolation_level(self, value):
+    def isolation_level(self, value: Optional[str]) -> None:
         # aiosqlite's isolation_level setter works outside the Thread
         # that it's supposed to, necessitating setting check_same_thread=False.
         # for improved stability, we instead invent our own awaitable version
         # using aiosqlite's async queue directly.
 
-        def set_iso(connection, value):
+        def set_iso(
+            connection: AsyncAdapt_aiosqlite_connection, value: Optional[str]
+        ) -> None:
             connection.isolation_level = value
 
         function = partial(set_iso, self._connection._conn, value)
@@ -125,25 +143,25 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
         self._connection._tx.put_nowait((future, function))
 
         try:
-            return await_(future)
+            await_(future)
         except Exception as error:
             self._handle_exception(error)
 
-    def create_function(self, *args, **kw):
+    def create_function(self, *args: Any, **kw: Any) -> None:
         try:
             await_(self._connection.create_function(*args, **kw))
         except Exception as error:
             self._handle_exception(error)
 
-    def rollback(self):
+    def rollback(self) -> None:
         if self._connection._connection:
             super().rollback()
 
-    def commit(self):
+    def commit(self) -> None:
         if self._connection._connection:
             super().commit()
 
-    def close(self):
+    def close(self) -> None:
         try:
             await_(self._connection.close())
         except ValueError:
@@ -159,7 +177,7 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
         except Exception as error:
             self._handle_exception(error)
 
-    def _handle_exception(self, error):
+    def _handle_exception(self, error: Exception) -> NoReturn:
         if isinstance(error, ValueError) and error.args[0].lower() in (
             "no active connection",
             "connection closed",
@@ -169,14 +187,14 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
             super()._handle_exception(error)
 
 
-class AsyncAdapt_aiosqlite_dbapi:
-    def __init__(self, aiosqlite, sqlite):
+class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
+    def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType):
         self.aiosqlite = aiosqlite
         self.sqlite = sqlite
         self.paramstyle = "qmark"
         self._init_dbapi_attributes()
 
-    def _init_dbapi_attributes(self):
+    def _init_dbapi_attributes(self) -> None:
         for name in (
             "DatabaseError",
             "Error",
@@ -195,7 +213,7 @@ class AsyncAdapt_aiosqlite_dbapi:
         for name in ("Binary",):
             setattr(self, name, getattr(self.sqlite, name))
 
-    def connect(self, *arg, **kw):
+    def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiosqlite_connection:
         creator_fn = kw.pop("async_creator_fn", None)
         if creator_fn:
             connection = creator_fn(*arg, **kw)
@@ -211,7 +229,7 @@ class AsyncAdapt_aiosqlite_dbapi:
 
 
 class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
-    def create_server_side_cursor(self):
+    def create_server_side_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor(server_side=True)
 
 
@@ -226,19 +244,25 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
     execution_ctx_cls = SQLiteExecutionContext_aiosqlite
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> AsyncAdapt_aiosqlite_dbapi:
         return AsyncAdapt_aiosqlite_dbapi(
             __import__("aiosqlite"), __import__("sqlite3")
         )
 
     @classmethod
-    def get_pool_class(cls, url):
+    def get_pool_class(cls, url: URL) -> type[pool.Pool]:
         if cls._is_url_file_db(url):
             return pool.AsyncAdaptedQueuePool
         else:
             return pool.StaticPool
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
+        self.dbapi = cast("DBAPIModule", self.dbapi)
         if isinstance(e, self.dbapi.OperationalError):
             err_lower = str(e).lower()
             if (
@@ -249,8 +273,10 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
 
         return super().is_disconnect(e, connection, cursor)
 
-    def get_driver_connection(self, connection):
-        return connection._connection
+    def get_driver_connection(
+        self, connection: DBAPIConnection
+    ) -> AsyncIODBAPIConnection:
+        return connection._connection  # type: ignore[no-any-return]
 
 
 dialect = SQLiteDialect_aiosqlite
index d4b1518a3efc87c1d3a4465bf9459a58b079e29e..c6fd69225c6d135f635e03bfbd7a561d1b6513f6 100644 (file)
@@ -391,10 +391,15 @@ connection when it is created. That is accomplished with an event listener::
             print(conn.scalar(text("SELECT UDF()")))
 
 """  # noqa
+from __future__ import annotations
 
 import math
 import os
 import re
+from typing import cast
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .base import DATE
 from .base import DATETIME
@@ -404,6 +409,13 @@ from ... import pool
 from ... import types as sqltypes
 from ... import util
 
+if TYPE_CHECKING:
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.url import URL
+    from ...pool.base import PoolProxiedConnection
+
 
 class _SQLite_pysqliteTimeStamp(DATETIME):
     def bind_processor(self, dialect):
@@ -457,7 +469,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
         return sqlite
 
     @classmethod
-    def _is_url_file_db(cls, url):
+    def _is_url_file_db(cls, url: URL):
         if (url.database and url.database != ":memory:") and (
             url.query.get("mode", None) != "memory"
         ):
@@ -587,7 +599,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
 
         return ([filename], pysqlite_opts)
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
+        self.dbapi = cast("DBAPIModule", self.dbapi)
         return isinstance(
             e, self.dbapi.ProgrammingError
         ) and "Cannot operate on a closed database." in str(e)