]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
type pysqlite main
authorPablo Estevez <pablo22estevez@gmail.com>
Sat, 23 Aug 2025 12:33:47 +0000 (08:33 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 1 Oct 2025 20:07:12 +0000 (22:07 +0200)
<!-- Provide a general summary of your proposed changes in the Title field above -->
type pysqlite from dialects.
type some related code on pysqlite.py

related to #6810

### Description
<!-- Describe your changes in detail -->

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #12789
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12789
Pull-request-sha: 2c1ea7283d534dd625c8f0e4270247d2cc5ed40c

Change-Id: I4d691c4bb334957029cd47289463555034ebd866

lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py

index 0f9cef6004a1b8246761f0a1c867464689118e3f..edd973fb512149ad23f663acd5ca41f1abfb167a 100644 (file)
@@ -989,7 +989,10 @@ from __future__ import annotations
 import datetime
 import numbers
 import re
+from typing import Any
+from typing import Callable
 from typing import Optional
+from typing import TYPE_CHECKING
 
 from .json import JSON
 from .json import JSONIndexType
@@ -1022,6 +1025,13 @@ from ...types import TEXT  # noqa
 from ...types import TIMESTAMP  # noqa
 from ...types import VARCHAR  # noqa
 
+if TYPE_CHECKING:
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import Dialect
+    from ...engine.interfaces import IsolationLevel
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _ResultProcessorType
+
 
 class _SQliteJson(JSON):
     def result_processor(self, dialect, coltype):
@@ -1160,7 +1170,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime):
                 "%(hour)02d:%(minute)02d:%(second)02d"
             )
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[Any]]:
         datetime_datetime = datetime.datetime
         datetime_date = datetime.date
         format_ = self._storage_format
@@ -1196,7 +1208,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime):
 
         return process
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if self._reg:
             return processors.str_to_datetime_processor_factory(
                 self._reg, datetime.datetime
@@ -1251,7 +1265,9 @@ class DATE(_DateTimeMixin, sqltypes.Date):
 
     _storage_format = "%(year)04d-%(month)02d-%(day)02d"
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[Any]]:
         datetime_date = datetime.date
         format_ = self._storage_format
 
@@ -1272,7 +1288,9 @@ class DATE(_DateTimeMixin, sqltypes.Date):
 
         return process
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if self._reg:
             return processors.str_to_datetime_processor_factory(
                 self._reg, datetime.date
@@ -2123,11 +2141,11 @@ class SQLiteDialect(default.DefaultDialect):
 
     def __init__(
         self,
-        native_datetime=False,
-        json_serializer=None,
-        json_deserializer=None,
-        **kwargs,
-    ):
+        native_datetime: bool = False,
+        json_serializer: Optional[Callable[..., Any]] = None,
+        json_deserializer: Optional[Callable[..., Any]] = None,
+        **kwargs: Any,
+    ) -> None:
         default.DefaultDialect.__init__(self, **kwargs)
 
         self._json_serializer = json_serializer
@@ -2191,7 +2209,9 @@ class SQLiteDialect(default.DefaultDialect):
     def get_isolation_level_values(self, dbapi_connection):
         return list(self._isolation_lookup)
 
-    def set_isolation_level(self, dbapi_connection, level):
+    def set_isolation_level(
+        self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+    ) -> None:
         isolation_level = self._isolation_lookup[level]
 
         cursor = dbapi_connection.cursor()
index ea2c6a876578fd75b5c67bd66a71f530b72da58b..116c89e8b6617ff62ee19bf00ae5e6c72dc1ee9e 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 r"""
@@ -396,9 +395,13 @@ from __future__ import annotations
 import math
 import os
 import re
+from typing import Any
+from typing import Callable
 from typing import cast
 from typing import Optional
+from typing import Pattern
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from .base import DATE
@@ -408,23 +411,33 @@ from ... import exc
 from ... import pool
 from ... import types as sqltypes
 from ... import util
+from ...util.typing import Self
 
 if TYPE_CHECKING:
+    from ...engine.interfaces import ConnectArgsType
     from ...engine.interfaces import DBAPIConnection
     from ...engine.interfaces import DBAPICursor
     from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import IsolationLevel
+    from ...engine.interfaces import VersionInfoType
     from ...engine.url import URL
     from ...pool.base import PoolProxiedConnection
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _ResultProcessorType
 
 
 class _SQLite_pysqliteTimeStamp(DATETIME):
-    def bind_processor(self, dialect):
+    def bind_processor(  # type: ignore[override]
+        self, dialect: SQLiteDialect
+    ) -> Optional[_BindProcessorType[Any]]:
         if dialect.native_datetime:
             return None
         else:
             return DATETIME.bind_processor(self, dialect)
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(  # type: ignore[override]
+        self, dialect: SQLiteDialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if dialect.native_datetime:
             return None
         else:
@@ -432,13 +445,17 @@ class _SQLite_pysqliteTimeStamp(DATETIME):
 
 
 class _SQLite_pysqliteDate(DATE):
-    def bind_processor(self, dialect):
+    def bind_processor(  # type: ignore[override]
+        self, dialect: SQLiteDialect
+    ) -> Optional[_BindProcessorType[Any]]:
         if dialect.native_datetime:
             return None
         else:
             return DATE.bind_processor(self, dialect)
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(  # type: ignore[override]
+        self, dialect: SQLiteDialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if dialect.native_datetime:
             return None
         else:
@@ -463,13 +480,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
     driver = "pysqlite"
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> DBAPIModule:
         from sqlite3 import dbapi2 as sqlite
 
-        return sqlite
+        return cast("DBAPIModule", sqlite)
 
     @classmethod
-    def _is_url_file_db(cls, url: URL):
+    def _is_url_file_db(cls, url: URL) -> bool:
         if (url.database and url.database != ":memory:") and (
             url.query.get("mode", None) != "memory"
         ):
@@ -478,14 +495,14 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
             return False
 
     @classmethod
-    def get_pool_class(cls, url):
+    def get_pool_class(cls, url: URL) -> type[pool.Pool]:
         if cls._is_url_file_db(url):
             return pool.QueuePool
         else:
             return pool.SingletonThreadPool
 
-    def _get_server_version_info(self, connection):
-        return self.dbapi.sqlite_version_info
+    def _get_server_version_info(self, connection: Any) -> VersionInfoType:
+        return self.dbapi.sqlite_version_info  # type: ignore
 
     _isolation_lookup = SQLiteDialect._isolation_lookup.union(
         {
@@ -493,18 +510,20 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
         }
     )
 
-    def set_isolation_level(self, dbapi_connection, level):
+    def set_isolation_level(
+        self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+    ) -> None:
         if level == "AUTOCOMMIT":
             dbapi_connection.isolation_level = None
         else:
             dbapi_connection.isolation_level = ""
             return super().set_isolation_level(dbapi_connection, level)
 
-    def detect_autocommit_setting(self, dbapi_connection):
-        return dbapi_connection.isolation_level is None
+    def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
+        return dbapi_conn.isolation_level is None
 
-    def on_connect(self):
-        def regexp(a, b):
+    def on_connect(self) -> Callable[[DBAPIConnection], None]:
+        def regexp(a: str, b: Optional[str]) -> Optional[bool]:
             if b is None:
                 return None
             return re.search(a, b) is not None
@@ -518,12 +537,12 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
         else:
             create_func_kw = {}
 
-        def set_regexp(dbapi_connection):
+        def set_regexp(dbapi_connection: DBAPIConnection) -> None:
             dbapi_connection.create_function(
                 "regexp", 2, regexp, **create_func_kw
             )
 
-        def floor_func(dbapi_connection):
+        def floor_func(dbapi_connection: DBAPIConnection) -> None:
             # NOTE: floor is optionally present in sqlite 3.35+ , however
             # as it is normally non-present we deliver floor() unconditionally
             # for now.
@@ -534,13 +553,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
 
         fns = [set_regexp, floor_func]
 
-        def connect(conn):
+        def connect(conn: DBAPIConnection) -> None:
             for fn in fns:
                 fn(conn)
 
         return connect
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> ConnectArgsType:
         if url.username or url.password or url.host or url.port:
             raise exc.ArgumentError(
                 "Invalid SQLite URL: %s\n"
@@ -565,7 +584,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
             ("cached_statements", int),
         ]
         opts = url.query
-        pysqlite_opts = {}
+        pysqlite_opts: dict[str, Any] = {}
         for key, type_ in pysqlite_args:
             util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts)
 
@@ -582,7 +601,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
             # to adjust for that here.
             for key, type_ in pysqlite_args:
                 uri_opts.pop(key, None)
-            filename = url.database
+            filename: str = url.database  # type: ignore[assignment]
             if uri_opts:
                 # sorting of keys is for unit test support
                 filename += "?" + (
@@ -630,34 +649,36 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
     driver = "pysqlite_numeric"
 
     _first_bind = ":1"
-    _not_in_statement_regexp = None
+    _not_in_statement_regexp: Optional[Pattern[str]] = None
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any) -> None:
         kw.setdefault("paramstyle", "numeric")
         super().__init__(*arg, **kw)
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> ConnectArgsType:
         arg, opts = super().create_connect_args(url)
         opts["factory"] = self._fix_sqlite_issue_99953()
         return arg, opts
 
-    def _fix_sqlite_issue_99953(self):
+    def _fix_sqlite_issue_99953(self) -> Any:
         import sqlite3
 
         first_bind = self._first_bind
         if self._not_in_statement_regexp:
             nis = self._not_in_statement_regexp
 
-            def _test_sql(sql):
+            def _test_sql(sql: str) -> None:
                 m = nis.search(sql)
                 assert not m, f"Found {nis.pattern!r} in {sql!r}"
 
         else:
 
-            def _test_sql(sql):
+            def _test_sql(sql: str) -> None:
                 pass
 
-        def _numeric_param_as_dict(parameters):
+        def _numeric_param_as_dict(
+            parameters: Any,
+        ) -> Union[dict[str, Any], tuple[Any, ...]]:
             if parameters:
                 assert isinstance(parameters, tuple)
                 return {
@@ -667,13 +688,13 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
                 return ()
 
         class SQLiteFix99953Cursor(sqlite3.Cursor):
-            def execute(self, sql, parameters=()):
+            def execute(self, sql: str, parameters: Any = ()) -> Self:
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = _numeric_param_as_dict(parameters)
                 return super().execute(sql, parameters)
 
-            def executemany(self, sql, parameters):
+            def executemany(self, sql: str, parameters: Any) -> Self:
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = [
@@ -682,18 +703,27 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
                 return super().executemany(sql, parameters)
 
         class SQLiteFix99953Connection(sqlite3.Connection):
-            def cursor(self, factory=None):
+            _CursorT = TypeVar("_CursorT", bound=sqlite3.Cursor)
+
+            def cursor(
+                self,
+                factory: Optional[
+                    Callable[[sqlite3.Connection], _CursorT]
+                ] = None,
+            ) -> _CursorT:
                 if factory is None:
-                    factory = SQLiteFix99953Cursor
-                return super().cursor(factory=factory)
+                    factory = SQLiteFix99953Cursor  # type: ignore[assignment]
+                return super().cursor(factory=factory)  # type: ignore[return-value]  # noqa[E501]
 
-            def execute(self, sql, parameters=()):
+            def execute(
+                self, sql: str, parameters: Any = ()
+            ) -> sqlite3.Cursor:
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = _numeric_param_as_dict(parameters)
                 return super().execute(sql, parameters)
 
-            def executemany(self, sql, parameters):
+            def executemany(self, sql: str, parameters: Any) -> sqlite3.Cursor:
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = [
@@ -719,6 +749,6 @@ class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric):
     _first_bind = "$1"
     _not_in_statement_regexp = re.compile(r"[^\d]:\d+")
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any) -> None:
         kw.setdefault("paramstyle", "numeric_dollar")
         super().__init__(*arg, **kw)