]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
adapt create_engine from sqlalchemy2-stubs
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Mar 2022 17:01:42 +0000 (12:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Mar 2022 17:10:40 +0000 (12:10 -0500)
this is much simplified, will try to see if _IsolationLevel
can work out, technically some driver can have custom values
here but in practice this might not be a thing

Change-Id: I6085ccb559c377fab03c8ce79f0eecb240c56f7a

lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/pool/__init__.py
lib/sqlalchemy/pool/base.py

index c5f07de077cf1c4a71a6c855434dbdcf0877c312..61b40d935e1bd0c91fe69e0bec8331a3bcea678a 100644 (file)
@@ -9,6 +9,7 @@ from __future__ import annotations
 
 import re
 from types import ModuleType
+import typing
 from typing import Any
 from typing import Dict
 from typing import List
@@ -27,6 +28,9 @@ from ..engine import interfaces
 from ..engine import URL
 from ..sql.type_api import TypeEngine
 
+if typing.TYPE_CHECKING:
+    from ..engine.interfaces import _IsolationLevel
+
 
 class PyODBCConnector(Connector):
     driver = "pyodbc"
@@ -208,13 +212,15 @@ class PyODBCConnector(Connector):
 
     def get_isolation_level_values(
         self, dbapi_connection: interfaces.DBAPIConnection
-    ) -> List[str]:
+    ) -> List[_IsolationLevel]:
         return super().get_isolation_level_values(dbapi_connection) + [  # type: ignore  # noqa E501
             "AUTOCOMMIT"
         ]
 
     def set_isolation_level(
-        self, dbapi_connection: interfaces.DBAPIConnection, level: str
+        self,
+        dbapi_connection: interfaces.DBAPIConnection,
+        level: _IsolationLevel,
     ) -> None:
         # adjust for ConnectionFairy being present
         # allows attribute set e.g. "connection.autocommit = True"
index cb5219396bbc9f05fc7189d1d93803d77692b163..68a6b81e2c7709d1c5c860dbad382e36f6276b30 100644 (file)
@@ -10,9 +10,13 @@ from __future__ import annotations
 import inspect
 import typing
 from typing import Any
+from typing import Callable
 from typing import cast
 from typing import Dict
+from typing import List
 from typing import Optional
+from typing import overload
+from typing import Type
 from typing import Union
 
 from . import base
@@ -28,6 +32,61 @@ from ..sql import compiler
 
 if typing.TYPE_CHECKING:
     from .base import Engine
+    from .interfaces import _ExecuteOptions
+    from .interfaces import _IsolationLevel
+    from .interfaces import _ParamStyle
+    from .url import URL
+    from ..log import _EchoFlagType
+    from ..pool import _CreatorFnType
+    from ..pool import _CreatorWRecFnType
+    from ..pool import _ResetStyleArgType
+    from ..pool import Pool
+    from ..util.typing import Literal
+
+
+@overload
+def create_engine(
+    url: Union[str, URL],
+    *,
+    connect_args: Dict[Any, Any] = ...,
+    convert_unicode: bool = ...,
+    creator: Union[_CreatorFnType, _CreatorWRecFnType] = ...,
+    echo: _EchoFlagType = ...,
+    echo_pool: _EchoFlagType = ...,
+    enable_from_linting: bool = ...,
+    execution_options: _ExecuteOptions = ...,
+    future: Literal[True],
+    hide_parameters: bool = ...,
+    implicit_returning: bool = ...,
+    isolation_level: _IsolationLevel = ...,
+    json_deserializer: Callable[..., Any] = ...,
+    json_serializer: Callable[..., Any] = ...,
+    label_length: Optional[int] = ...,
+    listeners: Any = ...,
+    logging_name: str = ...,
+    max_identifier_length: Optional[int] = ...,
+    max_overflow: int = ...,
+    module: Optional[Any] = ...,
+    paramstyle: Optional[_ParamStyle] = ...,
+    pool: Optional[Pool] = ...,
+    poolclass: Optional[Type[Pool]] = ...,
+    pool_logging_name: str = ...,
+    pool_pre_ping: bool = ...,
+    pool_size: int = ...,
+    pool_recycle: int = ...,
+    pool_reset_on_return: Optional[_ResetStyleArgType] = ...,
+    pool_timeout: float = ...,
+    pool_use_lifo: bool = ...,
+    plugins: List[str] = ...,
+    query_cache_size: int = ...,
+    **kwargs: Any,
+) -> Engine:
+    ...
+
+
+@overload
+def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine:
+    ...
 
 
 @util.deprecated_params(
index 545dd0ddcdeef9a34b690ca892b4db570bdbb893..5aefcf5b565a946812f69353c0febfa5dddc4d35 100644 (file)
@@ -37,6 +37,7 @@ from ..sql.compiler import TypeCompiler  # noqa
 from ..util import immutabledict
 from ..util.concurrency import await_only
 from ..util.typing import _TypeToInstance
+from ..util.typing import Literal
 from ..util.typing import NotRequired
 from ..util.typing import Protocol
 from ..util.typing import TypedDict
@@ -221,6 +222,16 @@ _SchemaTranslateMapType = Mapping[str, str]
 
 _ImmutableExecuteOptions = immutabledict[str, Any]
 
+_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
+
+_IsolationLevel = Literal[
+    "SERIALIZABLE",
+    "REPEATABLE READ",
+    "READ COMMITTED",
+    "READ UNCOMMITTED",
+    "AUTOCOMMIT",
+]
+
 
 class ReflectedIdentity(TypedDict):
     """represent the reflected IDENTITY structure of a column, corresponding
@@ -622,7 +633,7 @@ class Dialect(EventTarget):
 
     """
 
-    default_isolation_level: str
+    default_isolation_level: _IsolationLevel
     """the isolation that is implicitly present on new connections"""
 
     execution_ctx_cls: Type["ExecutionContext"]
@@ -1647,7 +1658,7 @@ class Dialect(EventTarget):
         raise NotImplementedError()
 
     def set_isolation_level(
-        self, dbapi_connection: DBAPIConnection, level: str
+        self, dbapi_connection: DBAPIConnection, level: _IsolationLevel
     ) -> None:
         """Given a DBAPI connection, set its isolation level.
 
@@ -1680,7 +1691,9 @@ class Dialect(EventTarget):
 
         raise NotImplementedError()
 
-    def get_isolation_level(self, dbapi_connection: DBAPIConnection) -> str:
+    def get_isolation_level(
+        self, dbapi_connection: DBAPIConnection
+    ) -> _IsolationLevel:
         """Given a DBAPI connection, return its isolation level.
 
         When working with a :class:`_engine.Connection` object,
@@ -1713,7 +1726,9 @@ class Dialect(EventTarget):
 
         raise NotImplementedError()
 
-    def get_default_isolation_level(self, dbapi_conn: DBAPIConnection) -> str:
+    def get_default_isolation_level(
+        self, dbapi_conn: DBAPIConnection
+    ) -> _IsolationLevel:
         """Given a DBAPI connection, return its isolation level, or
         a default isolation level if one cannot be retrieved.
 
@@ -1735,7 +1750,7 @@ class Dialect(EventTarget):
 
     def get_isolation_level_values(
         self, dbapi_conn: DBAPIConnection
-    ) -> List[str]:
+    ) -> List[_IsolationLevel]:
         """return a sequence of string isolation level names that are accepted
         by this dialect.
 
@@ -1778,7 +1793,7 @@ class Dialect(EventTarget):
         raise NotImplementedError()
 
     def _assert_and_set_isolation_level(
-        self, dbapi_conn: DBAPIConnection, level: str
+        self, dbapi_conn: DBAPIConnection, level: _IsolationLevel
     ) -> None:
         raise NotImplementedError()
 
index 1fc77243a9ef3594215437a6ff93964133174f24..c86d3ddedaf7480eeb1a10bcdd4f86352549e6c6 100644 (file)
@@ -21,7 +21,10 @@ from . import events
 from .base import _AdhocProxiedConnection as _AdhocProxiedConnection
 from .base import _ConnectionFairy as _ConnectionFairy
 from .base import _ConnectionRecord
+from .base import _CreatorFnType as _CreatorFnType
+from .base import _CreatorWRecFnType as _CreatorWRecFnType
 from .base import _finalize_fairy
+from .base import _ResetStyleArgType as _ResetStyleArgType
 from .base import ConnectionPoolEntry as ConnectionPoolEntry
 from .base import ManagesConnection as ManagesConnection
 from .base import Pool as Pool
index c1008de5f522eb5aa8323d8e588355ff26bdb7b0..c9fb8cb341d5185adac48385b9362bd71d29a7d9 100644 (file)
@@ -56,11 +56,7 @@ class ResetStyle(Enum):
 
 _ResetStyleArgType = Union[
     ResetStyle,
-    Literal[True],
-    Literal[None],
-    Literal[False],
-    Literal["commit"],
-    Literal["rollback"],
+    Literal[True, None, False, "commit", "rollback"],
 ]
 reset_rollback, reset_commit, reset_none = list(ResetStyle)