From: Mike Bayer Date: Mon, 7 Mar 2022 17:01:42 +0000 (-0500) Subject: adapt create_engine from sqlalchemy2-stubs X-Git-Tag: rel_2_0_0b1~440^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e36496808a93870c07f143b0a17d2b8df224c55d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git adapt create_engine from sqlalchemy2-stubs this is much simplified, will try to see if _IsolationLevel can work out, technically some driver can have custom values here but in practice this might not be a thing Change-Id: I6085ccb559c377fab03c8ce79f0eecb240c56f7a --- diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index c5f07de077..61b40d935e 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -9,6 +9,7 @@ from __future__ import annotations import re from types import ModuleType +import typing from typing import Any from typing import Dict from typing import List @@ -27,6 +28,9 @@ from ..engine import interfaces from ..engine import URL from ..sql.type_api import TypeEngine +if typing.TYPE_CHECKING: + from ..engine.interfaces import _IsolationLevel + class PyODBCConnector(Connector): driver = "pyodbc" @@ -208,13 +212,15 @@ class PyODBCConnector(Connector): def get_isolation_level_values( self, dbapi_connection: interfaces.DBAPIConnection - ) -> List[str]: + ) -> List[_IsolationLevel]: return super().get_isolation_level_values(dbapi_connection) + [ # type: ignore # noqa E501 "AUTOCOMMIT" ] def set_isolation_level( - self, dbapi_connection: interfaces.DBAPIConnection, level: str + self, + dbapi_connection: interfaces.DBAPIConnection, + level: _IsolationLevel, ) -> None: # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index cb5219396b..68a6b81e2c 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -10,9 +10,13 @@ from __future__ import annotations import inspect import typing from typing import Any +from typing import Callable from typing import cast from typing import Dict +from typing import List from typing import Optional +from typing import overload +from typing import Type from typing import Union from . import base @@ -28,6 +32,61 @@ from ..sql import compiler if typing.TYPE_CHECKING: from .base import Engine + from .interfaces import _ExecuteOptions + from .interfaces import _IsolationLevel + from .interfaces import _ParamStyle + from .url import URL + from ..log import _EchoFlagType + from ..pool import _CreatorFnType + from ..pool import _CreatorWRecFnType + from ..pool import _ResetStyleArgType + from ..pool import Pool + from ..util.typing import Literal + + +@overload +def create_engine( + url: Union[str, URL], + *, + connect_args: Dict[Any, Any] = ..., + convert_unicode: bool = ..., + creator: Union[_CreatorFnType, _CreatorWRecFnType] = ..., + echo: _EchoFlagType = ..., + echo_pool: _EchoFlagType = ..., + enable_from_linting: bool = ..., + execution_options: _ExecuteOptions = ..., + future: Literal[True], + hide_parameters: bool = ..., + implicit_returning: bool = ..., + isolation_level: _IsolationLevel = ..., + json_deserializer: Callable[..., Any] = ..., + json_serializer: Callable[..., Any] = ..., + label_length: Optional[int] = ..., + listeners: Any = ..., + logging_name: str = ..., + max_identifier_length: Optional[int] = ..., + max_overflow: int = ..., + module: Optional[Any] = ..., + paramstyle: Optional[_ParamStyle] = ..., + pool: Optional[Pool] = ..., + poolclass: Optional[Type[Pool]] = ..., + pool_logging_name: str = ..., + pool_pre_ping: bool = ..., + pool_size: int = ..., + pool_recycle: int = ..., + pool_reset_on_return: Optional[_ResetStyleArgType] = ..., + pool_timeout: float = ..., + pool_use_lifo: bool = ..., + plugins: List[str] = ..., + query_cache_size: int = ..., + **kwargs: Any, +) -> Engine: + ... + + +@overload +def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: + ... @util.deprecated_params( diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 545dd0ddcd..5aefcf5b56 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -37,6 +37,7 @@ from ..sql.compiler import TypeCompiler # noqa from ..util import immutabledict from ..util.concurrency import await_only from ..util.typing import _TypeToInstance +from ..util.typing import Literal from ..util.typing import NotRequired from ..util.typing import Protocol from ..util.typing import TypedDict @@ -221,6 +222,16 @@ _SchemaTranslateMapType = Mapping[str, str] _ImmutableExecuteOptions = immutabledict[str, Any] +_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"] + +_IsolationLevel = Literal[ + "SERIALIZABLE", + "REPEATABLE READ", + "READ COMMITTED", + "READ UNCOMMITTED", + "AUTOCOMMIT", +] + class ReflectedIdentity(TypedDict): """represent the reflected IDENTITY structure of a column, corresponding @@ -622,7 +633,7 @@ class Dialect(EventTarget): """ - default_isolation_level: str + default_isolation_level: _IsolationLevel """the isolation that is implicitly present on new connections""" execution_ctx_cls: Type["ExecutionContext"] @@ -1647,7 +1658,7 @@ class Dialect(EventTarget): raise NotImplementedError() def set_isolation_level( - self, dbapi_connection: DBAPIConnection, level: str + self, dbapi_connection: DBAPIConnection, level: _IsolationLevel ) -> None: """Given a DBAPI connection, set its isolation level. @@ -1680,7 +1691,9 @@ class Dialect(EventTarget): raise NotImplementedError() - def get_isolation_level(self, dbapi_connection: DBAPIConnection) -> str: + def get_isolation_level( + self, dbapi_connection: DBAPIConnection + ) -> _IsolationLevel: """Given a DBAPI connection, return its isolation level. When working with a :class:`_engine.Connection` object, @@ -1713,7 +1726,9 @@ class Dialect(EventTarget): raise NotImplementedError() - def get_default_isolation_level(self, dbapi_conn: DBAPIConnection) -> str: + def get_default_isolation_level( + self, dbapi_conn: DBAPIConnection + ) -> _IsolationLevel: """Given a DBAPI connection, return its isolation level, or a default isolation level if one cannot be retrieved. @@ -1735,7 +1750,7 @@ class Dialect(EventTarget): def get_isolation_level_values( self, dbapi_conn: DBAPIConnection - ) -> List[str]: + ) -> List[_IsolationLevel]: """return a sequence of string isolation level names that are accepted by this dialect. @@ -1778,7 +1793,7 @@ class Dialect(EventTarget): raise NotImplementedError() def _assert_and_set_isolation_level( - self, dbapi_conn: DBAPIConnection, level: str + self, dbapi_conn: DBAPIConnection, level: _IsolationLevel ) -> None: raise NotImplementedError() diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 1fc77243a9..c86d3ddeda 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -21,7 +21,10 @@ from . import events from .base import _AdhocProxiedConnection as _AdhocProxiedConnection from .base import _ConnectionFairy as _ConnectionFairy from .base import _ConnectionRecord +from .base import _CreatorFnType as _CreatorFnType +from .base import _CreatorWRecFnType as _CreatorWRecFnType from .base import _finalize_fairy +from .base import _ResetStyleArgType as _ResetStyleArgType from .base import ConnectionPoolEntry as ConnectionPoolEntry from .base import ManagesConnection as ManagesConnection from .base import Pool as Pool diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index c1008de5f5..c9fb8cb341 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -56,11 +56,7 @@ class ResetStyle(Enum): _ResetStyleArgType = Union[ ResetStyle, - Literal[True], - Literal[None], - Literal[False], - Literal["commit"], - Literal["rollback"], + Literal[True, None, False, "commit", "rollback"], ] reset_rollback, reset_commit, reset_none = list(ResetStyle)