From 944df50e92bd077b81775b080bfd9347f3baabdc Mon Sep 17 00:00:00 2001 From: Pablo Estevez Date: Thu, 26 Jun 2025 15:21:55 -0400 Subject: [PATCH] type aiosqlite Closes: #12656 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12656 Pull-request-sha: 396e9cfaaccbc56537b967e62decbbb3eb0e036a Change-Id: I598ee6022616f265824291544750e571eaba413c --- lib/sqlalchemy/dialects/sqlite/aiosqlite.py | 68 ++++++++++++++------- lib/sqlalchemy/dialects/sqlite/pysqlite.py | 22 ++++++- 2 files changed, 67 insertions(+), 23 deletions(-) diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index ad718a4ae8..cf8726c1f3 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" @@ -78,18 +77,35 @@ based on the kind of SQLite database that's requested: :paramref:`_sa.create_engine.poolclass` parameter. """ # noqa +from __future__ import annotations import asyncio from functools import partial +from types import ModuleType +from typing import Any +from typing import cast +from typing import NoReturn +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from .base import SQLiteExecutionContext from .pysqlite import SQLiteDialect_pysqlite from ... import pool from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_module from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor from ...util.concurrency import await_ +if TYPE_CHECKING: + from ...connectors.asyncio import AsyncIODBAPIConnection + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.url import URL + from ...pool.base import PoolProxiedConnection + class AsyncAdapt_aiosqlite_cursor(AsyncAdapt_dbapi_cursor): __slots__ = () @@ -106,17 +122,19 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): _ss_cursor_cls = AsyncAdapt_aiosqlite_ss_cursor @property - def isolation_level(self): - return self._connection.isolation_level + def isolation_level(self) -> Optional[str]: + return cast(str, self._connection.isolation_level) @isolation_level.setter - def isolation_level(self, value): + def isolation_level(self, value: Optional[str]) -> None: # aiosqlite's isolation_level setter works outside the Thread # that it's supposed to, necessitating setting check_same_thread=False. # for improved stability, we instead invent our own awaitable version # using aiosqlite's async queue directly. - def set_iso(connection, value): + def set_iso( + connection: AsyncAdapt_aiosqlite_connection, value: Optional[str] + ) -> None: connection.isolation_level = value function = partial(set_iso, self._connection._conn, value) @@ -125,25 +143,25 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): self._connection._tx.put_nowait((future, function)) try: - return await_(future) + await_(future) except Exception as error: self._handle_exception(error) - def create_function(self, *args, **kw): + def create_function(self, *args: Any, **kw: Any) -> None: try: await_(self._connection.create_function(*args, **kw)) except Exception as error: self._handle_exception(error) - def rollback(self): + def rollback(self) -> None: if self._connection._connection: super().rollback() - def commit(self): + def commit(self) -> None: if self._connection._connection: super().commit() - def close(self): + def close(self) -> None: try: await_(self._connection.close()) except ValueError: @@ -159,7 +177,7 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): except Exception as error: self._handle_exception(error) - def _handle_exception(self, error): + def _handle_exception(self, error: Exception) -> NoReturn: if isinstance(error, ValueError) and error.args[0].lower() in ( "no active connection", "connection closed", @@ -169,14 +187,14 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): super()._handle_exception(error) -class AsyncAdapt_aiosqlite_dbapi: - def __init__(self, aiosqlite, sqlite): +class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module): + def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType): self.aiosqlite = aiosqlite self.sqlite = sqlite self.paramstyle = "qmark" self._init_dbapi_attributes() - def _init_dbapi_attributes(self): + def _init_dbapi_attributes(self) -> None: for name in ( "DatabaseError", "Error", @@ -195,7 +213,7 @@ class AsyncAdapt_aiosqlite_dbapi: for name in ("Binary",): setattr(self, name, getattr(self.sqlite, name)) - def connect(self, *arg, **kw): + def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiosqlite_connection: creator_fn = kw.pop("async_creator_fn", None) if creator_fn: connection = creator_fn(*arg, **kw) @@ -211,7 +229,7 @@ class AsyncAdapt_aiosqlite_dbapi: class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext): - def create_server_side_cursor(self): + def create_server_side_cursor(self) -> DBAPICursor: return self._dbapi_connection.cursor(server_side=True) @@ -226,19 +244,25 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): execution_ctx_cls = SQLiteExecutionContext_aiosqlite @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> AsyncAdapt_aiosqlite_dbapi: return AsyncAdapt_aiosqlite_dbapi( __import__("aiosqlite"), __import__("sqlite3") ) @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: URL) -> type[pool.Pool]: if cls._is_url_file_db(url): return pool.AsyncAdaptedQueuePool else: return pool.StaticPool - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: + self.dbapi = cast("DBAPIModule", self.dbapi) if isinstance(e, self.dbapi.OperationalError): err_lower = str(e).lower() if ( @@ -249,8 +273,10 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): return super().is_disconnect(e, connection, cursor) - def get_driver_connection(self, connection): - return connection._connection + def get_driver_connection( + self, connection: DBAPIConnection + ) -> AsyncIODBAPIConnection: + return connection._connection # type: ignore[no-any-return] dialect = SQLiteDialect_aiosqlite diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index d4b1518a3e..c6fd69225c 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -391,10 +391,15 @@ connection when it is created. That is accomplished with an event listener:: print(conn.scalar(text("SELECT UDF()"))) """ # noqa +from __future__ import annotations import math import os import re +from typing import cast +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from .base import DATE from .base import DATETIME @@ -404,6 +409,13 @@ from ... import pool from ... import types as sqltypes from ... import util +if TYPE_CHECKING: + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.url import URL + from ...pool.base import PoolProxiedConnection + class _SQLite_pysqliteTimeStamp(DATETIME): def bind_processor(self, dialect): @@ -457,7 +469,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return sqlite @classmethod - def _is_url_file_db(cls, url): + def _is_url_file_db(cls, url: URL): if (url.database and url.database != ":memory:") and ( url.query.get("mode", None) != "memory" ): @@ -587,7 +599,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return ([filename], pysqlite_opts) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: + self.dbapi = cast("DBAPIModule", self.dbapi) return isinstance( e, self.dbapi.ProgrammingError ) and "Cannot operate on a closed database." in str(e) -- 2.47.2