]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
type aiosqlite
authorPablo Estevez <pablo22estevez@gmail.com>
Sun, 8 Jun 2025 21:13:36 +0000 (21:13 +0000)
committerPablo Estevez <pablo22estevez@gmail.com>
Sun, 8 Jun 2025 21:13:36 +0000 (21:13 +0000)
lib/sqlalchemy/dialects/sqlite/aiosqlite.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py

index ad718a4ae8b8424c3d159e19911e5167b67a24a1..5cb8c2534f74368cb25979c1daf1c5931915de6f 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 r"""
@@ -78,9 +77,17 @@ based on the kind of SQLite database that's requested:
     :paramref:`_sa.create_engine.poolclass` parameter.
 
 """  # noqa
+from __future__ import annotations
 
 import asyncio
 from functools import partial
+from types import ModuleType
+from typing import Any
+from typing import cast
+from typing import NoReturn
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .base import SQLiteExecutionContext
 from .pysqlite import SQLiteDialect_pysqlite
@@ -88,8 +95,16 @@ from ... import pool
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...engine.interfaces import DBAPIModule
 from ...util.concurrency import await_
 
+if TYPE_CHECKING:
+    from ...connectors.asyncio import AsyncIODBAPIConnection
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.url import URL
+    from ...pool.base import PoolProxiedConnection
+
 
 class AsyncAdapt_aiosqlite_cursor(AsyncAdapt_dbapi_cursor):
     __slots__ = ()
@@ -106,17 +121,19 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
     _ss_cursor_cls = AsyncAdapt_aiosqlite_ss_cursor
 
     @property
-    def isolation_level(self):
-        return self._connection.isolation_level
+    def isolation_level(self) -> Optional[str]:
+        return cast(str, self._connection.isolation_level)
 
     @isolation_level.setter
-    def isolation_level(self, value):
+    def isolation_level(self, value: Optional[str]) -> None:
         # aiosqlite's isolation_level setter works outside the Thread
         # that it's supposed to, necessitating setting check_same_thread=False.
         # for improved stability, we instead invent our own awaitable version
         # using aiosqlite's async queue directly.
 
-        def set_iso(connection, value):
+        def set_iso(
+            connection: AsyncAdapt_aiosqlite_connection, value: Optional[str]
+        ) -> None:
             connection.isolation_level = value
 
         function = partial(set_iso, self._connection._conn, value)
@@ -125,25 +142,25 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
         self._connection._tx.put_nowait((future, function))
 
         try:
-            return await_(future)
+            return cast(None, await_(future))
         except Exception as error:
             self._handle_exception(error)
 
-    def create_function(self, *args, **kw):
+    def create_function(self, *args: Any, **kw: Any) -> None:
         try:
-            await_(self._connection.create_function(*args, **kw))
+            cast(None, await_(self._connection.create_function(*args, **kw)))
         except Exception as error:
             self._handle_exception(error)
 
-    def rollback(self):
+    def rollback(self) -> None:
         if self._connection._connection:
             super().rollback()
 
-    def commit(self):
+    def commit(self) -> None:
         if self._connection._connection:
             super().commit()
 
-    def close(self):
+    def close(self) -> None:
         try:
             await_(self._connection.close())
         except ValueError:
@@ -159,7 +176,7 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
         except Exception as error:
             self._handle_exception(error)
 
-    def _handle_exception(self, error):
+    def _handle_exception(self, error: Exception) -> NoReturn:
         if isinstance(error, ValueError) and error.args[0].lower() in (
             "no active connection",
             "connection closed",
@@ -170,13 +187,13 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
 
 
 class AsyncAdapt_aiosqlite_dbapi:
-    def __init__(self, aiosqlite, sqlite):
+    def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType):
         self.aiosqlite = aiosqlite
         self.sqlite = sqlite
         self.paramstyle = "qmark"
         self._init_dbapi_attributes()
 
-    def _init_dbapi_attributes(self):
+    def _init_dbapi_attributes(self) -> None:
         for name in (
             "DatabaseError",
             "Error",
@@ -195,7 +212,7 @@ class AsyncAdapt_aiosqlite_dbapi:
         for name in ("Binary",):
             setattr(self, name, getattr(self.sqlite, name))
 
-    def connect(self, *arg, **kw):
+    def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiosqlite_connection:
         creator_fn = kw.pop("async_creator_fn", None)
         if creator_fn:
             connection = creator_fn(*arg, **kw)
@@ -209,9 +226,11 @@ class AsyncAdapt_aiosqlite_dbapi:
             await_(connection),
         )
 
+    def __getattr__(self, key: str) -> Any: ...
+
 
 class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
-    def create_server_side_cursor(self):
+    def create_server_side_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor(server_side=True)
 
 
@@ -226,19 +245,25 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
     execution_ctx_cls = SQLiteExecutionContext_aiosqlite
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> AsyncAdapt_aiosqlite_dbapi:
         return AsyncAdapt_aiosqlite_dbapi(
             __import__("aiosqlite"), __import__("sqlite3")
         )
 
     @classmethod
-    def get_pool_class(cls, url):
+    def get_pool_class(cls, url: URL) -> type[pool.Pool]:
         if cls._is_url_file_db(url):
             return pool.AsyncAdaptedQueuePool
         else:
             return pool.StaticPool
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
+        self.dbapi = cast(DBAPIModule, self.dbapi)
         if isinstance(e, self.dbapi.OperationalError):
             err_lower = str(e).lower()
             if (
@@ -249,8 +274,10 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
 
         return super().is_disconnect(e, connection, cursor)
 
-    def get_driver_connection(self, connection):
-        return connection._connection
+    def get_driver_connection(
+        self, connection: DBAPIConnection
+    ) -> AsyncIODBAPIConnection:
+        return connection._connection  # type: ignore[no-any-return]
 
 
 dialect = SQLiteDialect_aiosqlite
index d4b1518a3efc87c1d3a4465bf9459a58b079e29e..fa9b100c88dabba745c82be09da8d174ce7e487e 100644 (file)
@@ -391,10 +391,15 @@ connection when it is created. That is accomplished with an event listener::
             print(conn.scalar(text("SELECT UDF()")))
 
 """  # noqa
+from __future__ import annotations
 
 import math
 import os
 import re
+from typing import cast
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .base import DATE
 from .base import DATETIME
@@ -404,6 +409,13 @@ from ... import pool
 from ... import types as sqltypes
 from ... import util
 
+if TYPE_CHECKING:
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.url import URL
+    from ...pool.base import PoolProxiedConnection
+
 
 class _SQLite_pysqliteTimeStamp(DATETIME):
     def bind_processor(self, dialect):
@@ -457,7 +469,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
         return sqlite
 
     @classmethod
-    def _is_url_file_db(cls, url):
+    def _is_url_file_db(cls, url: URL):
         if (url.database and url.database != ":memory:") and (
             url.query.get("mode", None) != "memory"
         ):
@@ -587,7 +599,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
 
         return ([filename], pysqlite_opts)
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
+        self.dbapi = cast(DBAPIModule, self.dbapi)
         return isinstance(
             e, self.dbapi.ProgrammingError
         ) and "Cannot operate on a closed database." in str(e)