From: Mike Bayer Date: Wed, 16 Feb 2022 04:43:51 +0000 (-0500) Subject: pep-484 for pool X-Git-Tag: rel_2_0_0b1~480^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5157e0aa542f390242dd7a6d27a6ce1663230e46;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep-484 for pool also extends into some areas of utils, events and others as needed. Formalizes a public hierarchy for pool API, with ManagesConnection -> PoolProxiedConnection / ConnectionPoolEntry for connectionfairy / connectionrecord, which are now what's exposed in the event API and other APIs. all public API docs moved to the new objects. Corrects the mypy plugin's check for sqlalchemy-stubs not being insatlled, which has to be imported using the dash in the name to be effective. Change-Id: I16c2cb43b2e840d28e70a015f370a768e70f3581 --- diff --git a/doc/build/core/pooling.rst b/doc/build/core/pooling.rst index bb5e2826a7..008c4e1a15 100644 --- a/doc/build/core/pooling.rst +++ b/doc/build/core/pooling.rst @@ -558,14 +558,18 @@ API Documentation - Available Pool Implementations .. autoclass:: StaticPool -.. autoclass:: PoolProxiedConnection +.. autoclass:: ManagesConnection :members: -.. autoclass:: _ConnectionFairy +.. autoclass:: ConnectionPoolEntry :members: + :inherited-members: - .. autoattribute:: _connection_record +.. autoclass:: PoolProxiedConnection + :members: + :inherited-members: + +.. autoclass:: _ConnectionFairy .. autoclass:: _ConnectionRecord - :members: diff --git a/doc/build/faq/connections.rst b/doc/build/faq/connections.rst index 02d088384c..d592ccf6dc 100644 --- a/doc/build/faq/connections.rst +++ b/doc/build/faq/connections.rst @@ -414,14 +414,14 @@ How do I get at the raw DBAPI connection when using an Engine? With a regular SA engine-level Connection, you can get at a pool-proxied version of the DBAPI connection via the :attr:`_engine.Connection.connection` attribute on :class:`_engine.Connection`, and for the really-real DBAPI connection you can call the -:attr:`._ConnectionFairy.dbapi_connection` attribute on that. On regular sync drivers +:attr:`.PoolProxiedConnection.dbapi_connection` attribute on that. On regular sync drivers there is usually no need to access the non-pool-proxied DBAPI connection, as all methods are proxied through:: engine = create_engine(...) conn = engine.connect() - # pep-249 style ConnectionFairy connection pool proxy object + # pep-249 style PoolProxiedConnection (historically called a "connection fairy") connection_fairy = conn.connection # typically to run statements one would get a cursor() from this @@ -438,11 +438,11 @@ as all methods are proxied through:: also_raw_dbapi_connection = connection_fairy.driver_connection .. versionchanged:: 1.4.24 Added the - :attr:`._ConnectionFairy.dbapi_connection` attribute, + :attr:`.PoolProxiedConnection.dbapi_connection` attribute, which supersedes the previous - :attr:`._ConnectionFairy.connection` attribute which still remains + :attr:`.PoolProxiedConnection.connection` attribute which still remains available; this attribute always provides a pep-249 synchronous style - connection object. The :attr:`._ConnectionFairy.driver_connection` + connection object. The :attr:`.PoolProxiedConnection.driver_connection` attribute is also added which will always refer to the real driver-level connection regardless of what API it presents. @@ -451,15 +451,15 @@ Accessing the underlying connection for an asyncio driver When an asyncio driver is in use, there are two changes to the above scheme. The first is that when using an :class:`_asyncio.AsyncConnection`, -the :class:`._ConnectionFairy` must be accessed using the awaitable method +the :class:`.PoolProxiedConnection` must be accessed using the awaitable method :meth:`_asyncio.AsyncConnection.get_raw_connection`. The -returned :class:`._ConnectionFairy` in this case retains a sync-style -pep-249 usage pattern, and the :attr:`._ConnectionFairy.dbapi_connection` +returned :class:`.PoolProxiedConnection` in this case retains a sync-style +pep-249 usage pattern, and the :attr:`.PoolProxiedConnection.dbapi_connection` attribute refers to a a SQLAlchemy-adapted connection object which adapts the asyncio connection to a sync style pep-249 API, in other words there are *two* levels of proxying going on when using an asyncio driver. The actual asyncio connection -is available from the :class:`._ConnectionFairy.driver_connection` attribute. +is available from the :class:`.PoolProxiedConnection.driver_connection` attribute. To restate the previous example in terms of asyncio looks like:: async def main(): @@ -483,8 +483,8 @@ To restate the previous example in terms of asyncio looks like:: result = await raw_asyncio_connection.execute(...) .. versionchanged:: 1.4.24 Added the - :attr:`._ConnectionFairy.dbapi_connection` - and :attr:`._ConnectionFairy.driver_connection` attributes to allow access + :attr:`.PoolProxiedConnection.dbapi_connection` + and :attr:`.PoolProxiedConnection.driver_connection` attributes to allow access to pep-249 connections, pep-249 adaption layers, and underlying driver connections using a consistent interface. @@ -493,10 +493,10 @@ SQLAlchemy-adapted form of connection which presents a synchronous-style pep-249 style API. To access the actual asyncio driver connection, which will present the original asyncio API of the driver in use, this can be accessed via the -:attr:`._ConnectionFairy.driver_connection` attribute of -:class:`._ConnectionFairy`. -For a standard pep-249 driver, :attr:`._ConnectionFairy.dbapi_connection` -and :attr:`._ConnectionFairy.driver_connection` are synonymous. +:attr:`.PoolProxiedConnection.driver_connection` attribute of +:class:`.PoolProxiedConnection`. +For a standard pep-249 driver, :attr:`.PoolProxiedConnection.dbapi_connection` +and :attr:`.PoolProxiedConnection.driver_connection` are synonymous. You must ensure that you revert any isolation level settings or other operation-specific settings on the connection back to normal before returning diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4fd2739484..8c99f63090 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1771,15 +1771,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if not self._is_disconnect: if cursor: self._safe_close_cursor(cursor) - with util.safe_reraise(warn_only=True): - # "autorollback" was mostly relevant in 1.x series. - # It's very unlikely to reach here, as the connection - # does autobegin so when we are here, we are usually - # in an explicit / semi-explicit transaction. - # however we have a test which manufactures this - # scenario in any case using an event handler. - if not self.in_transaction(): - self._rollback_impl() + # "autorollback" was mostly relevant in 1.x series. + # It's very unlikely to reach here, as the connection + # does autobegin so when we are here, we are usually + # in an explicit / semi-explicit transaction. + # however we have a test which manufactures this + # scenario in any case using an event handler. + # test/engine/test_execute.py-> test_actual_autorollback + if not self.in_transaction(): + self._rollback_impl() if newraise: raise newraise.with_traceback(exc_info[2]) from e @@ -2318,11 +2318,15 @@ class Engine( _schema_translate_map = None + dialect: Dialect + pool: Pool + url: URL + def __init__( self, - pool: "Pool", - dialect: "Dialect", - url: "URL", + pool: Pool, + dialect: Dialect, + url: URL, logging_name: Optional[str] = None, echo: Union[None, str, bool] = None, query_cache_size: int = 500, diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index a252b7cfeb..ac3d6a2d89 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -12,11 +12,13 @@ from typing import Union from . import base from . import url as _url +from .interfaces import DBAPIConnection from .mock import create_mock_engine from .. import event from .. import exc -from .. import pool as poollib from .. import util +from ..pool import _AdhocProxiedConnection +from ..pool import ConnectionPoolEntry from ..sql import compiler @@ -603,10 +605,13 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": if builtin_on_connect: event.listen(pool, "connect", builtin_on_connect) - def first_connect(dbapi_connection, connection_record): + def first_connect( + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ): c = base.Connection( engine, - connection=poollib._AdhocProxiedConnection( + connection=_AdhocProxiedConnection( dbapi_connection, connection_record ), _has_events=False, diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index ce884614c0..aab6b2de87 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -59,7 +59,7 @@ class DBAPIConnection(Protocol): def commit(self) -> None: ... - def cursor(self) -> "DBAPICursor": + def cursor(self) -> DBAPICursor: ... def rollback(self) -> None: @@ -657,6 +657,9 @@ class Dialect: """ + is_async: bool + """Whether or not this dialect is intended for asyncio use.""" + def create_connect_args( self, url: "URL" ) -> Tuple[Tuple[str], Mapping[str, Any]]: @@ -1091,7 +1094,7 @@ class Dialect: raise NotImplementedError() - def do_close(self, dbapi_connection: PoolProxiedConnection) -> None: + def do_close(self, dbapi_connection: DBAPIConnection) -> None: """Provide an implementation of ``connection.close()``, given a DBAPI connection. @@ -1104,6 +1107,11 @@ class Dialect: raise NotImplementedError() + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: + """ping the DBAPI connection and return True if the connection is + usable.""" + raise NotImplementedError() + def do_set_input_sizes( self, cursor: DBAPICursor, @@ -1679,7 +1687,7 @@ class Dialect: """ - def get_driver_connection(self, connection: PoolProxiedConnection) -> Any: + def get_driver_connection(self, connection: DBAPIConnection) -> Any: """Returns the connection object as returned by the external driver package. diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index 2d10372ab1..0dfb39e1a0 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -14,6 +14,10 @@ from .api import listens_for as listens_for from .api import NO_RETVAL as NO_RETVAL from .api import remove as remove from .attr import RefCollection as RefCollection +from .base import _Dispatch as _Dispatch from .base import dispatcher as dispatcher from .base import Events as Events -from .legacy import _legacy_signature +from .legacy import _legacy_signature as _legacy_signature +from .registry import _EventKey as _EventKey +from .registry import _ListenerFnType as _ListenerFnType +from .registry import EventTarget as EventTarget diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index d1ae7a8452..9692894fe8 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -68,6 +68,7 @@ _T = TypeVar("_T", bound=Any) if typing.TYPE_CHECKING: from .base import _Dispatch + from .base import _DispatchCommon from .base import _HasEventsDispatch from .base import _JoinedDispatcher @@ -280,6 +281,38 @@ class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]): def __bool__(self) -> bool: raise NotImplementedError() + def exec_once(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def __call__(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + raise NotImplementedError() + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + raise NotImplementedError() + + def remove(self, event_key: _EventKey[_ET]) -> None: + raise NotImplementedError() + + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _InstanceLevelDispatch[_ET]: + """Return an event collection which can be modified. + + For _ClsLevelDispatch at the class level of + a dispatcher, this returns self. + + """ + return self + class _EmptyListener(_InstanceLevelDispatch[_ET]): """Serves as a proxy interface to the events @@ -306,7 +339,9 @@ class _EmptyListener(_InstanceLevelDispatch[_ET]): self.parent_listeners = parent._clslevel[target_cls] self.name = parent.name - def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]: + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _ListenerCollection[_ET]: """Return an event collection which can be modified. For _EmptyListener at the instance level of @@ -315,6 +350,8 @@ class _EmptyListener(_InstanceLevelDispatch[_ET]): and returns it. """ + obj = cast("_Dispatch[_ET]", obj) + assert obj._instance_cls is not None result = _ListenerCollection(self.parent, obj._instance_cls) if getattr(obj, self.name) is self: @@ -512,7 +549,9 @@ class _ListenerCollection(_CompoundListener[_ET]): self.listeners = collections.deque() self.propagate = set() - def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]: + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _ListenerCollection[_ET]: """Return an event collection which can be modified. For _ListenerCollection at the instance level of @@ -599,7 +638,7 @@ class _JoinedListener(_CompoundListener[_ET]): ) -> _ListenerFnType: return self.local._adjust_fn_spec(fn, named) - def for_modify(self, obj: _JoinedDispatcher[_ET]) -> _JoinedListener[_ET]: + def for_modify(self, obj: _DispatchCommon[_ET]) -> _JoinedListener[_ET]: self.local = self.parent_listeners = self.local.for_modify(obj) return self diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 0e0647036f..ef3ff9dab3 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -17,6 +17,7 @@ instances of ``_Dispatch``. """ from __future__ import annotations +import typing from typing import Any from typing import cast from typing import Dict @@ -71,7 +72,11 @@ class _UnpickleDispatch: raise AttributeError("No class with a 'dispatch' member present.") -class _Dispatch(Generic[_ET]): +class _DispatchCommon(Generic[_ET]): + __slots__ = () + + +class _Dispatch(_DispatchCommon[_ET]): """Mirror the event listening definitions of an Events class with listener collections. @@ -218,6 +223,11 @@ class _HasEventsDispatch(Generic[_ET]): """ + if typing.TYPE_CHECKING: + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + ... + def __init_subclass__(cls) -> None: """Intercept new Event subclasses and create associated _Dispatch classes.""" @@ -357,7 +367,7 @@ class Events(_HasEventsDispatch[_ET]): cls.dispatch._clear() -class _JoinedDispatcher(Generic[_ET]): +class _JoinedDispatcher(_DispatchCommon[_ET]): """Represent a connection between two _Dispatch objects.""" __slots__ = "local", "parent", "_instance_cls" @@ -402,11 +412,11 @@ class dispatcher(Generic[_ET]): @overload def __get__( self, obj: Literal[None], cls: Type[Any] - ) -> Type[_HasEventsDispatch[_ET]]: + ) -> Type[_Dispatch[_ET]]: ... @overload - def __get__(self, obj: Any, cls: Type[Any]) -> _HasEventsDispatch[_ET]: + def __get__(self, obj: Any, cls: Type[Any]) -> _Dispatch[_ET]: ... def __get__(self, obj: Any, cls: Type[Any]) -> Any: diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index e20d3e0b53..449f391876 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -22,7 +22,6 @@ import typing from typing import Any from typing import Callable from typing import cast -from typing import ClassVar from typing import Deque from typing import Dict from typing import Generic @@ -35,7 +34,6 @@ import weakref from .. import exc from .. import util -from ..util.typing import Protocol if typing.TYPE_CHECKING: from .attr import RefCollection @@ -46,7 +44,10 @@ _ListenerFnKeyType = Union[int, Tuple[int, int]] _EventKeyTupleType = Tuple[int, str, _ListenerFnKeyType] -class _EventTargetType(Protocol): +_ET = TypeVar("_ET", bound="EventTarget") + + +class EventTarget: """represents an event target, that is, something we can listen on either with that target as a class or as an instance. @@ -55,10 +56,10 @@ class _EventTargetType(Protocol): """ - dispatch: ClassVar[dispatcher[Any]] + __slots__ = () + dispatch: dispatcher[Any] -_ET = TypeVar("_ET", bound=_EventTargetType) _RefCollectionToListenerType = Dict[ "weakref.ref[RefCollection[Any]]", @@ -104,7 +105,7 @@ def _collection_gced(ref: weakref.ref[Any]) -> None: if not _collection_to_key or ref not in _collection_to_key: return - ref = cast("weakref.ref[RefCollection[_EventTargetType]]", ref) + ref = cast("weakref.ref[RefCollection[EventTarget]]", ref) listener_to_key = _collection_to_key.pop(ref) for key in listener_to_key.values(): diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index f39f4cd8fa..1383e024a1 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -32,7 +32,11 @@ if typing.TYPE_CHECKING: from .sql.compiler import Compiled from .sql.elements import ClauseElement -_version_token = None +if typing.TYPE_CHECKING: + _version_token: str +else: + # set by __init__.py + _version_token = None class HasDescriptionCode: diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index 3a78ab188c..f7e66e3419 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -43,16 +43,16 @@ from . import names from . import util try: - import sqlalchemy_stubs # noqa + __import__("sqlalchemy-stubs") except ImportError: pass else: - import sqlalchemy - raise ImportError( - f"The SQLAlchemy mypy plugin in SQLAlchemy " - f"{sqlalchemy.__version__} does not work with sqlalchemy-stubs or " - "sqlalchemy2-stubs installed" + "The SQLAlchemy mypy plugin in SQLAlchemy " + "2.0 does not work with sqlalchemy-stubs or " + "sqlalchemy2-stubs installed, as well as with any other third party " + "SQLAlchemy stubs. Please uninstall all SQLAlchemy stubs " + "packages." ) diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 2f63b8569d..8da45ed0d7 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -75,12 +75,15 @@ def class_logger(cls: Type[_IT]) -> Type[_IT]: return cls +_IdentifiedLoggerType = Union[logging.Logger, "InstanceLogger"] + + class Identified: __slots__ = () logging_name: Optional[str] = None - logger: Union[logging.Logger, "InstanceLogger"] + logger: _IdentifiedLoggerType _echo: _EchoFlagType diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 33367c0c65..b9c881cfe0 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -662,15 +662,9 @@ class Mapped(Generic[_T], TypingOnly): def _empty_constructor(cls, arg1: Any) -> "Mapped[_T]": ... - @overload - def __set__(self, instance: Any, value: _T) -> None: - ... - - @overload - def __set__(self, instance: Any, value: SQLCoreOperations) -> None: - ... - - def __set__(self, instance, value): + def __set__( + self, instance: Any, value: Union[SQLCoreOperations[_T], _T] + ): ... def __delete__(self, instance: Any): diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index bc2f93d57e..2c52a70650 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -22,6 +22,8 @@ from .base import _AdhocProxiedConnection from .base import _ConnectionFairy from .base import _ConnectionRecord from .base import _finalize_fairy +from .base import ConnectionPoolEntry as ConnectionPoolEntry +from .base import ManagesConnection as ManagesConnection from .base import Pool as Pool from .base import PoolProxiedConnection as PoolProxiedConnection from .base import reset_commit as reset_commit diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 72c56716f1..18d268182d 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -13,24 +13,55 @@ from __future__ import annotations from collections import deque +from enum import Enum +import threading import time +import typing from typing import Any +from typing import Callable +from typing import cast +from typing import Deque from typing import Dict +from typing import List from typing import Optional +from typing import Tuple from typing import TYPE_CHECKING +from typing import Union import weakref from .. import event from .. import exc from .. import log from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol if TYPE_CHECKING: from ..engine.interfaces import DBAPIConnection + from ..engine.interfaces import DBAPICursor + from ..engine.interfaces import Dialect + from ..event import _Dispatch + from ..event import _ListenerFnType + from ..event import dispatcher -reset_rollback = util.symbol("reset_rollback") -reset_commit = util.symbol("reset_commit") -reset_none = util.symbol("reset_none") + +class ResetStyle(Enum): + """Describe options for "reset on return" behaviors.""" + + reset_rollback = 0 + reset_commit = 1 + reset_none = 2 + + +_ResetStyleArgType = Union[ + ResetStyle, + Literal[True], + Literal[None], + Literal[False], + Literal["commit"], + Literal["rollback"], +] +reset_rollback, reset_commit, reset_none = list(ResetStyle) class _ConnDialect: @@ -45,22 +76,22 @@ class _ConnDialect: is_async = False - def do_rollback(self, dbapi_connection): + def do_rollback(self, dbapi_connection: PoolProxiedConnection) -> None: dbapi_connection.rollback() - def do_commit(self, dbapi_connection): + def do_commit(self, dbapi_connection: PoolProxiedConnection) -> None: dbapi_connection.commit() - def do_close(self, dbapi_connection): + def do_close(self, dbapi_connection: DBAPIConnection) -> None: dbapi_connection.close() - def do_ping(self, dbapi_connection): + def do_ping(self, dbapi_connection: DBAPIConnection) -> None: raise NotImplementedError( "The ping feature requires that a dialect is " "passed to the connection pool." ) - def get_driver_connection(self, connection): + def get_driver_connection(self, connection: DBAPIConnection) -> Any: return connection @@ -68,23 +99,40 @@ class _AsyncConnDialect(_ConnDialect): is_async = True -class Pool(log.Identified): +class _CreatorFnType(Protocol): + def __call__(self) -> DBAPIConnection: + ... + + +class _CreatorWRecFnType(Protocol): + def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: + ... + + +class Pool(log.Identified, event.EventTarget): """Abstract base class for connection pools.""" - _dialect = _ConnDialect() + dispatch: dispatcher[Pool] + echo: log._EchoFlagType + + _orig_logging_name: Optional[str] + _dialect: Union[_ConnDialect, Dialect] = _ConnDialect() + _creator_arg: Union[_CreatorFnType, _CreatorWRecFnType] + _invoke_creator: _CreatorWRecFnType + _invalidate_time: float def __init__( self, - creator, - recycle=-1, - echo=None, - logging_name=None, - reset_on_return=True, - events=None, - dialect=None, - pre_ping=False, - _dispatch=None, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + recycle: int = -1, + echo: log._EchoFlagType = None, + logging_name: Optional[str] = None, + reset_on_return: _ResetStyleArgType = True, + events: Optional[List[Tuple[_ListenerFnType, str]]] = None, + dialect: Optional[Union[_ConnDialect, Dialect]] = None, + pre_ping: bool = False, + _dispatch: Optional[_Dispatch[Pool]] = None, ): """ Construct a Pool. @@ -188,15 +236,14 @@ class Pool(log.Identified): self._recycle = recycle self._invalidate_time = 0 self._pre_ping = pre_ping - self._reset_on_return = util.symbol.parse_user_argument( + self._reset_on_return = util.parse_user_argument_for_enum( reset_on_return, { - reset_rollback: ["rollback", True], - reset_none: ["none", None, False], - reset_commit: ["commit"], + ResetStyle.reset_rollback: ["rollback", True], + ResetStyle.reset_none: ["none", None, False], + ResetStyle.reset_commit: ["commit"], }, "reset_on_return", - resolve_symbol_names=False, ) self.echo = echo @@ -210,19 +257,32 @@ class Pool(log.Identified): event.listen(self, target, fn) @util.hybridproperty - def _is_asyncio(self): + def _is_asyncio(self) -> bool: return self._dialect.is_async @property - def _creator(self): - return self.__dict__["_creator"] + def _creator(self) -> Union[_CreatorFnType, _CreatorWRecFnType]: + return self._creator_arg @_creator.setter - def _creator(self, creator): - self.__dict__["_creator"] = creator - self._invoke_creator = self._should_wrap_creator(creator) + def _creator( + self, creator: Union[_CreatorFnType, _CreatorWRecFnType] + ) -> None: + self._creator_arg = creator + + # mypy seems to get super confused assigning functions to + # attributes + self._invoke_creator = self._should_wrap_creator(creator) # type: ignore # noqa E501 + + @_creator.deleter + def _creator(self) -> None: + # needed for mock testing + del self._creator_arg + del self._invoke_creator # type: ignore[misc] - def _should_wrap_creator(self, creator): + def _should_wrap_creator( + self, creator: Union[_CreatorFnType, _CreatorWRecFnType] + ) -> _CreatorWRecFnType: """Detect if creator accepts a single argument, or is sent as a legacy style no-arg function. @@ -231,26 +291,30 @@ class Pool(log.Identified): try: argspec = util.get_callable_argspec(self._creator, no_self=True) except TypeError: - return lambda crec: creator() + creator_fn = cast(_CreatorFnType, creator) + return lambda rec: creator_fn() - defaulted = argspec[3] is not None and len(argspec[3]) or 0 + if argspec.defaults is not None: + defaulted = len(argspec.defaults) + else: + defaulted = 0 positionals = len(argspec[0]) - defaulted # look for the exact arg signature that DefaultStrategy # sends us if (argspec[0], argspec[3]) == (["connection_record"], (None,)): - return creator + return cast(_CreatorWRecFnType, creator) # or just a single positional elif positionals == 1: - return creator + return cast(_CreatorWRecFnType, creator) # all other cases, just wrap and assume legacy "creator" callable # thing else: - return lambda crec: creator() + creator_fn = cast(_CreatorFnType, creator) + return lambda rec: creator_fn() - def _close_connection(self, connection): + def _close_connection(self, connection: DBAPIConnection) -> None: self.logger.debug("Closing connection %r", connection) - try: self._dialect.do_close(connection) except Exception: @@ -258,12 +322,17 @@ class Pool(log.Identified): "Exception closing connection %r", connection, exc_info=True ) - def _create_connection(self): + def _create_connection(self) -> ConnectionPoolEntry: """Called by subclasses to create a new ConnectionRecord.""" return _ConnectionRecord(self) - def _invalidate(self, connection, exception=None, _checkin=True): + def _invalidate( + self, + connection: PoolProxiedConnection, + exception: Optional[BaseException] = None, + _checkin: bool = True, + ) -> None: """Mark all connections established within the generation of the given connection as invalidated. @@ -280,7 +349,7 @@ class Pool(log.Identified): if _checkin and getattr(connection, "is_valid", False): connection.invalidate(exception) - def recreate(self): + def recreate(self) -> Pool: """Return a new :class:`_pool.Pool`, of the same class as this one and configured with identical creation arguments. @@ -292,7 +361,7 @@ class Pool(log.Identified): raise NotImplementedError() - def dispose(self): + def dispose(self) -> None: """Dispose of this pool. This method leaves the possibility of checked-out connections @@ -307,7 +376,7 @@ class Pool(log.Identified): raise NotImplementedError() - def connect(self): + def connect(self) -> PoolProxiedConnection: """Return a DBAPI connection from the pool. The connection is instrumented such that when its @@ -317,7 +386,7 @@ class Pool(log.Identified): """ return _ConnectionFairy._checkout(self) - def _return_conn(self, record): + def _return_conn(self, record: ConnectionPoolEntry) -> None: """Given a _ConnectionRecord, return it to the :class:`_pool.Pool`. This method is called when an instrumented DBAPI connection @@ -326,100 +395,230 @@ class Pool(log.Identified): """ self._do_return_conn(record) - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: """Implementation for :meth:`get`, supplied by subclasses.""" raise NotImplementedError() - def _do_return_conn(self, conn): + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: """Implementation for :meth:`return_conn`, supplied by subclasses.""" raise NotImplementedError() - def status(self): + def status(self) -> str: raise NotImplementedError() -class _ConnectionRecord: +class ManagesConnection: + """Common base for the two connection-management interfaces + :class:`.PoolProxiedConnection` and :class:`.ConnectionPoolEntry`. - """Internal object which maintains an individual DBAPI connection - referenced by a :class:`_pool.Pool`. + These two objects are typically exposed in the public facing API + via the connection pool event hooks, documented at :class:`.PoolEvents`. - The :class:`._ConnectionRecord` object always exists for any particular - DBAPI connection whether or not that DBAPI connection has been - "checked out". This is in contrast to the :class:`._ConnectionFairy` - which is only a public facade to the DBAPI connection while it is checked - out. + .. versionadded:: 2.0 - A :class:`._ConnectionRecord` may exist for a span longer than that - of a single DBAPI connection. For example, if the - :meth:`._ConnectionRecord.invalidate` - method is called, the DBAPI connection associated with this - :class:`._ConnectionRecord` - will be discarded, but the :class:`._ConnectionRecord` may be used again, - in which case a new DBAPI connection is produced when the - :class:`_pool.Pool` - next uses this record. + """ - The :class:`._ConnectionRecord` is delivered along with connection - pool events, including :meth:`_events.PoolEvents.connect` and - :meth:`_events.PoolEvents.checkout`, however :class:`._ConnectionRecord` - still - remains an internal object whose API and internals may change. + __slots__ = () + + dbapi_connection: Optional[DBAPIConnection] + """A reference to the actual DBAPI connection being tracked. + + This is a :pep:`249`-compliant object that for traditional sync-style + dialects is provided by the third-party + DBAPI implementation in use. For asyncio dialects, the implementation + is typically an adapter object provided by the SQLAlchemy dialect + itself; the underlying asyncio object is available via the + :attr:`.ManagesConnection.driver_connection` attribute. + + SQLAlchemy's interface for the DBAPI connection is based on the + :class:`.DBAPIConnection` protocol object .. seealso:: - :class:`._ConnectionFairy` + :attr:`.ManagesConnection.driver_connection` + + :ref:`faq_dbapi_connection` """ - def __init__(self, pool, connect=True): - self.__pool = pool - if connect: - self.__connect() - self.finalize_callback = deque() + @property + def driver_connection(self) -> Optional[Any]: + """The "driver level" connection object as used by the Python + DBAPI or database driver. + + For traditional :pep:`249` DBAPI implementations, this object will + be the same object as that of + :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database + driver, this will be the ultimate "connection" object used by that + driver, such as the ``asyncpg.Connection`` object which will not have + standard pep-249 methods. + + .. versionadded:: 1.4.24 - fresh = False + .. seealso:: - fairy_ref = None + :attr:`.ManagesConnection.dbapi_connection` - starttime = None + :ref:`faq_dbapi_connection` - dbapi_connection = None - """A reference to the actual DBAPI connection being tracked. + """ + raise NotImplementedError() + + @util.dynamic_property + def info(self) -> Dict[str, Any]: + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.ManagesConnection` instance, allowing + user-defined data to be associated with the connection. + + The data in this dictionary is persistent for the lifespan + of the DBAPI connection itself, including across pool checkins + and checkouts. When the connection is invalidated + and replaced with a new one, this dictionary is cleared. + + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns a dictionary that is local to that + :class:`.ConnectionPoolEntry`. Therefore the + :attr:`.ManagesConnection.info` attribute will always provide a Python + dictionary. + + .. seealso:: - May be ``None`` if this :class:`._ConnectionRecord` has been marked - as invalidated; a new DBAPI connection may replace it if the owning - pool calls upon this :class:`._ConnectionRecord` to reconnect. + :attr:`.ManagesConnection.record_info` - For adapted drivers, like the Asyncio implementations, this is a - :class:`.AdaptedConnection` that adapts the driver connection - to the DBAPI protocol. - Use :attr:`._ConnectionRecord.driver_connection` to obtain the - connection objected returned by the driver. - .. versionadded:: 1.4.24 + """ + raise NotImplementedError() + + @util.dynamic_property + def record_info(self) -> Optional[Dict[str, Any]]: + """Persistent info dictionary associated with this + :class:`.ManagesConnection`. + + Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan + of this dictionary is that of the :class:`.ConnectionPoolEntry` + which owns it; therefore this dictionary will persist across + reconnects and connection invalidation for a particular entry + in the connection pool. + + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns None. Contrast to the :attr:`.ManagesConnection.info` + dictionary which is never None. + + + .. seealso:: + + :attr:`.ManagesConnection.info` + + """ + raise NotImplementedError() + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + """Mark the managed connection as invalidated. + + :param e: an exception object indicating a reason for the invalidation. + + :param soft: if True, the connection isn't closed; instead, this + connection will be recycled on next checkout. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + + """ + raise NotImplementedError() + + +class ConnectionPoolEntry(ManagesConnection): + """Interface for the object that maintains an individual database + connection on behalf of a :class:`_pool.Pool` instance. + + The :class:`.ConnectionPoolEntry` object represents the long term + maintainance of a particular connection for a pool, including expiring or + invalidating that connection to have it replaced with a new one, which will + continue to be maintained by that same :class:`.ConnectionPoolEntry` + instance. Compared to :class:`.PoolProxiedConnection`, which is the + short-term, per-checkout connection manager, this object lasts for the + lifespan of a particular "slot" within a connection pool. + + The :class:`.ConnectionPoolEntry` object is mostly visible to public-facing + API code when it is delivered to connection pool event hooks, such as + :meth:`_events.PoolEvents.connect` and :meth:`_events.PoolEvents.checkout`. + + .. versionadded:: 2.0 :class:`.ConnectionPoolEntry` provides the public + facing interface for the :class:`._ConnectionRecord` internal class. """ + __slots__ = () + @property - def driver_connection(self): - """The connection object as returned by the driver after a connect. + def in_use(self) -> bool: + """Return True the connection is currently checked out""" - For normal sync drivers that support the DBAPI protocol, this object - is the same as the one referenced by - :attr:`._ConnectionRecord.dbapi_connection`. + raise NotImplementedError() - For adapted drivers, like the Asyncio ones, this is the actual object - that was returned by the driver ``connect`` call. + def close(self) -> None: + """Close the DBAPI connection managed by this connection pool entry.""" + raise NotImplementedError() - As :attr:`._ConnectionRecord.dbapi_connection` it may be ``None`` - if this :class:`._ConnectionRecord` has been marked as invalidated. - .. versionadded:: 1.4.24 +class _ConnectionRecord(ConnectionPoolEntry): - """ + """Maintains a position in a connection pool which references a pooled + connection. + This is an internal object used by the :class:`_pool.Pool` implementation + to provide context management to a DBAPI connection maintained by + that :class:`_pool.Pool`. The public facing interface for this class + is described by the :class:`.ConnectionPoolEntry` class. See that + class for public API details. + + .. seealso:: + + :class:`.ConnectionPoolEntry` + + :class:`.PoolProxiedConnection` + + """ + + __slots__ = ( + "__pool", + "fairy_ref", + "finalize_callback", + "fresh", + "starttime", + "dbapi_connection", + "__weakref__", + "__dict__", + ) + + finalize_callback: Deque[Callable[[DBAPIConnection], None]] + fresh: bool + fairy_ref: Optional[weakref.ref[_ConnectionFairy]] + starttime: float + + def __init__(self, pool: Pool, connect: bool = True): + self.fresh = False + self.fairy_ref = None + self.starttime = 0 + self.dbapi_connection = None + + self.__pool = pool + if connect: + self.__connect() + self.finalize_callback = deque() + + dbapi_connection: Optional[DBAPIConnection] + + @property + def driver_connection(self) -> Optional[Any]: if self.dbapi_connection is None: return None else: @@ -428,72 +627,41 @@ class _ConnectionRecord: ) @property - def connection(self): - """An alias to :attr:`._ConnectionRecord.dbapi_connection`. - - This alias is deprecated, please use the new name. - - .. deprecated:: 1.4.24 - - """ + def connection(self) -> Optional[DBAPIConnection]: return self.dbapi_connection @connection.setter - def connection(self, value): + def connection(self, value: DBAPIConnection) -> None: self.dbapi_connection = value - _soft_invalidate_time = 0 + _soft_invalidate_time: float = 0 @util.memoized_property - def info(self): - """The ``.info`` dictionary associated with the DBAPI connection. - - This dictionary is shared among the :attr:`._ConnectionFairy.info` - and :attr:`_engine.Connection.info` accessors. - - .. note:: - - The lifespan of this dictionary is linked to the - DBAPI connection itself, meaning that it is **discarded** each time - the DBAPI connection is closed and/or invalidated. The - :attr:`._ConnectionRecord.record_info` dictionary remains - persistent throughout the lifespan of the - :class:`._ConnectionRecord` container. - - """ + def info(self) -> Dict[str, Any]: return {} @util.memoized_property - def record_info(self): - """An "info' dictionary associated with the connection record - itself. - - Unlike the :attr:`._ConnectionRecord.info` dictionary, which is linked - to the lifespan of the DBAPI connection, this dictionary is linked - to the lifespan of the :class:`._ConnectionRecord` container itself - and will remain persistent throughout the life of the - :class:`._ConnectionRecord`. - - .. versionadded:: 1.1 - - """ + def record_info(self) -> Optional[Dict[str, Any]]: return {} @classmethod - def checkout(cls, pool): - rec = pool._do_get() + def checkout(cls, pool: Pool) -> _ConnectionFairy: + rec = cast(_ConnectionRecord, pool._do_get()) try: dbapi_connection = rec.get_connection() except Exception as err: with util.safe_reraise(): rec._checkin_failed(err, _fairy_was_created=False) + raise + echo = pool._should_log_debug() - fairy = _ConnectionFairy(dbapi_connection, rec, echo) + fairy = _ConnectionFairy(pool, dbapi_connection, rec, echo) rec.fairy_ref = ref = weakref.ref( fairy, - lambda ref: _finalize_fairy - and _finalize_fairy(None, rec, pool, ref, echo, True), + lambda ref: _finalize_fairy(None, rec, pool, ref, echo, True) + if _finalize_fairy + else None, ) _strong_ref_connection_records[ref] = rec if echo: @@ -502,13 +670,15 @@ class _ConnectionRecord: ) return fairy - def _checkin_failed(self, err, _fairy_was_created=True): + def _checkin_failed( + self, err: Exception, _fairy_was_created: bool = True + ) -> None: self.invalidate(e=err) self.checkin( _fairy_was_created=_fairy_was_created, ) - def checkin(self, _fairy_was_created=True): + def checkin(self, _fairy_was_created: bool = True) -> None: if self.fairy_ref is None and _fairy_was_created: # _fairy_was_created is False for the initial get connection phase; # meaning there was no _ConnectionFairy and we must unconditionally @@ -524,47 +694,28 @@ class _ConnectionRecord: pool = self.__pool while self.finalize_callback: finalizer = self.finalize_callback.pop() - finalizer(connection) + if connection is not None: + finalizer(connection) if pool.dispatch.checkin: pool.dispatch.checkin(connection, self) pool._return_conn(self) @property - def in_use(self): + def in_use(self) -> bool: return self.fairy_ref is not None @property - def last_connect_time(self): + def last_connect_time(self) -> float: return self.starttime - def close(self): + def close(self) -> None: if self.dbapi_connection is not None: self.__close() - def invalidate(self, e=None, soft=False): - """Invalidate the DBAPI connection held by this - :class:`._ConnectionRecord`. - - This method is called for all connection invalidations, including - when the :meth:`._ConnectionFairy.invalidate` or - :meth:`_engine.Connection.invalidate` methods are called, - as well as when any - so-called "automatic invalidation" condition occurs. - - :param e: an exception object indicating a reason for the - invalidation. - - :param soft: if True, the connection isn't closed; instead, this - connection will be recycled on next checkout. - - .. versionadded:: 1.0.3 - - .. seealso:: - - :ref:`pool_connection_invalidation` - - """ + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: # already invalidated if self.dbapi_connection is None: return @@ -595,7 +746,7 @@ class _ConnectionRecord: self.__close() self.dbapi_connection = None - def get_connection(self): + def get_connection(self) -> DBAPIConnection: recycle = False # NOTE: the various comparisons here are assuming that measurable time @@ -610,8 +761,9 @@ class _ConnectionRecord: # within 16 milliseconds accuracy, so unit tests for connection # invalidation need a sleep of at least this long between initial start # time and invalidation for the logic below to work reliably. + if self.dbapi_connection is None: - self.info.clear() + self.info.clear() # type: ignore # our info is always present self.__connect() elif ( self.__pool._recycle > -1 @@ -639,26 +791,29 @@ class _ConnectionRecord: if recycle: self.__close() - self.info.clear() + self.info.clear() # type: ignore # our info is always present self.__connect() + + assert self.dbapi_connection is not None return self.dbapi_connection - def _is_hard_or_soft_invalidated(self): + def _is_hard_or_soft_invalidated(self) -> bool: return ( self.dbapi_connection is None or self.__pool._invalidate_time > self.starttime or (self._soft_invalidate_time > self.starttime) ) - def __close(self): + def __close(self) -> None: self.finalize_callback.clear() if self.__pool.dispatch.close: self.__pool.dispatch.close(self.dbapi_connection, self) + assert self.dbapi_connection is not None self.__pool._close_connection(self.dbapi_connection) self.dbapi_connection = None - def __connect(self): + def __connect(self) -> None: pool = self.__pool # ensure any existing connection is removed, so that if @@ -688,14 +843,16 @@ class _ConnectionRecord: def _finalize_fairy( - dbapi_connection, - connection_record, - pool, - ref, # this is None when called directly, not by the gc - echo, - reset=True, - fairy=None, -): + dbapi_connection: Optional[DBAPIConnection], + connection_record: Optional[_ConnectionRecord], + pool: Pool, + ref: Optional[ + weakref.ref[_ConnectionFairy] + ], # this is None when called directly, not by the gc + echo: Optional[log._EchoFlagType], + reset: bool = True, + fairy: Optional[_ConnectionFairy] = None, +) -> None: """Cleanup for a :class:`._ConnectionFairy` whether or not it's already been garbage collected. @@ -705,12 +862,16 @@ def _finalize_fairy( will only log a message and raise a warning. """ - if ref: + is_gc_cleanup = ref is not None + + if is_gc_cleanup: + assert ref is not None _strong_ref_connection_records.pop(ref, None) elif fairy: _strong_ref_connection_records.pop(weakref.ref(fairy), None) - if ref is not None: + if is_gc_cleanup: + assert connection_record is not None if connection_record.fairy_ref is not ref: return assert dbapi_connection is None @@ -720,10 +881,10 @@ def _finalize_fairy( dont_restore_gced = pool._dialect.is_async if dont_restore_gced: - detach = not connection_record or ref - can_manipulate_connection = not ref + detach = connection_record is None or is_gc_cleanup + can_manipulate_connection = ref is None else: - detach = not connection_record + detach = connection_record is None can_manipulate_connection = True if dbapi_connection is not None: @@ -737,11 +898,14 @@ def _finalize_fairy( ) try: - fairy = fairy or _ConnectionFairy( - dbapi_connection, - connection_record, - echo, - ) + if not fairy: + assert connection_record is not None + fairy = _ConnectionFairy( + pool, + dbapi_connection, + connection_record, + echo, + ) assert fairy.dbapi_connection is dbapi_connection if reset and can_manipulate_connection: fairy._reset(pool) @@ -786,6 +950,7 @@ def _finalize_fairy( # test/engine/test_pool.py::PoolEventsTest::test_checkin_event_gc[True] # which actually started failing when pytest warnings plugin was # turned on, due to util.warn() above + fairy.dbapi_connection = fairy._connection_record = None # type: ignore del dbapi_connection del connection_record del fairy @@ -795,53 +960,36 @@ def _finalize_fairy( # GC under pypy will call ConnectionFairy finalizers. linked directly to the # weakref that will empty itself when collected so that it should not create # any unmanaged memory references. -_strong_ref_connection_records = {} +_strong_ref_connection_records: Dict[ + weakref.ref[_ConnectionFairy], _ConnectionRecord +] = {} -class PoolProxiedConnection: - """interface for the wrapper connection that is used by the connection - pool. +class PoolProxiedConnection(ManagesConnection): + """A connection-like adapter for a :pep:`249` DBAPI connection, which + includes additional methods specific to the :class:`.Pool` implementation. - :class:`.PoolProxiedConnection` is basically the public-facing interface - for the :class:`._ConnectionFairy` implementation object, users familiar - with :class:`._ConnectionFairy` can consider this object to be - equivalent. + :class:`.PoolProxiedConnection` is the public-facing interface for the + internal :class:`._ConnectionFairy` implementation object; users familiar + with :class:`._ConnectionFairy` can consider this object to be equivalent. - .. versionadded:: 2.0 + .. versionadded:: 2.0 :class:`.PoolProxiedConnection` provides the public- + facing interface for the :class:`._ConnectionFairy` internal class. """ __slots__ = () - @util.memoized_property - def dbapi_connection(self) -> "DBAPIConnection": - """A reference to the actual DBAPI connection being tracked. + if typing.TYPE_CHECKING: - .. seealso:: + def commit(self) -> None: + ... - :attr:`.PoolProxiedConnection.driver_connection` + def cursor(self) -> DBAPICursor: + ... - :attr:`.PoolProxiedConnection.dbapi_connection` - - :ref:`faq_dbapi_connection` - - """ - raise NotImplementedError() - - @property - def driver_connection(self) -> Any: - """The connection object as returned by the driver after a connect. - - .. seealso:: - - :attr:`.PoolProxiedConnection.dbapi_connection` - - :attr:`._ConnectionRecord.driver_connection` - - :ref:`faq_dbapi_connection` - - """ - raise NotImplementedError() + def rollback(self) -> None: + ... @property def is_valid(self) -> bool: @@ -850,62 +998,11 @@ class PoolProxiedConnection: raise NotImplementedError() - @util.memoized_property - def info(self) -> Dict[str, Any]: - """Info dictionary associated with the underlying DBAPI connection - referred to by this :class:`.ConnectionFairy`, allowing user-defined - data to be associated with the connection. - - The data here will follow along with the DBAPI connection including - after it is returned to the connection pool and used again - in subsequent instances of :class:`._ConnectionFairy`. It is shared - with the :attr:`._ConnectionRecord.info` and - :attr:`_engine.Connection.info` - accessors. - - The dictionary associated with a particular DBAPI connection is - discarded when the connection itself is discarded. - - """ - - raise NotImplementedError() - @property - def record_info(self) -> Dict[str, Any]: - """Info dictionary associated with the :class:`._ConnectionRecord - container referred to by this :class:`.PoolProxiedConnection`. - - Unlike the :attr:`.PoolProxiedConnection.info` dictionary, the lifespan - of this dictionary is persistent across connections that are - disconnected and/or invalidated within the lifespan of a - :class:`._ConnectionRecord`. - - """ - - raise NotImplementedError() + def is_detached(self) -> bool: + """Return True if this :class:`.PoolProxiedConnection` is detached + from its pool.""" - def invalidate( - self, e: Optional[Exception] = None, soft: bool = False - ) -> None: - """Mark this connection as invalidated. - - This method can be called directly, and is also called as a result - of the :meth:`_engine.Connection.invalidate` method. When invoked, - the DBAPI connection is immediately closed and discarded from - further use by the pool. The invalidation mechanism proceeds - via the :meth:`._ConnectionRecord.invalidate` internal method. - - :param e: an exception object indicating a reason for the invalidation. - - :param soft: if True, the connection isn't closed; instead, this - connection will be recycled on next checkout. - - .. seealso:: - - :ref:`pool_connection_invalidation` - - - """ raise NotImplementedError() def detach(self) -> None: @@ -913,8 +1010,8 @@ class PoolProxiedConnection: This means that the connection will no longer be returned to the pool when closed, and will instead be literally closed. The - containing ConnectionRecord is separated from the DB-API connection, - and will create a new connection when next used. + associated :class:`.ConnectionPoolEntry` is de-associated from this + DBAPI connection. Note that any overall connection limiting constraints imposed by a Pool implementation may be violated after a detach, as the detached @@ -953,43 +1050,37 @@ class _AdhocProxiedConnection(PoolProxiedConnection): __slots__ = ("dbapi_connection", "_connection_record") - def __init__(self, dbapi_connection, connection_record): + dbapi_connection: DBAPIConnection + _connection_record: ConnectionPoolEntry + + def __init__( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ): self.dbapi_connection = dbapi_connection self._connection_record = connection_record @property - def driver_connection(self): + def driver_connection(self) -> Any: return self._connection_record.driver_connection @property - def connection(self): - """An alias to :attr:`._ConnectionFairy.dbapi_connection`. - - This alias is deprecated, please use the new name. - - .. deprecated:: 1.4.24 - - """ - return self._dbapi_connection + def connection(self) -> DBAPIConnection: + return self.dbapi_connection @property - def is_valid(self): + def is_valid(self) -> bool: raise AttributeError("is_valid not implemented by this proxy") - @property - def record_info(self): + @util.dynamic_property + def record_info(self) -> Optional[Dict[str, Any]]: return self._connection_record.record_info - def cursor(self, *args, **kwargs): - """Return a new DBAPI cursor for the underlying connection. - - This method is a proxy for the ``connection.cursor()`` DBAPI - method. - - """ + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: return self.dbapi_connection.cursor(*args, **kwargs) - def __getattr__(self, key): + def __getattr__(self, key: Any) -> Any: return getattr(self.dbapi_connection, key) @@ -1001,7 +1092,8 @@ class _ConnectionFairy(PoolProxiedConnection): This is an internal object used by the :class:`_pool.Pool` implementation to provide context management to a DBAPI connection delivered by that :class:`_pool.Pool`. The public facing interface for this class - is described by the :class:`.PoolProxiedConnection` class. + is described by the :class:`.PoolProxiedConnection` class. See that + class for public API details. The name "fairy" is inspired by the fact that the :class:`._ConnectionFairy` object's lifespan is transitory, as it lasts @@ -1011,68 +1103,76 @@ class _ConnectionFairy(PoolProxiedConnection): .. seealso:: - :class:`._ConnectionRecord` - - """ + :class:`.PoolProxiedConnection` - def __init__(self, dbapi_connection, connection_record, echo): - self.dbapi_connection = dbapi_connection - self._connection_record = connection_record - self._echo = echo + :class:`.ConnectionPoolEntry` - _connection_record = None - """A reference to the :class:`._ConnectionRecord` object associated - with the DBAPI connection. - - This is currently an internal accessor which is subject to change. """ - @property - def driver_connection(self): - """The connection object as returned by the driver after a connect. + __slots__ = ( + "dbapi_connection", + "_connection_record", + "_echo", + "_pool", + "_counter", + "__weakref__", + "__dict__", + ) - .. versionadded:: 1.4.24 - - .. seealso:: + pool: Pool + dbapi_connection: DBAPIConnection + _echo: log._EchoFlagType - :attr:`._ConnectionFairy.dbapi_connection` - - :attr:`._ConnectionRecord.driver_connection` + def __init__( + self, + pool: Pool, + dbapi_connection: DBAPIConnection, + connection_record: _ConnectionRecord, + echo: log._EchoFlagType, + ): + self._pool = pool + self._counter = 0 + self.dbapi_connection = dbapi_connection + self._connection_record = connection_record + self._echo = echo - :ref:`faq_dbapi_connection` + _connection_record: Optional[_ConnectionRecord] - """ + @property + def driver_connection(self) -> Optional[Any]: + if self._connection_record is None: + return None return self._connection_record.driver_connection @property - def connection(self): - """An alias to :attr:`._ConnectionFairy.dbapi_connection`. - - This alias is deprecated, please use the new name. - - .. deprecated:: 1.4.24 - - """ + def connection(self) -> DBAPIConnection: return self.dbapi_connection @connection.setter - def connection(self, value): + def connection(self, value: DBAPIConnection) -> None: self.dbapi_connection = value @classmethod - def _checkout(cls, pool, threadconns=None, fairy=None): + def _checkout( + cls, + pool: Pool, + threadconns: Optional[threading.local] = None, + fairy: Optional[_ConnectionFairy] = None, + ) -> _ConnectionFairy: if not fairy: fairy = _ConnectionRecord.checkout(pool) - fairy._pool = pool - fairy._counter = 0 - if threadconns is not None: threadconns.current = weakref.ref(fairy) - if fairy.dbapi_connection is None: - raise exc.InvalidRequestError("This connection is closed") + assert ( + fairy._connection_record is not None + ), "can't 'checkout' a detached connection fairy" + assert ( + fairy.dbapi_connection is not None + ), "can't 'checkout' an invalidated connection fairy" + fairy._counter += 1 if ( not pool.dispatch.checkout and not pool._pre_ping @@ -1084,6 +1184,7 @@ class _ConnectionFairy(PoolProxiedConnection): # there are three attempts made here, but note that if the database # is not accessible from a connection standpoint, those won't proceed # here. + attempts = 2 while attempts > 0: connection_is_fresh = fairy._connection_record.fresh @@ -1160,10 +1261,10 @@ class _ConnectionFairy(PoolProxiedConnection): fairy.invalidate() raise exc.InvalidRequestError("This connection is closed") - def _checkout_existing(self): + def _checkout_existing(self) -> _ConnectionFairy: return _ConnectionFairy._checkout(self._pool, fairy=self) - def _checkin(self, reset=True): + def _checkin(self, reset: bool = True) -> None: _finalize_fairy( self.dbapi_connection, self._connection_record, @@ -1173,14 +1274,13 @@ class _ConnectionFairy(PoolProxiedConnection): reset=reset, fairy=self, ) - self.dbapi_connection = None - self._connection_record = None - _close = _checkin + def _close(self) -> None: + self._checkin() - def _reset(self, pool): + def _reset(self, pool: Pool) -> None: if pool.dispatch.reset: - pool.dispatch.reset(self, self._connection_record) + pool.dispatch.reset(self.dbapi_connection, self._connection_record) if pool._reset_on_return is reset_rollback: if self._echo: pool.logger.debug( @@ -1196,50 +1296,34 @@ class _ConnectionFairy(PoolProxiedConnection): pool._dialect.do_commit(self) @property - def _logger(self): + def _logger(self) -> log._IdentifiedLoggerType: return self._pool.logger @property - def is_valid(self): - """Return True if this :class:`._ConnectionFairy` still refers - to an active DBAPI connection.""" - + def is_valid(self) -> bool: return self.dbapi_connection is not None - @util.memoized_property - def info(self): - """Info dictionary associated with the underlying DBAPI connection - referred to by this :class:`.ConnectionFairy`, allowing user-defined - data to be associated with the connection. - - See :attr:`.PoolProxiedConnection.info` for full description. - - """ - return self._connection_record.info - @property - def record_info(self): - """Info dictionary associated with the :class:`._ConnectionRecord - container referred to by this :class:`.ConnectionFairy`. + def is_detached(self) -> bool: + return self._connection_record is not None - See :attr:`.PoolProxiedConnection.record_info` for full description. - - """ - if self._connection_record: - return self._connection_record.record_info + @util.memoized_property + def info(self) -> Dict[str, Any]: + if self._connection_record is None: + return {} else: - return None - - def invalidate(self, e=None, soft=False): - """Mark this connection as invalidated. - - See :attr:`.PoolProxiedConnection.invalidate` for full description. - - .. seealso:: + return self._connection_record.info - :ref:`pool_connection_invalidation` + @util.dynamic_property + def record_info(self) -> Optional[Dict[str, Any]]: + if self._connection_record is None: + return None + else: + return self._connection_record.record_info - """ + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: if self.dbapi_connection is None: util.warn("Can't invalidate an already-closed connection.") @@ -1247,51 +1331,43 @@ class _ConnectionFairy(PoolProxiedConnection): if self._connection_record: self._connection_record.invalidate(e=e, soft=soft) if not soft: - self.dbapi_connection = None - self._checkin() - - def cursor(self, *args, **kwargs): - """Return a new DBAPI cursor for the underlying connection. + # prevent any rollback / reset actions etc. on + # the connection + self.dbapi_connection = None # type: ignore - This method is a proxy for the ``connection.cursor()`` DBAPI - method. + # finalize + self._checkin() - """ + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: + assert self.dbapi_connection is not None return self.dbapi_connection.cursor(*args, **kwargs) - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: return getattr(self.dbapi_connection, key) - def detach(self): - """Separate this connection from its Pool. - - See :meth:`.PoolProxiedConnection.detach` for full description. - - """ - + def detach(self) -> None: if self._connection_record is not None: rec = self._connection_record rec.fairy_ref = None rec.dbapi_connection = None # TODO: should this be _return_conn? self._pool._do_return_conn(self._connection_record) - self.info = self.info.copy() + + # can't get the descriptor assignment to work here + # in pylance. mypy is OK w/ it + self.info = self.info.copy() # type: ignore + self._connection_record = None if self._pool.dispatch.detach: self._pool.dispatch.detach(self.dbapi_connection, rec) - def close(self): - """Release this connection back to the pool. - - See :meth:`.PoolProxiedConnection.close` for full description. - - """ + def close(self) -> None: self._counter -= 1 if self._counter == 0: self._checkin() - def _close_no_reset(self): + def _close_no_reset(self) -> None: self._counter -= 1 if self._counter == 0: self._checkin(reset=False) diff --git a/lib/sqlalchemy/pool/events.py b/lib/sqlalchemy/pool/events.py index e53d614b0e..d0d89291bc 100644 --- a/lib/sqlalchemy/pool/events.py +++ b/lib/sqlalchemy/pool/events.py @@ -4,13 +4,26 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations +import typing +from typing import Any +from typing import Optional +from typing import Type +from typing import Union + +from .base import ConnectionPoolEntry from .base import Pool +from .base import PoolProxiedConnection from .. import event from .. import util +if typing.TYPE_CHECKING: + from ..engine import Engine + from ..engine.interfaces import DBAPIConnection + -class PoolEvents(event.Events): +class PoolEvents(event.Events[Pool]): """Available events for :class:`_pool.Pool`. The methods here define the name of an event as well @@ -37,35 +50,48 @@ class PoolEvents(event.Events): # will associate with engine.pool event.listen(engine, 'checkout', my_on_checkout) - """ # noqa + """ # noqa E501 _target_class_doc = "SomeEngineOrPool" _dispatch_target = Pool @util.preload_module("sqlalchemy.engine") @classmethod - def _accept_with(cls, target): - Engine = util.preloaded.engine.Engine + def _accept_with( + cls, target: Union[Pool, Type[Pool], Engine, Type[Engine]] + ) -> Union[Pool, Type[Pool]]: + if not typing.TYPE_CHECKING: + Engine = util.preloaded.engine.Engine if isinstance(target, type): if issubclass(target, Engine): return Pool - elif issubclass(target, Pool): + else: + assert issubclass(target, Pool) return target elif isinstance(target, Engine): return target.pool else: + assert isinstance(target, Pool) return target @classmethod - def _listen(cls, event_key, **kw): + def _listen( # type: ignore[override] # would rather keep **kw + cls, + event_key: event._EventKey[Pool], + **kw: Any, + ) -> None: target = event_key.dispatch_target kw.setdefault("asyncio", target._is_asyncio) event_key.base_listen(**kw) - def connect(self, dbapi_connection, connection_record): + def connect( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: """Called at the moment a particular DBAPI connection is first created for a given :class:`_pool.Pool`. @@ -74,14 +100,18 @@ class PoolEvents(event.Events): to produce a new DBAPI connection. :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. """ - def first_connect(self, dbapi_connection, connection_record): + def first_connect( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: """Called exactly once for the first time a DBAPI connection is checked out from a particular :class:`_pool.Pool`. @@ -99,24 +129,29 @@ class PoolEvents(event.Events): encoding settings, collation settings, and many others. :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. """ - def checkout(self, dbapi_connection, connection_record, connection_proxy): + def checkout( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + connection_proxy: PoolProxiedConnection, + ) -> None: """Called when a connection is retrieved from the Pool. :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. - :param connection_proxy: the :class:`._ConnectionFairy` object which - will proxy the public interface of the DBAPI connection for the + :param connection_proxy: the :class:`.PoolProxiedConnection` object + which will proxy the public interface of the DBAPI connection for the lifespan of the checkout. If you raise a :class:`~sqlalchemy.exc.DisconnectionError`, the current @@ -130,7 +165,11 @@ class PoolEvents(event.Events): """ - def checkin(self, dbapi_connection, connection_record): + def checkin( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: """Called when a connection returns to the pool. Note that the connection may be closed, and may be None if the @@ -138,14 +177,18 @@ class PoolEvents(event.Events): for detached connections. (They do not return to the pool.) :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. """ - def reset(self, dbapi_connection, connection_record): + def reset( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: """Called before the "reset" action occurs for a pooled connection. This event represents @@ -160,10 +203,10 @@ class PoolEvents(event.Events): cases where the connection is discarded immediately after reset. :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. .. seealso:: @@ -173,21 +216,26 @@ class PoolEvents(event.Events): """ - def invalidate(self, dbapi_connection, connection_record, exception): + def invalidate( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + exception: Optional[BaseException], + ) -> None: """Called when a DBAPI connection is to be "invalidated". - This event is called any time the :meth:`._ConnectionRecord.invalidate` - method is invoked, either from API usage or via "auto-invalidation", - without the ``soft`` flag. + This event is called any time the + :meth:`.ConnectionPoolEntry.invalidate` method is invoked, either from + API usage or via "auto-invalidation", without the ``soft`` flag. The event occurs before a final attempt to call ``.close()`` on the connection occurs. :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. :param exception: the exception object corresponding to the reason for this invalidation, if any. May be ``None``. @@ -201,10 +249,16 @@ class PoolEvents(event.Events): """ - def soft_invalidate(self, dbapi_connection, connection_record, exception): + def soft_invalidate( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + exception: Optional[BaseException], + ) -> None: """Called when a DBAPI connection is to be "soft invalidated". - This event is called any time the :meth:`._ConnectionRecord.invalidate` + This event is called any time the + :meth:`.ConnectionPoolEntry.invalidate` method is invoked with the ``soft`` flag. Soft invalidation refers to when the connection record that tracks @@ -215,17 +269,21 @@ class PoolEvents(event.Events): .. versionadded:: 1.0.3 :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. :param exception: the exception object corresponding to the reason for this invalidation, if any. May be ``None``. """ - def close(self, dbapi_connection, connection_record): + def close( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: """Called when a DBAPI connection is closed. The event is emitted before the close occurs. @@ -241,14 +299,18 @@ class PoolEvents(event.Events): .. versionadded:: 1.1 :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. """ - def detach(self, dbapi_connection, connection_record): + def detach( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: """Called when a DBAPI connection is "detached" from a pool. This event is emitted after the detach occurs. The connection @@ -257,14 +319,14 @@ class PoolEvents(event.Events): .. versionadded:: 1.1 :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. - :param connection_record: the :class:`._ConnectionRecord` managing the - DBAPI connection. + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. """ - def close_detached(self, dbapi_connection): + def close_detached(self, dbapi_connection: DBAPIConnection) -> None: """Called when a detached DBAPI connection is closed. The event is emitted before the close occurs. @@ -276,6 +338,6 @@ class PoolEvents(event.Events): .. versionadded:: 1.1 :param dbapi_connection: a DBAPI connection. - The :attr:`._ConnectionRecord.dbapi_connection` attribute. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. """ diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index 7a422cd2ac..d1be3f5419 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -9,19 +9,36 @@ """Pool implementation classes. """ +from __future__ import annotations import threading import traceback +import typing +from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from typing import Union import weakref from .base import _AsyncConnDialect from .base import _ConnectionFairy from .base import _ConnectionRecord +from .base import _CreatorFnType +from .base import _CreatorWRecFnType +from .base import ConnectionPoolEntry from .base import Pool +from .base import PoolProxiedConnection from .. import exc from .. import util from ..util import chop_traceback from ..util import queue as sqla_queue +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from ..engine.interfaces import DBAPIConnection class QueuePool(Pool): @@ -34,17 +51,22 @@ class QueuePool(Pool): """ - _is_asyncio = False - _queue_class = sqla_queue.Queue + _is_asyncio = False # type: ignore[assignment] + + _queue_class: Type[ + sqla_queue.QueueCommon[ConnectionPoolEntry] + ] = sqla_queue.Queue + + _pool: sqla_queue.QueueCommon[ConnectionPoolEntry] def __init__( self, - creator, - pool_size=5, - max_overflow=10, - timeout=30.0, - use_lifo=False, - **kw, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + pool_size: int = 5, + max_overflow: int = 10, + timeout: float = 30.0, + use_lifo: bool = False, + **kw: Any, ): r""" Construct a QueuePool. @@ -107,20 +129,20 @@ class QueuePool(Pool): self._timeout = timeout self._overflow_lock = threading.Lock() - def _do_return_conn(self, conn): + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: try: - self._pool.put(conn, False) + self._pool.put(record, False) except sqla_queue.Full: try: - conn.close() + record.close() finally: self._dec_overflow() - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: use_overflow = self._max_overflow > -1 + wait = use_overflow and self._overflow >= self._max_overflow try: - wait = use_overflow and self._overflow >= self._max_overflow return self._pool.get(wait, self._timeout) except sqla_queue.Empty: # don't do things inside of "except Empty", because when we say @@ -144,10 +166,11 @@ class QueuePool(Pool): except: with util.safe_reraise(): self._dec_overflow() + raise else: return self._do_get() - def _inc_overflow(self): + def _inc_overflow(self) -> bool: if self._max_overflow == -1: self._overflow += 1 return True @@ -158,7 +181,7 @@ class QueuePool(Pool): else: return False - def _dec_overflow(self): + def _dec_overflow(self) -> Literal[True]: if self._max_overflow == -1: self._overflow -= 1 return True @@ -166,7 +189,7 @@ class QueuePool(Pool): self._overflow -= 1 return True - def recreate(self): + def recreate(self) -> QueuePool: self.logger.info("Pool recreating") return self.__class__( self._creator, @@ -183,7 +206,7 @@ class QueuePool(Pool): dialect=self._dialect, ) - def dispose(self): + def dispose(self) -> None: while True: try: conn = self._pool.get(False) @@ -194,7 +217,7 @@ class QueuePool(Pool): self._overflow = 0 - self.size() self.logger.info("Pool disposed. %s", self.status()) - def status(self): + def status(self) -> str: return ( "Pool size: %d Connections in pool: %d " "Current Overflow: %d Current Checked out " @@ -207,25 +230,28 @@ class QueuePool(Pool): ) ) - def size(self): + def size(self) -> int: return self._pool.maxsize - def timeout(self): + def timeout(self) -> float: return self._timeout - def checkedin(self): + def checkedin(self) -> int: return self._pool.qsize() - def overflow(self): + def overflow(self) -> int: return self._overflow - def checkedout(self): + def checkedout(self) -> int: return self._pool.maxsize - self._pool.qsize() + self._overflow class AsyncAdaptedQueuePool(QueuePool): - _is_asyncio = True - _queue_class = sqla_queue.AsyncAdaptedQueue + _is_asyncio = True # type: ignore[assignment] + _queue_class: Type[ + sqla_queue.QueueCommon[ConnectionPoolEntry] + ] = sqla_queue.AsyncAdaptedQueue + _dialect = _AsyncConnDialect() @@ -246,16 +272,16 @@ class NullPool(Pool): """ - def status(self): + def status(self) -> str: return "NullPool" - def _do_return_conn(self, conn): - conn.close() + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + record.close() - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: return self._create_connection() - def recreate(self): + def recreate(self) -> NullPool: self.logger.info("Pool recreating") return self.__class__( @@ -269,7 +295,7 @@ class NullPool(Pool): dialect=self._dialect, ) - def dispose(self): + def dispose(self) -> None: pass @@ -304,16 +330,21 @@ class SingletonThreadPool(Pool): """ - _is_asyncio = False + _is_asyncio = False # type: ignore[assignment] - def __init__(self, creator, pool_size=5, **kw): + def __init__( + self, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + pool_size: int = 5, + **kw: Any, + ): Pool.__init__(self, creator, **kw) self._conn = threading.local() self._fairy = threading.local() - self._all_conns = set() + self._all_conns: Set[ConnectionPoolEntry] = set() self.size = pool_size - def recreate(self): + def recreate(self) -> SingletonThreadPool: self.logger.info("Pool recreating") return self.__class__( self._creator, @@ -327,7 +358,7 @@ class SingletonThreadPool(Pool): dialect=self._dialect, ) - def dispose(self): + def dispose(self) -> None: """Dispose of this pool.""" for conn in self._all_conns: @@ -340,23 +371,26 @@ class SingletonThreadPool(Pool): self._all_conns.clear() - def _cleanup(self): + def _cleanup(self) -> None: while len(self._all_conns) >= self.size: c = self._all_conns.pop() c.close() - def status(self): + def status(self) -> str: return "SingletonThreadPool id:%d size: %d" % ( id(self), len(self._all_conns), ) - def _do_return_conn(self, conn): - pass + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + try: + del self._fairy.current # type: ignore + except AttributeError: + pass - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: try: - c = self._conn.current() + c = cast(ConnectionPoolEntry, self._conn.current()) if c: return c except AttributeError: @@ -368,11 +402,11 @@ class SingletonThreadPool(Pool): self._all_conns.add(c) return c - def connect(self): + def connect(self) -> PoolProxiedConnection: # vendored from Pool to include the now removed use_threadlocal # behavior try: - rec = self._fairy.current() + rec = cast(_ConnectionFairy, self._fairy.current()) except AttributeError: pass else: @@ -381,13 +415,6 @@ class SingletonThreadPool(Pool): return _ConnectionFairy._checkout(self, self._fairy) - def _return_conn(self, record): - try: - del self._fairy.current - except AttributeError: - pass - self._do_return_conn(record) - class StaticPool(Pool): @@ -401,13 +428,13 @@ class StaticPool(Pool): """ @util.memoized_property - def connection(self): + def connection(self) -> _ConnectionRecord: return _ConnectionRecord(self) - def status(self): + def status(self) -> str: return "StaticPool" - def dispose(self): + def dispose(self) -> None: if ( "connection" in self.__dict__ and self.connection.dbapi_connection is not None @@ -415,7 +442,7 @@ class StaticPool(Pool): self.connection.close() del self.__dict__["connection"] - def recreate(self): + def recreate(self) -> StaticPool: self.logger.info("Pool recreating") return self.__class__( creator=self._creator, @@ -428,20 +455,23 @@ class StaticPool(Pool): dialect=self._dialect, ) - def _transfer_from(self, other_static_pool): + def _transfer_from(self, other_static_pool: StaticPool) -> None: # used by the test suite to make a new engine / pool without # losing the state of an existing SQLite :memory: connection - self._invoke_creator = ( - lambda crec: other_static_pool.connection.dbapi_connection - ) + def creator(rec: ConnectionPoolEntry) -> DBAPIConnection: + conn = other_static_pool.connection.dbapi_connection + assert conn is not None + return conn - def _create_connection(self): + self._invoke_creator = creator + + def _create_connection(self) -> ConnectionPoolEntry: raise NotImplementedError() - def _do_return_conn(self, conn): + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: pass - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: rec = self.connection if rec._is_hard_or_soft_invalidated(): del self.__dict__["connection"] @@ -461,28 +491,31 @@ class AssertionPool(Pool): """ - def __init__(self, *args, **kw): + _conn: Optional[ConnectionPoolEntry] + _checkout_traceback: Optional[List[str]] + + def __init__(self, *args: Any, **kw: Any): self._conn = None self._checked_out = False self._store_traceback = kw.pop("store_traceback", True) self._checkout_traceback = None Pool.__init__(self, *args, **kw) - def status(self): + def status(self) -> str: return "AssertionPool" - def _do_return_conn(self, conn): + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: if not self._checked_out: raise AssertionError("connection is not checked out") self._checked_out = False - assert conn is self._conn + assert record is self._conn - def dispose(self): + def dispose(self) -> None: self._checked_out = False if self._conn: self._conn.close() - def recreate(self): + def recreate(self) -> AssertionPool: self.logger.info("Pool recreating") return self.__class__( self._creator, @@ -495,7 +528,7 @@ class AssertionPool(Pool): dialect=self._dialect, ) - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: if self._checked_out: if self._checkout_traceback: suffix = " at:\n%s" % "".join( diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 85bbca20f5..a414205045 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -94,6 +94,7 @@ from .langhelpers import decode_slice as decode_slice from .langhelpers import decorator as decorator from .langhelpers import dictlike_iteritems as dictlike_iteritems from .langhelpers import duck_type_collection as duck_type_collection +from .langhelpers import dynamic_property as dynamic_property from .langhelpers import ellipses_string as ellipses_string from .langhelpers import EnsureKWArg as EnsureKWArg from .langhelpers import format_argspec_init as format_argspec_init @@ -122,6 +123,9 @@ from .langhelpers import ( ) from .langhelpers import NoneType as NoneType from .langhelpers import only_once as only_once +from .langhelpers import ( + parse_user_argument_for_enum as parse_user_argument_for_enum, +) from .langhelpers import PluginLoader as PluginLoader from .langhelpers import portable_instancemethod as portable_instancemethod from .langhelpers import quoted_token_parser as quoted_token_parser diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index fa20667027..b17b408dd7 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -10,13 +10,37 @@ from contextvars import copy_context as _copy_context import sys import typing from typing import Any +from typing import Awaitable from typing import Callable from typing import Coroutine - -import greenlet # type: ignore # noqa +from typing import TypeVar from .langhelpers import memoized_property from .. import exc +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + + class greenlet(Protocol): + + dead: bool + + def __init__(self, fn: Callable[..., Any], driver: "greenlet"): + ... + + def throw(self, *arg: Any) -> Any: + ... + + def switch(self, value: Any) -> Any: + ... + + def getcurrent() -> greenlet: + ... + +else: + from greenlet import getcurrent + from greenlet import greenlet + if not typing.TYPE_CHECKING: try: @@ -24,12 +48,14 @@ if not typing.TYPE_CHECKING: # If greenlet.gr_context is present in current version of greenlet, # it will be set with a copy of the current context on creation. # Refs: https://github.com/python-greenlet/greenlet/pull/198 - getattr(greenlet.greenlet, "gr_context") + getattr(greenlet, "gr_context") except (ImportError, AttributeError): _copy_context = None # noqa +_T = TypeVar("_T", bound=Any) -def is_exit_exception(e): + +def is_exit_exception(e: BaseException) -> bool: # note asyncio.CancelledError is already BaseException # so was an exit exception in any case return not isinstance(e, Exception) or isinstance( @@ -42,15 +68,17 @@ def is_exit_exception(e): # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 -class _AsyncIoGreenlet(greenlet.greenlet): # type: ignore - def __init__(self, fn, driver): - greenlet.greenlet.__init__(self, fn, driver) +class _AsyncIoGreenlet(greenlet): # type: ignore + dead: bool + + def __init__(self, fn: Callable[..., Any], driver: greenlet): + greenlet.__init__(self, fn, driver) self.driver = driver if _copy_context is not None: self.gr_context = _copy_context() -def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any: +def await_only(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. @@ -60,7 +88,7 @@ def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any: """ # this is called in the context greenlet while running fn - current = greenlet.getcurrent() + current = getcurrent() if not isinstance(current, _AsyncIoGreenlet): raise exc.MissingGreenlet( "greenlet_spawn has not been called; can't call await_() here. " @@ -71,10 +99,10 @@ def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any: # a coroutine to run. Once the awaitable is done, the driver greenlet # switches back to this greenlet with the result of awaitable that is # then returned to the caller (or raised as error) - return current.driver.switch(awaitable) + return current.driver.switch(awaitable) # type: ignore[no-any-return] -def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any: +def await_fallback(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. @@ -83,8 +111,9 @@ def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any: :param awaitable: The coroutine to call. """ + # this is called in the context greenlet while running fn - current = greenlet.getcurrent() + current = getcurrent() if not isinstance(current, _AsyncIoGreenlet): loop = get_event_loop() if loop.is_running(): @@ -93,9 +122,9 @@ def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any: "loop is already running; can't call await_() here. " "Was IO attempted in an unexpected place?" ) - return loop.run_until_complete(awaitable) + return loop.run_until_complete(awaitable) # type: ignore[no-any-return] # noqa E501 - return current.driver.switch(awaitable) + return current.driver.switch(awaitable) # type: ignore[no-any-return] async def greenlet_spawn( @@ -114,7 +143,7 @@ async def greenlet_spawn( :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable. """ - context = _AsyncIoGreenlet(fn, greenlet.getcurrent()) + context = _AsyncIoGreenlet(fn, getcurrent()) # runs the function synchronously in gl greenlet. If the execution # is interrupted by await_, context is not dead and result is a # coroutine to wait. If the context is dead the function has @@ -149,21 +178,23 @@ async def greenlet_spawn( class AsyncAdaptedLock: @memoized_property - def mutex(self): + def mutex(self) -> asyncio.Lock: # there should not be a race here for coroutines creating the # new lock as we are not using await, so therefore no concurrency return asyncio.Lock() - def __enter__(self): + def __enter__(self) -> bool: # await is used to acquire the lock only after the first calling # coroutine has created the mutex. return await_fallback(self.mutex.acquire()) - def __exit__(self, *arg, **kw): + def __exit__(self, *arg: Any, **kw: Any) -> None: self.mutex.release() -def _util_async_run_coroutine_function(fn, *args, **kwargs): +def _util_async_run_coroutine_function( + fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any +) -> Any: """for test suite/ util only""" loop = get_event_loop() @@ -175,7 +206,10 @@ def _util_async_run_coroutine_function(fn, *args, **kwargs): return loop.run_until_complete(fn(*args, **kwargs)) -def _util_async_run(fn, *args, **kwargs): +def _util_async_run( + fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any +) -> Any: + """for test suite/ util only""" loop = get_event_loop() @@ -183,11 +217,11 @@ def _util_async_run(fn, *args, **kwargs): return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs)) else: # allow for a wrapped test function to call another - assert isinstance(greenlet.getcurrent(), _AsyncIoGreenlet) + assert isinstance(getcurrent(), _AsyncIoGreenlet) return fn(*args, **kwargs) -def get_event_loop(): +def get_event_loop() -> asyncio.AbstractEventLoop: """vendor asyncio.get_event_loop() for python 3.7 and above. Python 3.10 deprecates get_event_loop() as a standalone. diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index f91d902dae..7e1d3213ab 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -14,23 +14,32 @@ import re from typing import Any from typing import Callable from typing import cast +from typing import Dict +from typing import Match from typing import Optional +from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TypeVar +from typing import Union from . import compat from .langhelpers import _hash_limit_string from .langhelpers import _warnings_warn from .langhelpers import decorator +from .langhelpers import dynamic_property from .langhelpers import inject_docstring_text from .langhelpers import inject_param_text -from .typing import ReadOnlyInstanceDescriptor from .. import exc _T = TypeVar("_T", bound=Any) +# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators +_F = TypeVar("_F", bound=Callable[..., Any]) + + def _warn_with_version( msg: str, version: str, @@ -52,7 +61,13 @@ def warn_deprecated( ) -def warn_deprecated_limited(msg, args, version, stacklevel=3, code=None): +def warn_deprecated_limited( + msg: str, + args: Sequence[Any], + version: str, + stacklevel: int = 3, + code: Optional[str] = None, +) -> None: """Issue a deprecation warning with a parameterized string, limiting the number of registrations. @@ -64,10 +79,12 @@ def warn_deprecated_limited(msg, args, version, stacklevel=3, code=None): ) -def deprecated_cls(version, message, constructor="__init__"): +def deprecated_cls( + version: str, message: str, constructor: str = "__init__" +) -> Callable[[Type[_T]], Type[_T]]: header = ".. deprecated:: %s %s" % (version, (message or "")) - def decorate(cls): + def decorate(cls: Type[_T]) -> Type[_T]: return _decorate_cls_with_warning( cls, constructor, @@ -84,9 +101,9 @@ def deprecated_property( version: str, message: Optional[str] = None, add_deprecation_to_docstring: bool = True, - warning: Optional[str] = None, + warning: Optional[Type[exc.SADeprecationWarning]] = None, enable_warnings: bool = True, -) -> Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]]: +) -> Callable[[Callable[..., _T]], dynamic_property[_T]]: """the @deprecated decorator with a @property. E.g.:: @@ -113,9 +130,9 @@ def deprecated_property( great! now it is. """ - return cast( - Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]], - lambda fn: property( + + def decorate(fn: Callable[..., _T]) -> dynamic_property[_T]: + return dynamic_property( deprecated( version, message=message, @@ -123,17 +140,18 @@ def deprecated_property( warning=warning, enable_warnings=enable_warnings, )(fn) - ), - ) + ) + + return decorate def deprecated( - version, - message=None, - add_deprecation_to_docstring=True, - warning=None, - enable_warnings=True, -): + version: str, + message: Optional[str] = None, + add_deprecation_to_docstring: bool = True, + warning: Optional[Type[exc.SADeprecationWarning]] = None, + enable_warnings: bool = True, +) -> Callable[[_F], _F]: """Decorates a function and issues a deprecation warning on use. :param version: @@ -166,7 +184,9 @@ def deprecated( message += " (deprecated since: %s)" % version - def decorate(fn): + def decorate(fn: _F) -> _F: + assert message is not None + assert warning is not None return _decorate_with_warning( fn, warning, @@ -179,13 +199,17 @@ def deprecated( return decorate -def moved_20(message, **kw): +def moved_20( + message: str, **kw: Any +) -> Callable[[Callable[..., _T]], Callable[..., _T]]: return deprecated( "2.0", message=message, warning=exc.MovedIn20Warning, **kw ) -def became_legacy_20(api_name, alternative=None, **kw): +def became_legacy_20( + api_name: str, alternative: Optional[str] = None, **kw: Any +) -> Callable[[_F], _F]: type_reg = re.match("^:(attr|func|meth):", api_name) if type_reg: type_ = {"attr": "attribute", "func": "function", "meth": "method"}[ @@ -221,10 +245,7 @@ def became_legacy_20(api_name, alternative=None, **kw): return deprecated("2.0", message=message, warning=warning_cls, **kw) -_C = TypeVar("_C", bound=Callable[..., Any]) - - -def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: +def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]: """Decorates a function to warn on use of certain parameters. e.g. :: @@ -240,18 +261,19 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: """ - messages = {} - versions = {} - version_warnings = {} + messages: Dict[str, str] = {} + versions: Dict[str, str] = {} + version_warnings: Dict[str, Type[exc.SADeprecationWarning]] = {} for param, (version, message) in specs.items(): versions[param] = version messages[param] = _sanitize_restructured_text(message) version_warnings[param] = exc.SADeprecationWarning - def decorate(fn): + def decorate(fn: _F) -> _F: spec = compat.inspect_getfullargspec(fn) + check_defaults: Union[Set[str], Tuple[()]] if spec.defaults is not None: defaults = dict( zip( @@ -268,7 +290,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: check_any_kw = spec.varkw @decorator - def warned(fn, *args, **kwargs): + def warned(fn: _F, *args: Any, **kwargs: Any) -> _F: for m in check_defaults: if (defaults[m] is None and kwargs[m] is not None) or ( defaults[m] is not None and kwargs[m] != defaults[m] @@ -283,7 +305,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: if check_any_kw in messages and set(kwargs).difference( check_defaults ): - + assert check_any_kw is not None _warn_with_version( messages[check_any_kw], versions[check_any_kw], @@ -299,7 +321,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: version_warnings[m], stacklevel=3, ) - return fn(*args, **kwargs) + return fn(*args, **kwargs) # type: ignore[no-any-return] doc = fn.__doc__ is not None and fn.__doc__ or "" if doc: @@ -311,15 +333,15 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: for param, (version, message) in specs.items() }, ) - decorated = warned(fn) + decorated = cast(_F, warned)(fn) decorated.__doc__ = doc - return decorated + return decorated # type: ignore[no-any-return] return decorate -def _sanitize_restructured_text(text): - def repl(m): +def _sanitize_restructured_text(text: str) -> str: + def repl(m: Match[str]) -> str: type_, name = m.group(1, 2) if type_ in ("func", "meth"): name += "()" @@ -330,8 +352,13 @@ def _sanitize_restructured_text(text): def _decorate_cls_with_warning( - cls, constructor, wtype, message, version, docstring_header=None -): + cls: Type[_T], + constructor: str, + wtype: Type[exc.SADeprecationWarning], + message: str, + version: str, + docstring_header: Optional[str] = None, +) -> Type[_T]: doc = cls.__doc__ is not None and cls.__doc__ or "" if docstring_header is not None: @@ -361,6 +388,7 @@ def _decorate_cls_with_warning( if constructor is not None: assert constructor_fn is not None + assert wtype is not None setattr( cls, constructor, @@ -372,8 +400,13 @@ def _decorate_cls_with_warning( def _decorate_with_warning( - func, wtype, message, version, docstring_header=None, enable_warnings=True -): + func: _F, + wtype: Type[exc.SADeprecationWarning], + message: str, + version: str, + docstring_header: Optional[str] = None, + enable_warnings: bool = True, +) -> _F: """Wrap a function with a warnings.warn and augmented docstring.""" message = _sanitize_restructured_text(message) @@ -387,13 +420,13 @@ def _decorate_with_warning( doc_only = "" @decorator - def warned(fn, *args, **kwargs): + def warned(fn: _F, *args: Any, **kwargs: Any) -> _F: skip_warning = not enable_warnings or kwargs.pop( "_sa_skip_warning", False ) if not skip_warning: _warn_with_version(message, version, wtype, stacklevel=3) - return fn(*args, **kwargs) + return fn(*args, **kwargs) # type: ignore[no-any-return] doc = func.__doc__ is not None and func.__doc__ or "" if docstring_header is not None: @@ -403,9 +436,9 @@ def _decorate_with_warning( doc = inject_docstring_text(doc, docstring_header, 1) - decorated = warned(func) + decorated = cast(_F, warned)(func) decorated.__doc__ = doc decorated._sa_warn = lambda: _warn_with_version( message, version, wtype, stacklevel=3 ) - return decorated + return decorated # type: ignore[no-any-return] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 9e024b3c03..43f9d5c73f 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -32,6 +32,7 @@ from typing import Generic from typing import Iterator from typing import List from typing import Mapping +from typing import NoReturn from typing import Optional from typing import overload from typing import Sequence @@ -103,9 +104,14 @@ class safe_reraise: with safe_reraise(): sess.rollback() + TODO: is this context manager getting us anything in Python 3? + Not sure of the coroutine issue stated above; we would assume this was + when using eventlet / gevent. not sure if our own greenlet integration + is impacted. + """ - __slots__ = ("warn_only", "_exc_info") + __slots__ = ("_exc_info",) _exc_info: Union[ None, @@ -117,9 +123,6 @@ class safe_reraise: Tuple[None, None, None], ] - def __init__(self, warn_only: bool = False): - self.warn_only = warn_only - def __enter__(self) -> None: self._exc_info = sys.exc_info() @@ -128,15 +131,14 @@ class safe_reraise: type_: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[types.TracebackType], - ) -> None: + ) -> NoReturn: assert self._exc_info is not None # see #2703 for notes if type_ is None: exc_type, exc_value, exc_tb = self._exc_info assert exc_value is not None self._exc_info = None # remove potential circular references - if not self.warn_only: - raise exc_value.with_traceback(exc_tb) + raise exc_value.with_traceback(exc_tb) else: self._exc_info = None # remove potential circular references assert value is not None @@ -1123,13 +1125,22 @@ def as_interface(obj, cls=None, methods=None, required=None): ) +Selfdynamic_property = TypeVar( + "Selfdynamic_property", bound="dynamic_property[Any]" +) + Selfmemoized_property = TypeVar( "Selfmemoized_property", bound="memoized_property[Any]" ) -class memoized_property(Generic[_T]): - """A read-only @property that is only evaluated once.""" +class dynamic_property(Generic[_T]): + """A read-only @property that is evaluated each time. + + This is mostly the same as @property except we can type it + alongside memoized_property + + """ fget: Callable[..., _T] __doc__: Optional[str] @@ -1140,6 +1151,27 @@ class memoized_property(Generic[_T]): self.__doc__ = doc or fget.__doc__ self.__name__ = fget.__name__ + @overload + def __get__( + self: Selfdynamic_property, obj: None, cls: Any + ) -> Selfdynamic_property: + ... + + @overload + def __get__(self, obj: Any, cls: Any) -> _T: + ... + + def __get__( + self: Selfdynamic_property, obj: Any, cls: Any + ) -> Union[Selfdynamic_property, _T]: + if obj is None: + return self + return self.fget(obj) # type: ignore[no-any-return] + + +class memoized_property(dynamic_property[_T]): + """A read-only @property that is only evaluated once.""" + @overload def __get__( self: Selfmemoized_property, obj: None, cls: Any @@ -1158,7 +1190,16 @@ class memoized_property(Generic[_T]): obj.__dict__[self.__name__] = result = self.fget(obj) return result # type: ignore - def _reset(self, obj): + if typing.TYPE_CHECKING: + # __set__ can't actually be implemented because it would + # cause __get__ to be called in all cases + def __set__(self, instance: Any, value: Any) -> None: + ... + + def __delete__(self, instance: Any) -> None: + ... + + def _reset(self, obj: Any) -> None: memoized_property.reset(obj, self.__name__) @classmethod @@ -1628,6 +1669,39 @@ class symbol: raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) +def parse_user_argument_for_enum( + arg: Any, + choices: Dict[_T, List[Any]], + name: str, +) -> Optional[_T]: + """Given a user parameter, parse the parameter into a chosen value + from a list of choice objects, typically Enum values. + + The user argument can be a string name that matches the name of a + symbol, or the symbol object itself, or any number of alternate choices + such as True/False/ None etc. + + :param arg: the user argument. + :param choices: dictionary of enum values to lists of possible + entries for each. + :param name: name of the argument. Used in an :class:`.ArgumentError` + that is raised if the parameter doesn't match any available argument. + + """ + # TODO: use whatever built in thing Enum provides for this, + # if applicable + for enum_value, choice in choices.items(): + if arg is enum_value: + return enum_value + elif arg in choice: + return enum_value + + if arg is None: + return None + + raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) + + _creation_order = 1 @@ -1644,7 +1718,7 @@ def set_creation_order(instance): _creation_order += 1 -def warn_exception(func, *args, **kwargs): +def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """executes the given function, catches all exceptions and converts to a warning. @@ -1678,7 +1752,9 @@ class _hash_limit_string(str): _hash: int - def __new__(cls, value, num, args): + def __new__( + cls, value: str, num: int, args: Sequence[Any] + ) -> _hash_limit_string: interpolated = (value % args) + ( " (this warning may be suppressed after %d occurrences)" % num ) @@ -1686,14 +1762,14 @@ class _hash_limit_string(str): self._hash = hash("%s_%d" % (value, hash(interpolated) % num)) return self - def __hash__(self): + def __hash__(self) -> int: return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) -def warn(msg, code=None): +def warn(msg: str, code: Optional[str] = None) -> None: """Issue a warning. If msg is a string, :class:`.exc.SAWarning` is used as @@ -1706,7 +1782,7 @@ def warn(msg, code=None): _warnings_warn(msg, exc.SAWarning) -def warn_limited(msg, args): +def warn_limited(msg: str, args: Sequence[Any]) -> None: """Issue a warning with a parameterized string, limiting the number of registrations. @@ -1716,7 +1792,11 @@ def warn_limited(msg, args): _warnings_warn(msg, exc.SAWarning) -def _warnings_warn(message, category=None, stacklevel=2): +def _warnings_warn( + message: Union[str, Warning], + category: Optional[Type[Warning]] = None, + stacklevel: int = 2, +) -> None: # adjust the given stacklevel to be outside of SQLAlchemy try: @@ -1736,7 +1816,7 @@ def _warnings_warn(message, category=None, stacklevel=2): while frame is not None and re.match( r"^(?:sqlalchemy\.|alembic\.)", frame.f_globals.get("__name__", "") ): - frame = frame.f_back + frame = frame.f_back # type: ignore[assignment] stacklevel += 1 if category is not None: @@ -1775,7 +1855,11 @@ _SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py") _UNITTEST_RE = re.compile(r"unit(?:2|test2?/)") -def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE): +def chop_traceback( + tb: List[str], + exclude_prefix: re.Pattern[str] = _UNITTEST_RE, + exclude_suffix: re.Pattern[str] = _SQLA_RE, +) -> List[str]: """Chop extraneous lines off beginning and end of a traceback. :param tb: diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 3062d9d8ab..06b60c8bf8 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -17,16 +17,26 @@ producing a ``put()`` inside the ``get()`` and therefore a reentrant condition. """ +from __future__ import annotations + import asyncio from collections import deque import threading from time import time as _time +import typing +from typing import Any +from typing import Awaitable +from typing import Deque +from typing import Generic +from typing import Optional +from typing import TypeVar from .concurrency import await_fallback from .concurrency import await_only from .langhelpers import memoized_property +_T = TypeVar("_T", bound=Any) __all__ = ["Empty", "Full", "Queue"] @@ -42,8 +52,41 @@ class Full(Exception): pass -class Queue: - def __init__(self, maxsize=0, use_lifo=False): +class QueueCommon(Generic[_T]): + maxsize: int + use_lifo: bool + + def __init__(self, maxsize: int = 0, use_lifo: bool = False): + ... + + def empty(self) -> bool: + raise NotImplementedError() + + def full(self) -> bool: + raise NotImplementedError() + + def qsize(self) -> int: + raise NotImplementedError() + + def put_nowait(self, item: _T) -> None: + raise NotImplementedError() + + def put( + self, item: _T, block: bool = True, timeout: Optional[float] = None + ) -> None: + raise NotImplementedError() + + def get_nowait(self) -> _T: + raise NotImplementedError() + + def get(self, block: bool = True, timeout: Optional[float] = None) -> _T: + raise NotImplementedError() + + +class Queue(QueueCommon[_T]): + queue: Deque[_T] + + def __init__(self, maxsize: int = 0, use_lifo: bool = False): """Initialize a queue object with a given maximum size. If `maxsize` is <= 0, the queue size is infinite. @@ -66,27 +109,29 @@ class Queue: # If this queue uses LIFO or FIFO self.use_lifo = use_lifo - def qsize(self): + def qsize(self) -> int: """Return the approximate size of the queue (not reliable!).""" with self.mutex: return self._qsize() - def empty(self): + def empty(self) -> bool: """Return True if the queue is empty, False otherwise (not reliable!).""" with self.mutex: return self._empty() - def full(self): + def full(self) -> bool: """Return True if the queue is full, False otherwise (not reliable!).""" with self.mutex: return self._full() - def put(self, item, block=True, timeout=None): + def put( + self, item: _T, block: bool = True, timeout: Optional[float] = None + ) -> None: """Put an item into the queue. If optional args `block` is True and `timeout` is None (the @@ -118,7 +163,7 @@ class Queue: self._put(item) self.not_empty.notify() - def put_nowait(self, item): + def put_nowait(self, item: _T) -> None: """Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. @@ -126,7 +171,7 @@ class Queue: """ return self.put(item, False) - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: Optional[float] = None) -> _T: """Remove and return an item from the queue. If optional args `block` is True and `timeout` is None (the @@ -158,7 +203,7 @@ class Queue: self.not_full.notify() return item - def get_nowait(self): + def get_nowait(self) -> _T: """Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise @@ -167,32 +212,23 @@ class Queue: return self.get(False) - # Override these methods to implement other queue organizations - # (e.g. stack or priority queue). - # These will only be called with appropriate locks held - - # Initialize the queue representation - def _init(self, maxsize): + def _init(self, maxsize: int) -> None: self.maxsize = maxsize self.queue = deque() - def _qsize(self): + def _qsize(self) -> int: return len(self.queue) - # Check whether the queue is empty - def _empty(self): + def _empty(self) -> bool: return not self.queue - # Check whether the queue is full - def _full(self): + def _full(self) -> bool: return self.maxsize > 0 and len(self.queue) == self.maxsize - # Put a new item in the queue - def _put(self, item): + def _put(self, item: _T) -> None: self.queue.append(item) - # Get an item from the queue - def _get(self): + def _get(self) -> _T: if self.use_lifo: # LIFO return self.queue.pop() @@ -201,14 +237,21 @@ class Queue: return self.queue.popleft() -class AsyncAdaptedQueue: - await_ = staticmethod(await_only) +class AsyncAdaptedQueue(QueueCommon[_T]): + if typing.TYPE_CHECKING: - def __init__(self, maxsize=0, use_lifo=False): + @staticmethod + def await_(coroutine: Awaitable[Any]) -> _T: + ... + + else: + await_ = staticmethod(await_only) + + def __init__(self, maxsize: int = 0, use_lifo: bool = False): self.use_lifo = use_lifo self.maxsize = maxsize - def empty(self): + def empty(self) -> bool: return self._queue.empty() def full(self): @@ -218,7 +261,7 @@ class AsyncAdaptedQueue: return self._queue.qsize() @memoized_property - def _queue(self): + def _queue(self) -> asyncio.Queue[_T]: # Delay creation of the queue until it is first used, to avoid # binding it to a possibly wrong event loop. # By delaying the creation of the pool we accommodate the common @@ -226,39 +269,41 @@ class AsyncAdaptedQueue: # different event loop is in present compared to when the application # is actually run. + queue: asyncio.Queue[_T] + if self.use_lifo: queue = asyncio.LifoQueue(maxsize=self.maxsize) else: queue = asyncio.Queue(maxsize=self.maxsize) return queue - def put_nowait(self, item): + def put_nowait(self, item: _T) -> None: try: - return self._queue.put_nowait(item) + self._queue.put_nowait(item) except asyncio.QueueFull as err: raise Full() from err - def put(self, item, block=True, timeout=None): + def put( + self, item: _T, block: bool = True, timeout: Optional[float] = None + ) -> None: if not block: return self.put_nowait(item) try: if timeout is not None: - return self.await_( - asyncio.wait_for(self._queue.put(item), timeout) - ) + self.await_(asyncio.wait_for(self._queue.put(item), timeout)) else: - return self.await_(self._queue.put(item)) + self.await_(self._queue.put(item)) except (asyncio.QueueFull, asyncio.TimeoutError) as err: raise Full() from err - def get_nowait(self): + def get_nowait(self) -> _T: try: return self._queue.get_nowait() except asyncio.QueueEmpty as err: raise Empty() from err - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: Optional[float] = None) -> _T: if not block: return self.get_nowait() @@ -273,5 +318,6 @@ class AsyncAdaptedQueue: raise Empty() from err -class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue): - await_ = staticmethod(await_fallback) +class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue[_T]): + if not typing.TYPE_CHECKING: + await_ = staticmethod(await_fallback) diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 404f239c89..ddda420db1 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -80,25 +80,6 @@ class _TypeToInstance(Generic[_T]): ... -class ReadOnlyInstanceDescriptor(Protocol[_T]): - """protocol representing an instance-only descriptor""" - - @overload - def __get__( - self, instance: None, owner: Any - ) -> "ReadOnlyInstanceDescriptor[_T]": - ... - - @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... - - def __get__( - self, instance: object, owner: Any - ) -> Union["ReadOnlyInstanceDescriptor[_T]", _T]: - ... - - def de_stringify_annotation( cls: Type[Any], annotation: Union[str, Type[Any]] ) -> Union[str, Type[Any]]: diff --git a/pyproject.toml b/pyproject.toml index b6f0952390..f7750b6a6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ markers = [ [tool.pyright] include = [ + "lib/sqlalchemy/pool/", "lib/sqlalchemy/event/", "lib/sqlalchemy/events.py", "lib/sqlalchemy/exc.py", @@ -50,6 +51,9 @@ include = [ "lib/sqlalchemy/util/", ] +reportPrivateUsage = "none" +reportUnusedClass = "none" +reportUnusedFunction = "none" [tool.mypy] @@ -78,6 +82,7 @@ strict = true # strict checking [[tool.mypy.overrides]] module = [ + "sqlalchemy.pool.*", "sqlalchemy.event.*", "sqlalchemy.events", "sqlalchemy.exc", diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 0c89752025..c1613069e1 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -322,7 +322,7 @@ class PoolTest(PoolTestBase): is_(rec.connection, rec.dbapi_connection) is_(rec.driver_connection, rec.dbapi_connection) - fairy = pool._ConnectionFairy(rec.dbapi_connection, rec, False) + fairy = pool._ConnectionFairy(p1, rec.dbapi_connection, rec, False) is_not_none(fairy.dbapi_connection) is_(fairy.connection, fairy.dbapi_connection) @@ -346,12 +346,13 @@ class PoolTest(PoolTestBase): rec = pool._ConnectionRecord(p1) + assert rec.dbapi_connection is not None is_not_none(rec.dbapi_connection) is_(rec.connection, rec.dbapi_connection) is_(rec.driver_connection, mock_dc) - fairy = pool._ConnectionFairy(rec.dbapi_connection, rec, False) + fairy = pool._ConnectionFairy(p1, rec.dbapi_connection, rec, False) is_not_none(fairy.dbapi_connection) is_(fairy.connection, fairy.dbapi_connection)