From: Pablo Estevez Date: Wed, 6 Aug 2025 00:46:56 +0000 (+0000) Subject: type pysqlite X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0535ceeb3264c23967e7164c1db6a671afabe0db;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git type pysqlite --- diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index b78423d329..e8d14333d6 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -981,7 +981,10 @@ from __future__ import annotations import datetime import numbers import re +from typing import Any +from typing import Callable from typing import Optional +from typing import TYPE_CHECKING from .json import JSON from .json import JSONIndexType @@ -1014,6 +1017,13 @@ from ...types import TEXT # noqa from ...types import TIMESTAMP # noqa from ...types import VARCHAR # noqa +if TYPE_CHECKING: + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import Dialect + from ...engine.interfaces import IsolationLevel + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _ResultProcessorType + class _SQliteJson(JSON): def result_processor(self, dialect, coltype): @@ -1152,7 +1162,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): "%(hour)02d:%(minute)02d:%(second)02d" ) - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[Any]]: datetime_datetime = datetime.datetime datetime_date = datetime.date format_ = self._storage_format @@ -1188,7 +1200,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): return process - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if self._reg: return processors.str_to_datetime_processor_factory( self._reg, datetime.datetime @@ -1243,7 +1257,9 @@ class DATE(_DateTimeMixin, sqltypes.Date): _storage_format = "%(year)04d-%(month)02d-%(day)02d" - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[Any]]: datetime_date = datetime.date format_ = self._storage_format @@ -1264,7 +1280,9 @@ class DATE(_DateTimeMixin, sqltypes.Date): return process - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if self._reg: return processors.str_to_datetime_processor_factory( self._reg, datetime.date @@ -2095,11 +2113,11 @@ class SQLiteDialect(default.DefaultDialect): def __init__( self, - native_datetime=False, - json_serializer=None, - json_deserializer=None, - **kwargs, - ): + native_datetime: bool = False, + json_serializer: Optional[Callable[..., Any]] = None, + json_deserializer: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ) -> None: default.DefaultDialect.__init__(self, **kwargs) self._json_serializer = json_serializer @@ -2163,7 +2181,9 @@ class SQLiteDialect(default.DefaultDialect): def get_isolation_level_values(self, dbapi_connection): return list(self._isolation_lookup) - def set_isolation_level(self, dbapi_connection, level): + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: isolation_level = self._isolation_lookup[level] cursor = dbapi_connection.cursor() diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index c6fd69225c..a344d3ec4e 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" @@ -396,9 +395,14 @@ from __future__ import annotations import math import os import re +from typing import Any +from typing import Callable from typing import cast from typing import Optional +from typing import Pattern +from typing import Self from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from .base import DATE @@ -410,21 +414,30 @@ from ... import types as sqltypes from ... import util if TYPE_CHECKING: + from ...engine.interfaces import ConnectArgsType from ...engine.interfaces import DBAPIConnection from ...engine.interfaces import DBAPICursor from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import IsolationLevel + from ...engine.interfaces import VersionInfoType from ...engine.url import URL from ...pool.base import PoolProxiedConnection + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _ResultProcessorType class _SQLite_pysqliteTimeStamp(DATETIME): - def bind_processor(self, dialect): + def bind_processor( # type: ignore[override] + self, dialect: SQLiteDialect + ) -> Optional[_BindProcessorType[Any]]: if dialect.native_datetime: return None else: return DATETIME.bind_processor(self, dialect) - def result_processor(self, dialect, coltype): + def result_processor( # type: ignore[override] + self, dialect: SQLiteDialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if dialect.native_datetime: return None else: @@ -432,13 +445,17 @@ class _SQLite_pysqliteTimeStamp(DATETIME): class _SQLite_pysqliteDate(DATE): - def bind_processor(self, dialect): + def bind_processor( # type: ignore[override] + self, dialect: SQLiteDialect + ) -> Optional[_BindProcessorType[Any]]: if dialect.native_datetime: return None else: return DATE.bind_processor(self, dialect) - def result_processor(self, dialect, coltype): + def result_processor( # type: ignore[override] + self, dialect: SQLiteDialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if dialect.native_datetime: return None else: @@ -463,13 +480,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect): driver = "pysqlite" @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> DBAPIModule: from sqlite3 import dbapi2 as sqlite - return sqlite + return cast("DBAPIModule", sqlite) @classmethod - def _is_url_file_db(cls, url: URL): + def _is_url_file_db(cls, url: URL) -> bool: if (url.database and url.database != ":memory:") and ( url.query.get("mode", None) != "memory" ): @@ -478,14 +495,14 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return False @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: URL) -> type[pool.Pool]: if cls._is_url_file_db(url): return pool.QueuePool else: return pool.SingletonThreadPool - def _get_server_version_info(self, connection): - return self.dbapi.sqlite_version_info + def _get_server_version_info(self, connection: Any) -> VersionInfoType: + return self.dbapi.sqlite_version_info # type: ignore _isolation_lookup = SQLiteDialect._isolation_lookup.union( { @@ -493,15 +510,17 @@ class SQLiteDialect_pysqlite(SQLiteDialect): } ) - def set_isolation_level(self, dbapi_connection, level): + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: if level == "AUTOCOMMIT": dbapi_connection.isolation_level = None else: dbapi_connection.isolation_level = "" return super().set_isolation_level(dbapi_connection, level) - def on_connect(self): - def regexp(a, b): + def on_connect(self) -> Callable[[DBAPIConnection], None]: + def regexp(a: str, b: Optional[str]) -> Optional[bool]: if b is None: return None return re.search(a, b) is not None @@ -515,12 +534,12 @@ class SQLiteDialect_pysqlite(SQLiteDialect): else: create_func_kw = {} - def set_regexp(dbapi_connection): + def set_regexp(dbapi_connection: DBAPIConnection) -> None: dbapi_connection.create_function( "regexp", 2, regexp, **create_func_kw ) - def floor_func(dbapi_connection): + def floor_func(dbapi_connection: DBAPIConnection) -> None: # NOTE: floor is optionally present in sqlite 3.35+ , however # as it is normally non-present we deliver floor() unconditionally # for now. @@ -531,13 +550,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect): fns = [set_regexp, floor_func] - def connect(conn): + def connect(conn: DBAPIConnection) -> None: for fn in fns: fn(conn) return connect - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: if url.username or url.password or url.host or url.port: raise exc.ArgumentError( "Invalid SQLite URL: %s\n" @@ -562,7 +581,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): ("cached_statements", int), ] opts = url.query - pysqlite_opts = {} + pysqlite_opts: dict[str, Any] = {} for key, type_ in pysqlite_args: util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts) @@ -579,7 +598,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): # to adjust for that here. for key, type_ in pysqlite_args: uri_opts.pop(key, None) - filename = url.database + filename: str = url.database # type: ignore[assignment] if uri_opts: # sorting of keys is for unit test support filename += "?" + ( @@ -627,34 +646,36 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): driver = "pysqlite_numeric" _first_bind = ":1" - _not_in_statement_regexp = None + _not_in_statement_regexp: Optional[Pattern[str]] = None - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any) -> None: kw.setdefault("paramstyle", "numeric") super().__init__(*arg, **kw) - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: arg, opts = super().create_connect_args(url) opts["factory"] = self._fix_sqlite_issue_99953() return arg, opts - def _fix_sqlite_issue_99953(self): + def _fix_sqlite_issue_99953(self) -> Any: import sqlite3 first_bind = self._first_bind if self._not_in_statement_regexp: nis = self._not_in_statement_regexp - def _test_sql(sql): + def _test_sql(sql: str) -> None: m = nis.search(sql) assert not m, f"Found {nis.pattern!r} in {sql!r}" else: - def _test_sql(sql): + def _test_sql(sql: str) -> None: pass - def _numeric_param_as_dict(parameters): + def _numeric_param_as_dict( + parameters: Any, + ) -> Union[dict[str, Any], tuple[Any, ...]]: if parameters: assert isinstance(parameters, tuple) return { @@ -664,13 +685,13 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): return () class SQLiteFix99953Cursor(sqlite3.Cursor): - def execute(self, sql, parameters=()): + def execute(self, sql: str, parameters: Any = ()) -> Self: _test_sql(sql) if first_bind in sql: parameters = _numeric_param_as_dict(parameters) return super().execute(sql, parameters) - def executemany(self, sql, parameters): + def executemany(self, sql: str, parameters: Any) -> Self: _test_sql(sql) if first_bind in sql: parameters = [ @@ -679,18 +700,27 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): return super().executemany(sql, parameters) class SQLiteFix99953Connection(sqlite3.Connection): - def cursor(self, factory=None): + _CursorT = TypeVar("_CursorT", bound=sqlite3.Cursor) + + def cursor( + self, + factory: Optional[ + Callable[[sqlite3.Connection], _CursorT] + ] = None, + ) -> _CursorT: if factory is None: - factory = SQLiteFix99953Cursor - return super().cursor(factory=factory) + factory = SQLiteFix99953Cursor # type: ignore[assignment] + return super().cursor(factory=factory) # type: ignore[return-value] # noqa[E501] - def execute(self, sql, parameters=()): + def execute( + self, sql: str, parameters: Any = () + ) -> sqlite3.Cursor: _test_sql(sql) if first_bind in sql: parameters = _numeric_param_as_dict(parameters) return super().execute(sql, parameters) - def executemany(self, sql, parameters): + def executemany(self, sql: str, parameters: Any) -> sqlite3.Cursor: _test_sql(sql) if first_bind in sql: parameters = [ @@ -716,6 +746,6 @@ class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric): _first_bind = "$1" _not_in_statement_regexp = re.compile(r"[^\d]:\d+") - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any) -> None: kw.setdefault("paramstyle", "numeric_dollar") super().__init__(*arg, **kw)