]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
type pysqlite main
authorPablo Estevez <pablo22estevez@gmail.com>
Sat, 23 Aug 2025 12:33:47 +0000 (08:33 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 1 Oct 2025 20:07:12 +0000 (22:07 +0200)
<!-- Provide a general summary of your proposed changes in the Title field above -->
type pysqlite from dialects.
type some related code on pysqlite.py

related to #6810

### Description
<!-- Describe your changes in detail -->

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #12789
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12789
Pull-request-sha: 2c1ea7283d534dd625c8f0e4270247d2cc5ed40c

Change-Id: I4d691c4bb334957029cd47289463555034ebd866

lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py

index 0f9cef6004a1b8246761f0a1c867464689118e3f..edd973fb512149ad23f663acd5ca41f1abfb167a 100644 (file)
@@ -989,7 +989,10 @@ from __future__ import annotations
 import datetime
 import numbers
 import re
 import datetime
 import numbers
 import re
+from typing import Any
+from typing import Callable
 from typing import Optional
 from typing import Optional
+from typing import TYPE_CHECKING
 
 from .json import JSON
 from .json import JSONIndexType
 
 from .json import JSON
 from .json import JSONIndexType
@@ -1022,6 +1025,13 @@ from ...types import TEXT  # noqa
 from ...types import TIMESTAMP  # noqa
 from ...types import VARCHAR  # noqa
 
 from ...types import TIMESTAMP  # noqa
 from ...types import VARCHAR  # noqa
 
+if TYPE_CHECKING:
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import Dialect
+    from ...engine.interfaces import IsolationLevel
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _ResultProcessorType
+
 
 class _SQliteJson(JSON):
     def result_processor(self, dialect, coltype):
 
 class _SQliteJson(JSON):
     def result_processor(self, dialect, coltype):
@@ -1160,7 +1170,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime):
                 "%(hour)02d:%(minute)02d:%(second)02d"
             )
 
                 "%(hour)02d:%(minute)02d:%(second)02d"
             )
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[Any]]:
         datetime_datetime = datetime.datetime
         datetime_date = datetime.date
         format_ = self._storage_format
         datetime_datetime = datetime.datetime
         datetime_date = datetime.date
         format_ = self._storage_format
@@ -1196,7 +1208,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime):
 
         return process
 
 
         return process
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if self._reg:
             return processors.str_to_datetime_processor_factory(
                 self._reg, datetime.datetime
         if self._reg:
             return processors.str_to_datetime_processor_factory(
                 self._reg, datetime.datetime
@@ -1251,7 +1265,9 @@ class DATE(_DateTimeMixin, sqltypes.Date):
 
     _storage_format = "%(year)04d-%(month)02d-%(day)02d"
 
 
     _storage_format = "%(year)04d-%(month)02d-%(day)02d"
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[Any]]:
         datetime_date = datetime.date
         format_ = self._storage_format
 
         datetime_date = datetime.date
         format_ = self._storage_format
 
@@ -1272,7 +1288,9 @@ class DATE(_DateTimeMixin, sqltypes.Date):
 
         return process
 
 
         return process
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if self._reg:
             return processors.str_to_datetime_processor_factory(
                 self._reg, datetime.date
         if self._reg:
             return processors.str_to_datetime_processor_factory(
                 self._reg, datetime.date
@@ -2123,11 +2141,11 @@ class SQLiteDialect(default.DefaultDialect):
 
     def __init__(
         self,
 
     def __init__(
         self,
-        native_datetime=False,
-        json_serializer=None,
-        json_deserializer=None,
-        **kwargs,
-    ):
+        native_datetime: bool = False,
+        json_serializer: Optional[Callable[..., Any]] = None,
+        json_deserializer: Optional[Callable[..., Any]] = None,
+        **kwargs: Any,
+    ) -> None:
         default.DefaultDialect.__init__(self, **kwargs)
 
         self._json_serializer = json_serializer
         default.DefaultDialect.__init__(self, **kwargs)
 
         self._json_serializer = json_serializer
@@ -2191,7 +2209,9 @@ class SQLiteDialect(default.DefaultDialect):
     def get_isolation_level_values(self, dbapi_connection):
         return list(self._isolation_lookup)
 
     def get_isolation_level_values(self, dbapi_connection):
         return list(self._isolation_lookup)
 
-    def set_isolation_level(self, dbapi_connection, level):
+    def set_isolation_level(
+        self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+    ) -> None:
         isolation_level = self._isolation_lookup[level]
 
         cursor = dbapi_connection.cursor()
         isolation_level = self._isolation_lookup[level]
 
         cursor = dbapi_connection.cursor()
index ea2c6a876578fd75b5c67bd66a71f530b72da58b..116c89e8b6617ff62ee19bf00ae5e6c72dc1ee9e 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 r"""
 
 
 r"""
@@ -396,9 +395,13 @@ from __future__ import annotations
 import math
 import os
 import re
 import math
 import os
 import re
+from typing import Any
+from typing import Callable
 from typing import cast
 from typing import Optional
 from typing import cast
 from typing import Optional
+from typing import Pattern
 from typing import TYPE_CHECKING
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from .base import DATE
 from typing import Union
 
 from .base import DATE
@@ -408,23 +411,33 @@ from ... import exc
 from ... import pool
 from ... import types as sqltypes
 from ... import util
 from ... import pool
 from ... import types as sqltypes
 from ... import util
+from ...util.typing import Self
 
 if TYPE_CHECKING:
 
 if TYPE_CHECKING:
+    from ...engine.interfaces import ConnectArgsType
     from ...engine.interfaces import DBAPIConnection
     from ...engine.interfaces import DBAPICursor
     from ...engine.interfaces import DBAPIModule
     from ...engine.interfaces import DBAPIConnection
     from ...engine.interfaces import DBAPICursor
     from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import IsolationLevel
+    from ...engine.interfaces import VersionInfoType
     from ...engine.url import URL
     from ...pool.base import PoolProxiedConnection
     from ...engine.url import URL
     from ...pool.base import PoolProxiedConnection
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _ResultProcessorType
 
 
 class _SQLite_pysqliteTimeStamp(DATETIME):
 
 
 class _SQLite_pysqliteTimeStamp(DATETIME):
-    def bind_processor(self, dialect):
+    def bind_processor(  # type: ignore[override]
+        self, dialect: SQLiteDialect
+    ) -> Optional[_BindProcessorType[Any]]:
         if dialect.native_datetime:
             return None
         else:
             return DATETIME.bind_processor(self, dialect)
 
         if dialect.native_datetime:
             return None
         else:
             return DATETIME.bind_processor(self, dialect)
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(  # type: ignore[override]
+        self, dialect: SQLiteDialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if dialect.native_datetime:
             return None
         else:
         if dialect.native_datetime:
             return None
         else:
@@ -432,13 +445,17 @@ class _SQLite_pysqliteTimeStamp(DATETIME):
 
 
 class _SQLite_pysqliteDate(DATE):
 
 
 class _SQLite_pysqliteDate(DATE):
-    def bind_processor(self, dialect):
+    def bind_processor(  # type: ignore[override]
+        self, dialect: SQLiteDialect
+    ) -> Optional[_BindProcessorType[Any]]:
         if dialect.native_datetime:
             return None
         else:
             return DATE.bind_processor(self, dialect)
 
         if dialect.native_datetime:
             return None
         else:
             return DATE.bind_processor(self, dialect)
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(  # type: ignore[override]
+        self, dialect: SQLiteDialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if dialect.native_datetime:
             return None
         else:
         if dialect.native_datetime:
             return None
         else:
@@ -463,13 +480,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
     driver = "pysqlite"
 
     @classmethod
     driver = "pysqlite"
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> DBAPIModule:
         from sqlite3 import dbapi2 as sqlite
 
         from sqlite3 import dbapi2 as sqlite
 
-        return sqlite
+        return cast("DBAPIModule", sqlite)
 
     @classmethod
 
     @classmethod
-    def _is_url_file_db(cls, url: URL):
+    def _is_url_file_db(cls, url: URL) -> bool:
         if (url.database and url.database != ":memory:") and (
             url.query.get("mode", None) != "memory"
         ):
         if (url.database and url.database != ":memory:") and (
             url.query.get("mode", None) != "memory"
         ):
@@ -478,14 +495,14 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
             return False
 
     @classmethod
             return False
 
     @classmethod
-    def get_pool_class(cls, url):
+    def get_pool_class(cls, url: URL) -> type[pool.Pool]:
         if cls._is_url_file_db(url):
             return pool.QueuePool
         else:
             return pool.SingletonThreadPool
 
         if cls._is_url_file_db(url):
             return pool.QueuePool
         else:
             return pool.SingletonThreadPool
 
-    def _get_server_version_info(self, connection):
-        return self.dbapi.sqlite_version_info
+    def _get_server_version_info(self, connection: Any) -> VersionInfoType:
+        return self.dbapi.sqlite_version_info  # type: ignore
 
     _isolation_lookup = SQLiteDialect._isolation_lookup.union(
         {
 
     _isolation_lookup = SQLiteDialect._isolation_lookup.union(
         {
@@ -493,18 +510,20 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
         }
     )
 
         }
     )
 
-    def set_isolation_level(self, dbapi_connection, level):
+    def set_isolation_level(
+        self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+    ) -> None:
         if level == "AUTOCOMMIT":
             dbapi_connection.isolation_level = None
         else:
             dbapi_connection.isolation_level = ""
             return super().set_isolation_level(dbapi_connection, level)
 
         if level == "AUTOCOMMIT":
             dbapi_connection.isolation_level = None
         else:
             dbapi_connection.isolation_level = ""
             return super().set_isolation_level(dbapi_connection, level)
 
-    def detect_autocommit_setting(self, dbapi_connection):
-        return dbapi_connection.isolation_level is None
+    def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
+        return dbapi_conn.isolation_level is None
 
 
-    def on_connect(self):
-        def regexp(a, b):
+    def on_connect(self) -> Callable[[DBAPIConnection], None]:
+        def regexp(a: str, b: Optional[str]) -> Optional[bool]:
             if b is None:
                 return None
             return re.search(a, b) is not None
             if b is None:
                 return None
             return re.search(a, b) is not None
@@ -518,12 +537,12 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
         else:
             create_func_kw = {}
 
         else:
             create_func_kw = {}
 
-        def set_regexp(dbapi_connection):
+        def set_regexp(dbapi_connection: DBAPIConnection) -> None:
             dbapi_connection.create_function(
                 "regexp", 2, regexp, **create_func_kw
             )
 
             dbapi_connection.create_function(
                 "regexp", 2, regexp, **create_func_kw
             )
 
-        def floor_func(dbapi_connection):
+        def floor_func(dbapi_connection: DBAPIConnection) -> None:
             # NOTE: floor is optionally present in sqlite 3.35+ , however
             # as it is normally non-present we deliver floor() unconditionally
             # for now.
             # NOTE: floor is optionally present in sqlite 3.35+ , however
             # as it is normally non-present we deliver floor() unconditionally
             # for now.
@@ -534,13 +553,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
 
         fns = [set_regexp, floor_func]
 
 
         fns = [set_regexp, floor_func]
 
-        def connect(conn):
+        def connect(conn: DBAPIConnection) -> None:
             for fn in fns:
                 fn(conn)
 
         return connect
 
             for fn in fns:
                 fn(conn)
 
         return connect
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> ConnectArgsType:
         if url.username or url.password or url.host or url.port:
             raise exc.ArgumentError(
                 "Invalid SQLite URL: %s\n"
         if url.username or url.password or url.host or url.port:
             raise exc.ArgumentError(
                 "Invalid SQLite URL: %s\n"
@@ -565,7 +584,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
             ("cached_statements", int),
         ]
         opts = url.query
             ("cached_statements", int),
         ]
         opts = url.query
-        pysqlite_opts = {}
+        pysqlite_opts: dict[str, Any] = {}
         for key, type_ in pysqlite_args:
             util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts)
 
         for key, type_ in pysqlite_args:
             util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts)
 
@@ -582,7 +601,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
             # to adjust for that here.
             for key, type_ in pysqlite_args:
                 uri_opts.pop(key, None)
             # to adjust for that here.
             for key, type_ in pysqlite_args:
                 uri_opts.pop(key, None)
-            filename = url.database
+            filename: str = url.database  # type: ignore[assignment]
             if uri_opts:
                 # sorting of keys is for unit test support
                 filename += "?" + (
             if uri_opts:
                 # sorting of keys is for unit test support
                 filename += "?" + (
@@ -630,34 +649,36 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
     driver = "pysqlite_numeric"
 
     _first_bind = ":1"
     driver = "pysqlite_numeric"
 
     _first_bind = ":1"
-    _not_in_statement_regexp = None
+    _not_in_statement_regexp: Optional[Pattern[str]] = None
 
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any) -> None:
         kw.setdefault("paramstyle", "numeric")
         super().__init__(*arg, **kw)
 
         kw.setdefault("paramstyle", "numeric")
         super().__init__(*arg, **kw)
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> ConnectArgsType:
         arg, opts = super().create_connect_args(url)
         opts["factory"] = self._fix_sqlite_issue_99953()
         return arg, opts
 
         arg, opts = super().create_connect_args(url)
         opts["factory"] = self._fix_sqlite_issue_99953()
         return arg, opts
 
-    def _fix_sqlite_issue_99953(self):
+    def _fix_sqlite_issue_99953(self) -> Any:
         import sqlite3
 
         first_bind = self._first_bind
         if self._not_in_statement_regexp:
             nis = self._not_in_statement_regexp
 
         import sqlite3
 
         first_bind = self._first_bind
         if self._not_in_statement_regexp:
             nis = self._not_in_statement_regexp
 
-            def _test_sql(sql):
+            def _test_sql(sql: str) -> None:
                 m = nis.search(sql)
                 assert not m, f"Found {nis.pattern!r} in {sql!r}"
 
         else:
 
                 m = nis.search(sql)
                 assert not m, f"Found {nis.pattern!r} in {sql!r}"
 
         else:
 
-            def _test_sql(sql):
+            def _test_sql(sql: str) -> None:
                 pass
 
                 pass
 
-        def _numeric_param_as_dict(parameters):
+        def _numeric_param_as_dict(
+            parameters: Any,
+        ) -> Union[dict[str, Any], tuple[Any, ...]]:
             if parameters:
                 assert isinstance(parameters, tuple)
                 return {
             if parameters:
                 assert isinstance(parameters, tuple)
                 return {
@@ -667,13 +688,13 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
                 return ()
 
         class SQLiteFix99953Cursor(sqlite3.Cursor):
                 return ()
 
         class SQLiteFix99953Cursor(sqlite3.Cursor):
-            def execute(self, sql, parameters=()):
+            def execute(self, sql: str, parameters: Any = ()) -> Self:
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = _numeric_param_as_dict(parameters)
                 return super().execute(sql, parameters)
 
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = _numeric_param_as_dict(parameters)
                 return super().execute(sql, parameters)
 
-            def executemany(self, sql, parameters):
+            def executemany(self, sql: str, parameters: Any) -> Self:
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = [
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = [
@@ -682,18 +703,27 @@ class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
                 return super().executemany(sql, parameters)
 
         class SQLiteFix99953Connection(sqlite3.Connection):
                 return super().executemany(sql, parameters)
 
         class SQLiteFix99953Connection(sqlite3.Connection):
-            def cursor(self, factory=None):
+            _CursorT = TypeVar("_CursorT", bound=sqlite3.Cursor)
+
+            def cursor(
+                self,
+                factory: Optional[
+                    Callable[[sqlite3.Connection], _CursorT]
+                ] = None,
+            ) -> _CursorT:
                 if factory is None:
                 if factory is None:
-                    factory = SQLiteFix99953Cursor
-                return super().cursor(factory=factory)
+                    factory = SQLiteFix99953Cursor  # type: ignore[assignment]
+                return super().cursor(factory=factory)  # type: ignore[return-value]  # noqa[E501]
 
 
-            def execute(self, sql, parameters=()):
+            def execute(
+                self, sql: str, parameters: Any = ()
+            ) -> sqlite3.Cursor:
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = _numeric_param_as_dict(parameters)
                 return super().execute(sql, parameters)
 
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = _numeric_param_as_dict(parameters)
                 return super().execute(sql, parameters)
 
-            def executemany(self, sql, parameters):
+            def executemany(self, sql: str, parameters: Any) -> sqlite3.Cursor:
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = [
                 _test_sql(sql)
                 if first_bind in sql:
                     parameters = [
@@ -719,6 +749,6 @@ class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric):
     _first_bind = "$1"
     _not_in_statement_regexp = re.compile(r"[^\d]:\d+")
 
     _first_bind = "$1"
     _not_in_statement_regexp = re.compile(r"[^\d]:\d+")
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any) -> None:
         kw.setdefault("paramstyle", "numeric_dollar")
         super().__init__(*arg, **kw)
         kw.setdefault("paramstyle", "numeric_dollar")
         super().__init__(*arg, **kw)