From: Pablo Estevez Date: Thu, 26 Jun 2025 19:21:55 +0000 (-0400) Subject: type aiosqlite X-Git-Tag: rel_2_0_42~5^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=688f136fa2c342176c42cb6da3f23241665bbba0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git type aiosqlite Closes: #12656 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12656 Pull-request-sha: 396e9cfaaccbc56537b967e62decbbb3eb0e036a Change-Id: I598ee6022616f265824291544750e571eaba413c (cherry picked from commit 944df50e92bd077b81775b080bfd9347f3baabdc) --- diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index b8cb8c3819..3f39d4dbc7 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" @@ -78,19 +77,43 @@ based on the kind of SQLite database that's requested: :paramref:`_sa.create_engine.poolclass` parameter. """ # noqa +from __future__ import annotations import asyncio from collections import deque from functools import partial +from types import ModuleType +from typing import Any +from typing import cast +from typing import Deque +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import TYPE_CHECKING +from typing import Union from .base import SQLiteExecutionContext from .pysqlite import SQLiteDialect_pysqlite from ... import pool from ... import util +from ...connectors.asyncio import AsyncAdapt_dbapi_module from ...engine import AdaptedConnection from ...util.concurrency import await_fallback from ...util.concurrency import await_only +if TYPE_CHECKING: + from ...connectors.asyncio import AsyncIODBAPIConnection + from ...connectors.asyncio import AsyncIODBAPICursor + from ...engine.interfaces import _DBAPICursorDescription + from ...engine.interfaces import _DBAPIMultiExecuteParams + from ...engine.interfaces import _DBAPISingleExecuteParams + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.url import URL + from ...pool.base import PoolProxiedConnection + class AsyncAdapt_aiosqlite_cursor: # TODO: base on connectors/asyncio.py @@ -109,21 +132,26 @@ class AsyncAdapt_aiosqlite_cursor: server_side = False - def __init__(self, adapt_connection): + def __init__(self, adapt_connection: AsyncAdapt_aiosqlite_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection self.await_ = adapt_connection.await_ self.arraysize = 1 self.rowcount = -1 - self.description = None - self._rows = deque() + self.description: Optional[_DBAPICursorDescription] = None + self._rows: Deque[Any] = deque() - def close(self): + def close(self) -> None: self._rows.clear() - def execute(self, operation, parameters=None): + def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: + try: - _cursor = self.await_(self._connection.cursor()) + _cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor()) # type: ignore[arg-type] # noqa: E501 if parameters is None: self.await_(_cursor.execute(operation)) @@ -144,13 +172,17 @@ class AsyncAdapt_aiosqlite_cursor: if not self.server_side: self.await_(_cursor.close()) else: - self._cursor = _cursor + self._cursor = _cursor # type: ignore[misc] except Exception as error: self._adapt_connection._handle_exception(error) - def executemany(self, operation, seq_of_parameters): + def executemany( + self, + operation: Any, + seq_of_parameters: _DBAPIMultiExecuteParams, + ) -> Any: try: - _cursor = self.await_(self._connection.cursor()) + _cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor()) # type: ignore[arg-type] # noqa: E501 self.await_(_cursor.executemany(operation, seq_of_parameters)) self.description = None self.lastrowid = _cursor.lastrowid @@ -159,27 +191,27 @@ class AsyncAdapt_aiosqlite_cursor: except Exception as error: self._adapt_connection._handle_exception(error) - def setinputsizes(self, *inputsizes): + def setinputsizes(self, *inputsizes: Any) -> None: pass - def __iter__(self): + def __iter__(self) -> Iterator[Any]: while self._rows: yield self._rows.popleft() - def fetchone(self): + def fetchone(self) -> Optional[Any]: if self._rows: return self._rows.popleft() else: return None - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: if size is None: size = self.arraysize rr = self._rows return [rr.popleft() for _ in range(min(size, len(rr)))] - def fetchall(self): + def fetchall(self) -> Sequence[Any]: retval = list(self._rows) self._rows.clear() return retval @@ -192,24 +224,27 @@ class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor): server_side = True - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any) -> None: super().__init__(*arg, **kw) - self._cursor = None + self._cursor: Optional[AsyncIODBAPICursor] = None - def close(self): + def close(self) -> None: if self._cursor is not None: self.await_(self._cursor.close()) self._cursor = None - def fetchone(self): + def fetchone(self) -> Optional[Any]: + assert self._cursor is not None return self.await_(self._cursor.fetchone()) - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: + assert self._cursor is not None if size is None: size = self.arraysize return self.await_(self._cursor.fetchmany(size=size)) - def fetchall(self): + def fetchall(self) -> Sequence[Any]: + assert self._cursor is not None return self.await_(self._cursor.fetchall()) @@ -217,22 +252,24 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection): await_ = staticmethod(await_only) __slots__ = ("dbapi",) - def __init__(self, dbapi, connection): + def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection) -> None: self.dbapi = dbapi self._connection = connection @property - def isolation_level(self): - return self._connection.isolation_level + def isolation_level(self) -> Optional[str]: + return cast(str, self._connection.isolation_level) @isolation_level.setter - def isolation_level(self, value): + def isolation_level(self, value: Optional[str]) -> None: # aiosqlite's isolation_level setter works outside the Thread # that it's supposed to, necessitating setting check_same_thread=False. # for improved stability, we instead invent our own awaitable version # using aiosqlite's async queue directly. - def set_iso(connection, value): + def set_iso( + connection: AsyncAdapt_aiosqlite_connection, value: Optional[str] + ) -> None: connection.isolation_level = value function = partial(set_iso, self._connection._conn, value) @@ -241,38 +278,38 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection): self._connection._tx.put_nowait((future, function)) try: - return self.await_(future) + self.await_(future) except Exception as error: self._handle_exception(error) - def create_function(self, *args, **kw): + def create_function(self, *args: Any, **kw: Any) -> None: try: self.await_(self._connection.create_function(*args, **kw)) except Exception as error: self._handle_exception(error) - def cursor(self, server_side=False): + def cursor(self, server_side: bool = False) -> AsyncAdapt_aiosqlite_cursor: if server_side: return AsyncAdapt_aiosqlite_ss_cursor(self) else: return AsyncAdapt_aiosqlite_cursor(self) - def execute(self, *args, **kw): + def execute(self, *args: Any, **kw: Any) -> Any: return self.await_(self._connection.execute(*args, **kw)) - def rollback(self): + def rollback(self) -> None: try: self.await_(self._connection.rollback()) except Exception as error: self._handle_exception(error) - def commit(self): + def commit(self) -> None: try: self.await_(self._connection.commit()) except Exception as error: self._handle_exception(error) - def close(self): + def close(self) -> None: try: self.await_(self._connection.close()) except ValueError: @@ -288,7 +325,7 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection): except Exception as error: self._handle_exception(error) - def _handle_exception(self, error): + def _handle_exception(self, error: Exception) -> NoReturn: if ( isinstance(error, ValueError) and error.args[0] == "no active connection" @@ -306,14 +343,14 @@ class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection): await_ = staticmethod(await_fallback) -class AsyncAdapt_aiosqlite_dbapi: - def __init__(self, aiosqlite, sqlite): +class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module): + def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType): self.aiosqlite = aiosqlite self.sqlite = sqlite self.paramstyle = "qmark" self._init_dbapi_attributes() - def _init_dbapi_attributes(self): + def _init_dbapi_attributes(self) -> None: for name in ( "DatabaseError", "Error", @@ -332,7 +369,7 @@ class AsyncAdapt_aiosqlite_dbapi: for name in ("Binary",): setattr(self, name, getattr(self.sqlite, name)) - def connect(self, *arg, **kw): + def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiosqlite_connection: async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", None) @@ -356,7 +393,7 @@ class AsyncAdapt_aiosqlite_dbapi: class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext): - def create_server_side_cursor(self): + def create_server_side_cursor(self) -> DBAPICursor: return self._dbapi_connection.cursor(server_side=True) @@ -371,19 +408,25 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): execution_ctx_cls = SQLiteExecutionContext_aiosqlite @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> AsyncAdapt_aiosqlite_dbapi: return AsyncAdapt_aiosqlite_dbapi( __import__("aiosqlite"), __import__("sqlite3") ) @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: URL) -> type[pool.Pool]: if cls._is_url_file_db(url): return pool.AsyncAdaptedQueuePool else: return pool.StaticPool - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: + self.dbapi = cast("DBAPIModule", self.dbapi) if isinstance( e, self.dbapi.OperationalError ) and "no active connection" in str(e): @@ -391,8 +434,10 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): return super().is_disconnect(e, connection, cursor) - def get_driver_connection(self, connection): - return connection._connection + def get_driver_connection( + self, connection: DBAPIConnection + ) -> AsyncIODBAPIConnection: + return connection._connection # type: ignore[no-any-return] dialect = SQLiteDialect_aiosqlite diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 4a777e3b81..2f23886b54 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -393,10 +393,15 @@ connection when it is created. That is accomplished with an event listener:: print(conn.scalar(text("SELECT UDF()"))) """ # noqa +from __future__ import annotations import math import os import re +from typing import cast +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from .base import DATE from .base import DATETIME @@ -406,6 +411,13 @@ from ... import pool from ... import types as sqltypes from ... import util +if TYPE_CHECKING: + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.url import URL + from ...pool.base import PoolProxiedConnection + class _SQLite_pysqliteTimeStamp(DATETIME): def bind_processor(self, dialect): @@ -459,7 +471,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return sqlite @classmethod - def _is_url_file_db(cls, url): + def _is_url_file_db(cls, url: URL): if (url.database and url.database != ":memory:") and ( url.query.get("mode", None) != "memory" ): @@ -589,7 +601,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return ([filename], pysqlite_opts) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: + self.dbapi = cast("DBAPIModule", self.dbapi) return isinstance( e, self.dbapi.ProgrammingError ) and "Cannot operate on a closed database." in str(e)