From e833474db064df32a7c02e809893e71ac4df4996 Mon Sep 17 00:00:00 2001 From: Pablo Estevez Date: Sat, 23 Aug 2025 08:33:47 -0400 Subject: [PATCH] type pysqlite type pysqlite from dialects. type some related code on pysqlite.py related to #6810 ### Description ### Checklist This pull request is: - [ ] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #12789 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12789 Pull-request-sha: 2c1ea7283d534dd625c8f0e4270247d2cc5ed40c Change-Id: I4d691c4bb334957029cd47289463555034ebd866 --- lib/sqlalchemy/dialects/sqlite/base.py | 40 ++++++-- lib/sqlalchemy/dialects/sqlite/pysqlite.py | 104 +++++++++++++-------- 2 files changed, 97 insertions(+), 47 deletions(-) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 0f9cef6004..edd973fb51 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -989,7 +989,10 @@ from __future__ import annotations import datetime import numbers import re +from typing import Any +from typing import Callable from typing import Optional +from typing import TYPE_CHECKING from .json import JSON from .json import JSONIndexType @@ -1022,6 +1025,13 @@ from ...types import TEXT # noqa from ...types import TIMESTAMP # noqa from ...types import VARCHAR # noqa +if TYPE_CHECKING: + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import Dialect + from ...engine.interfaces import IsolationLevel + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _ResultProcessorType + class _SQliteJson(JSON): def result_processor(self, dialect, coltype): @@ -1160,7 +1170,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): "%(hour)02d:%(minute)02d:%(second)02d" ) - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[Any]]: datetime_datetime = datetime.datetime datetime_date = datetime.date format_ = self._storage_format @@ -1196,7 +1208,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): return process - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if self._reg: return processors.str_to_datetime_processor_factory( self._reg, datetime.datetime @@ -1251,7 +1265,9 @@ class DATE(_DateTimeMixin, sqltypes.Date): _storage_format = "%(year)04d-%(month)02d-%(day)02d" - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[Any]]: datetime_date = datetime.date format_ = self._storage_format @@ -1272,7 +1288,9 @@ class DATE(_DateTimeMixin, sqltypes.Date): return process - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if self._reg: return processors.str_to_datetime_processor_factory( self._reg, datetime.date @@ -2123,11 +2141,11 @@ class SQLiteDialect(default.DefaultDialect): def __init__( self, - native_datetime=False, - json_serializer=None, - json_deserializer=None, - **kwargs, - ): + native_datetime: bool = False, + json_serializer: Optional[Callable[..., Any]] = None, + json_deserializer: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ) -> None: default.DefaultDialect.__init__(self, **kwargs) self._json_serializer = json_serializer @@ -2191,7 +2209,9 @@ class SQLiteDialect(default.DefaultDialect): def get_isolation_level_values(self, dbapi_connection): return list(self._isolation_lookup) - def set_isolation_level(self, dbapi_connection, level): + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: isolation_level = self._isolation_lookup[level] cursor = dbapi_connection.cursor() diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index ea2c6a8765..116c89e8b6 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" @@ -396,9 +395,13 @@ from __future__ import annotations import math import os import re +from typing import Any +from typing import Callable from typing import cast from typing import Optional +from typing import Pattern from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from .base import DATE @@ -408,23 +411,33 @@ from ... import exc from ... import pool from ... import types as sqltypes from ... import util +from ...util.typing import Self if TYPE_CHECKING: + from ...engine.interfaces import ConnectArgsType from ...engine.interfaces import DBAPIConnection from ...engine.interfaces import DBAPICursor from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import IsolationLevel + from ...engine.interfaces import VersionInfoType from ...engine.url import URL from ...pool.base import PoolProxiedConnection + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _ResultProcessorType class _SQLite_pysqliteTimeStamp(DATETIME): - def bind_processor(self, dialect): + def bind_processor( # type: ignore[override] + self, dialect: SQLiteDialect + ) -> Optional[_BindProcessorType[Any]]: if dialect.native_datetime: return None else: return DATETIME.bind_processor(self, dialect) - def result_processor(self, dialect, coltype): + def result_processor( # type: ignore[override] + self, dialect: SQLiteDialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if dialect.native_datetime: return None else: @@ -432,13 +445,17 @@ class _SQLite_pysqliteTimeStamp(DATETIME): class _SQLite_pysqliteDate(DATE): - def bind_processor(self, dialect): + def bind_processor( # type: ignore[override] + self, dialect: SQLiteDialect + ) -> Optional[_BindProcessorType[Any]]: if dialect.native_datetime: return None else: return DATE.bind_processor(self, dialect) - def result_processor(self, dialect, coltype): + def result_processor( # type: ignore[override] + self, dialect: SQLiteDialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if dialect.native_datetime: return None else: @@ -463,13 +480,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect): driver = "pysqlite" @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> DBAPIModule: from sqlite3 import dbapi2 as sqlite - return sqlite + return cast("DBAPIModule", sqlite) @classmethod - def _is_url_file_db(cls, url: URL): + def _is_url_file_db(cls, url: URL) -> bool: if (url.database and url.database != ":memory:") and ( url.query.get("mode", None) != "memory" ): @@ -478,14 +495,14 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return False @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: URL) -> type[pool.Pool]: if cls._is_url_file_db(url): return pool.QueuePool else: return pool.SingletonThreadPool - def _get_server_version_info(self, connection): - return self.dbapi.sqlite_version_info + def _get_server_version_info(self, connection: Any) -> VersionInfoType: + return self.dbapi.sqlite_version_info # type: ignore _isolation_lookup = SQLiteDialect._isolation_lookup.union( { @@ -493,18 +510,20 @@ class SQLiteDialect_pysqlite(SQLiteDialect): } ) - def set_isolation_level(self, dbapi_connection, level): + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: if level == "AUTOCOMMIT": dbapi_connection.isolation_level = None else: dbapi_connection.isolation_level = "" return super().set_isolation_level(dbapi_connection, level) - def detect_autocommit_setting(self, dbapi_connection): - return dbapi_connection.isolation_level is None + def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool: + return dbapi_conn.isolation_level is None - def on_connect(self): - def regexp(a, b): + def on_connect(self) -> Callable[[DBAPIConnection], None]: + def regexp(a: str, b: Optional[str]) -> Optional[bool]: if b is None: return None return re.search(a, b) is not None @@ -518,12 +537,12 @@ class SQLiteDialect_pysqlite(SQLiteDialect): else: create_func_kw = {} - def set_regexp(dbapi_connection): + def set_regexp(dbapi_connection: DBAPIConnection) -> None: dbapi_connection.create_function( "regexp", 2, regexp, **create_func_kw ) - def floor_func(dbapi_connection): + def floor_func(dbapi_connection: DBAPIConnection) -> None: # NOTE: floor is optionally present in sqlite 3.35+ , however # as it is normally non-present we deliver floor() unconditionally # for now. @@ -534,13 +553,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect): fns = [set_regexp, floor_func] - def connect(conn): + def connect(conn: DBAPIConnection) -> None: for fn in fns: fn(conn) return connect - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: if url.username or url.password or url.host or url.port: raise exc.ArgumentError( "Invalid SQLite URL: %s\n" @@ -565,7 +584,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): ("cached_statements", int), ] opts = url.query - pysqlite_opts = {} + pysqlite_opts: dict[str, Any] = {} for key, type_ in pysqlite_args: util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts) @@ -582,7 +601,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): # to adjust for that here. for key, type_ in pysqlite_args: uri_opts.pop(key, None) - filename = url.database + filename: str = url.database # type: ignore[assignment] if uri_opts: # sorting of keys is for unit test support filename += "?" + ( @@ -630,34 +649,36 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): driver = "pysqlite_numeric" _first_bind = ":1" - _not_in_statement_regexp = None + _not_in_statement_regexp: Optional[Pattern[str]] = None - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any) -> None: kw.setdefault("paramstyle", "numeric") super().__init__(*arg, **kw) - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: arg, opts = super().create_connect_args(url) opts["factory"] = self._fix_sqlite_issue_99953() return arg, opts - def _fix_sqlite_issue_99953(self): + def _fix_sqlite_issue_99953(self) -> Any: import sqlite3 first_bind = self._first_bind if self._not_in_statement_regexp: nis = self._not_in_statement_regexp - def _test_sql(sql): + def _test_sql(sql: str) -> None: m = nis.search(sql) assert not m, f"Found {nis.pattern!r} in {sql!r}" else: - def _test_sql(sql): + def _test_sql(sql: str) -> None: pass - def _numeric_param_as_dict(parameters): + def _numeric_param_as_dict( + parameters: Any, + ) -> Union[dict[str, Any], tuple[Any, ...]]: if parameters: assert isinstance(parameters, tuple) return { @@ -667,13 +688,13 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): return () class SQLiteFix99953Cursor(sqlite3.Cursor): - def execute(self, sql, parameters=()): + def execute(self, sql: str, parameters: Any = ()) -> Self: _test_sql(sql) if first_bind in sql: parameters = _numeric_param_as_dict(parameters) return super().execute(sql, parameters) - def executemany(self, sql, parameters): + def executemany(self, sql: str, parameters: Any) -> Self: _test_sql(sql) if first_bind in sql: parameters = [ @@ -682,18 +703,27 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): return super().executemany(sql, parameters) class SQLiteFix99953Connection(sqlite3.Connection): - def cursor(self, factory=None): + _CursorT = TypeVar("_CursorT", bound=sqlite3.Cursor) + + def cursor( + self, + factory: Optional[ + Callable[[sqlite3.Connection], _CursorT] + ] = None, + ) -> _CursorT: if factory is None: - factory = SQLiteFix99953Cursor - return super().cursor(factory=factory) + factory = SQLiteFix99953Cursor # type: ignore[assignment] + return super().cursor(factory=factory) # type: ignore[return-value] # noqa[E501] - def execute(self, sql, parameters=()): + def execute( + self, sql: str, parameters: Any = () + ) -> sqlite3.Cursor: _test_sql(sql) if first_bind in sql: parameters = _numeric_param_as_dict(parameters) return super().execute(sql, parameters) - def executemany(self, sql, parameters): + def executemany(self, sql: str, parameters: Any) -> sqlite3.Cursor: _test_sql(sql) if first_bind in sql: parameters = [ @@ -719,6 +749,6 @@ class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric): _first_bind = "$1" _not_in_statement_regexp = re.compile(r"[^\d]:\d+") - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any) -> None: kw.setdefault("paramstyle", "numeric_dollar") super().__init__(*arg, **kw) -- 2.47.3