]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep-484 for pool
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Feb 2022 04:43:51 +0000 (23:43 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 Feb 2022 19:45:04 +0000 (14:45 -0500)
also extends into some areas of utils, events and others
as needed.

Formalizes a public hierarchy for pool API,
with ManagesConnection -> PoolProxiedConnection /
ConnectionPoolEntry for connectionfairy / connectionrecord,
which are now what's exposed in the event API and other
APIs.  all public API docs moved to the new objects.

Corrects the mypy plugin's check for sqlalchemy-stubs
not being insatlled, which has to be imported using the
dash in the name to be effective.

Change-Id: I16c2cb43b2e840d28e70a015f370a768e70f3581

25 files changed:
doc/build/core/pooling.rst
doc/build/faq/connections.rst
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/event/__init__.py
lib/sqlalchemy/event/attr.py
lib/sqlalchemy/event/base.py
lib/sqlalchemy/event/registry.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/ext/mypy/plugin.py
lib/sqlalchemy/log.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/pool/__init__.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/pool/events.py
lib/sqlalchemy/pool/impl.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_concurrency_py3k.py
lib/sqlalchemy/util/deprecations.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/queue.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/engine/test_pool.py

index bb5e2826a72eaef1f8fce6da07e17cb2b8b4acc9..008c4e1a157753dee1ff0b0f67c4179e51f12022 100644 (file)
@@ -558,14 +558,18 @@ API Documentation - Available Pool Implementations
 
 .. autoclass:: StaticPool
 
-.. autoclass:: PoolProxiedConnection
+.. autoclass:: ManagesConnection
     :members:
 
-.. autoclass:: _ConnectionFairy
+.. autoclass:: ConnectionPoolEntry
     :members:
+    :inherited-members:
 
-    .. autoattribute:: _connection_record
+.. autoclass:: PoolProxiedConnection
+    :members:
+    :inherited-members:
+
+.. autoclass:: _ConnectionFairy
 
 .. autoclass:: _ConnectionRecord
-    :members:
 
index 02d088384c0300fb063a03f5ce090dcbaf01284a..d592ccf6dc88084f4ff5ea11917dbeceb6d3ec5b 100644 (file)
@@ -414,14 +414,14 @@ How do I get at the raw DBAPI connection when using an Engine?
 With a regular SA engine-level Connection, you can get at a pool-proxied
 version of the DBAPI connection via the :attr:`_engine.Connection.connection` attribute on
 :class:`_engine.Connection`, and for the really-real DBAPI connection you can call the
-:attr:`._ConnectionFairy.dbapi_connection` attribute on that.  On regular sync drivers
+:attr:`.PoolProxiedConnection.dbapi_connection` attribute on that.  On regular sync drivers
 there is usually no need to access the non-pool-proxied DBAPI connection,
 as all methods are proxied through::
 
     engine = create_engine(...)
     conn = engine.connect()
 
-    # pep-249 style ConnectionFairy connection pool proxy object
+    # pep-249 style PoolProxiedConnection (historically called a "connection fairy")
     connection_fairy = conn.connection
 
     # typically to run statements one would get a cursor() from this
@@ -438,11 +438,11 @@ as all methods are proxied through::
     also_raw_dbapi_connection = connection_fairy.driver_connection
 
 .. versionchanged:: 1.4.24  Added the
-   :attr:`._ConnectionFairy.dbapi_connection` attribute,
+   :attr:`.PoolProxiedConnection.dbapi_connection` attribute,
    which supersedes the previous
-   :attr:`._ConnectionFairy.connection` attribute which still remains
+   :attr:`.PoolProxiedConnection.connection` attribute which still remains
    available; this attribute always provides a pep-249 synchronous style
-   connection object.  The :attr:`._ConnectionFairy.driver_connection`
+   connection object.  The :attr:`.PoolProxiedConnection.driver_connection`
    attribute is also added which will always refer to the real driver-level
    connection regardless of what API it presents.
 
@@ -451,15 +451,15 @@ Accessing the underlying connection for an asyncio driver
 
 When an asyncio driver is in use, there are two changes to the above
 scheme.  The first is that when using an :class:`_asyncio.AsyncConnection`,
-the :class:`._ConnectionFairy` must be accessed using the awaitable method
+the :class:`.PoolProxiedConnection` must be accessed using the awaitable method
 :meth:`_asyncio.AsyncConnection.get_raw_connection`.   The
-returned :class:`._ConnectionFairy` in this case retains a sync-style
-pep-249 usage pattern, and the :attr:`._ConnectionFairy.dbapi_connection`
+returned :class:`.PoolProxiedConnection` in this case retains a sync-style
+pep-249 usage pattern, and the :attr:`.PoolProxiedConnection.dbapi_connection`
 attribute refers to a
 a SQLAlchemy-adapted connection object which adapts the asyncio
 connection to a sync style pep-249 API, in other words there are *two* levels
 of proxying going on when using an asyncio driver.   The actual asyncio connection
-is available from the :class:`._ConnectionFairy.driver_connection` attribute.
+is available from the :class:`.PoolProxiedConnection.driver_connection` attribute.
 To restate the previous example in terms of asyncio looks like::
 
     async def main():
@@ -483,8 +483,8 @@ To restate the previous example in terms of asyncio looks like::
         result = await raw_asyncio_connection.execute(...)
 
 .. versionchanged:: 1.4.24  Added the
-   :attr:`._ConnectionFairy.dbapi_connection`
-   and :attr:`._ConnectionFairy.driver_connection` attributes to allow access
+   :attr:`.PoolProxiedConnection.dbapi_connection`
+   and :attr:`.PoolProxiedConnection.driver_connection` attributes to allow access
    to pep-249 connections, pep-249 adaption layers, and underlying driver
    connections using a consistent interface.
 
@@ -493,10 +493,10 @@ SQLAlchemy-adapted form of connection which presents a synchronous-style
 pep-249 style API.  To access the actual
 asyncio driver connection, which will present the original asyncio API
 of the driver in use, this can be accessed via the
-:attr:`._ConnectionFairy.driver_connection` attribute of
-:class:`._ConnectionFairy`.
-For a standard pep-249 driver, :attr:`._ConnectionFairy.dbapi_connection`
-and :attr:`._ConnectionFairy.driver_connection` are synonymous.
+:attr:`.PoolProxiedConnection.driver_connection` attribute of
+:class:`.PoolProxiedConnection`.
+For a standard pep-249 driver, :attr:`.PoolProxiedConnection.dbapi_connection`
+and :attr:`.PoolProxiedConnection.driver_connection` are synonymous.
 
 You must ensure that you revert any isolation level settings or other
 operation-specific settings on the connection back to normal before returning
index 4fd273948468e78521e278fd481b7d33684d8999..8c99f63090406547a3cc8bf246758ec7f48d8410 100644 (file)
@@ -1771,15 +1771,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             if not self._is_disconnect:
                 if cursor:
                     self._safe_close_cursor(cursor)
-                with util.safe_reraise(warn_only=True):
-                    # "autorollback" was mostly relevant in 1.x series.
-                    # It's very unlikely to reach here, as the connection
-                    # does autobegin so when we are here, we are usually
-                    # in an explicit / semi-explicit transaction.
-                    # however we have a test which manufactures this
-                    # scenario in any case using an event handler.
-                    if not self.in_transaction():
-                        self._rollback_impl()
+                # "autorollback" was mostly relevant in 1.x series.
+                # It's very unlikely to reach here, as the connection
+                # does autobegin so when we are here, we are usually
+                # in an explicit / semi-explicit transaction.
+                # however we have a test which manufactures this
+                # scenario in any case using an event handler.
+                # test/engine/test_execute.py-> test_actual_autorollback
+                if not self.in_transaction():
+                    self._rollback_impl()
 
             if newraise:
                 raise newraise.with_traceback(exc_info[2]) from e
@@ -2318,11 +2318,15 @@ class Engine(
 
     _schema_translate_map = None
 
+    dialect: Dialect
+    pool: Pool
+    url: URL
+
     def __init__(
         self,
-        pool: "Pool",
-        dialect: "Dialect",
-        url: "URL",
+        pool: Pool,
+        dialect: Dialect,
+        url: URL,
         logging_name: Optional[str] = None,
         echo: Union[None, str, bool] = None,
         query_cache_size: int = 500,
index a252b7cfebfea0ebf181da56c155844588125012..ac3d6a2d89694dc06b2ff8c2af3bc57f84302a31 100644 (file)
@@ -12,11 +12,13 @@ from typing import Union
 
 from . import base
 from . import url as _url
+from .interfaces import DBAPIConnection
 from .mock import create_mock_engine
 from .. import event
 from .. import exc
-from .. import pool as poollib
 from .. import util
+from ..pool import _AdhocProxiedConnection
+from ..pool import ConnectionPoolEntry
 from ..sql import compiler
 
 
@@ -603,10 +605,13 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
         if builtin_on_connect:
             event.listen(pool, "connect", builtin_on_connect)
 
-        def first_connect(dbapi_connection, connection_record):
+        def first_connect(
+            dbapi_connection: DBAPIConnection,
+            connection_record: ConnectionPoolEntry,
+        ):
             c = base.Connection(
                 engine,
-                connection=poollib._AdhocProxiedConnection(
+                connection=_AdhocProxiedConnection(
                     dbapi_connection, connection_record
                 ),
                 _has_events=False,
index ce884614c0678ed0ed2f6a7f7266c5307b270d39..aab6b2de87675f3212660fc3a360bde156a94473 100644 (file)
@@ -59,7 +59,7 @@ class DBAPIConnection(Protocol):
     def commit(self) -> None:
         ...
 
-    def cursor(self) -> "DBAPICursor":
+    def cursor(self) -> DBAPICursor:
         ...
 
     def rollback(self) -> None:
@@ -657,6 +657,9 @@ class Dialect:
 
     """
 
+    is_async: bool
+    """Whether or not this dialect is intended for asyncio use."""
+
     def create_connect_args(
         self, url: "URL"
     ) -> Tuple[Tuple[str], Mapping[str, Any]]:
@@ -1091,7 +1094,7 @@ class Dialect:
 
         raise NotImplementedError()
 
-    def do_close(self, dbapi_connection: PoolProxiedConnection) -> None:
+    def do_close(self, dbapi_connection: DBAPIConnection) -> None:
         """Provide an implementation of ``connection.close()``, given a DBAPI
         connection.
 
@@ -1104,6 +1107,11 @@ class Dialect:
 
         raise NotImplementedError()
 
+    def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
+        """ping the DBAPI connection and return True if the connection is
+        usable."""
+        raise NotImplementedError()
+
     def do_set_input_sizes(
         self,
         cursor: DBAPICursor,
@@ -1679,7 +1687,7 @@ class Dialect:
 
         """
 
-    def get_driver_connection(self, connection: PoolProxiedConnection) -> Any:
+    def get_driver_connection(self, connection: DBAPIConnection) -> Any:
         """Returns the connection object as returned by the external driver
         package.
 
index 2d10372ab11c26cb74e3dd912c38efe0237384ad..0dfb39e1a04752f69306a7bb21a258c33d4fba34 100644 (file)
@@ -14,6 +14,10 @@ from .api import listens_for as listens_for
 from .api import NO_RETVAL as NO_RETVAL
 from .api import remove as remove
 from .attr import RefCollection as RefCollection
+from .base import _Dispatch as _Dispatch
 from .base import dispatcher as dispatcher
 from .base import Events as Events
-from .legacy import _legacy_signature
+from .legacy import _legacy_signature as _legacy_signature
+from .registry import _EventKey as _EventKey
+from .registry import _ListenerFnType as _ListenerFnType
+from .registry import EventTarget as EventTarget
index d1ae7a84527edcc681aff6f3de4dc380aa9b37b9..9692894fe82e444e8fb1bc0de237e83cfe237461 100644 (file)
@@ -68,6 +68,7 @@ _T = TypeVar("_T", bound=Any)
 
 if typing.TYPE_CHECKING:
     from .base import _Dispatch
+    from .base import _DispatchCommon
     from .base import _HasEventsDispatch
     from .base import _JoinedDispatcher
 
@@ -280,6 +281,38 @@ class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]):
     def __bool__(self) -> bool:
         raise NotImplementedError()
 
+    def exec_once(self, *args: Any, **kw: Any) -> None:
+        raise NotImplementedError()
+
+    def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None:
+        raise NotImplementedError()
+
+    def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None:
+        raise NotImplementedError()
+
+    def __call__(self, *args: Any, **kw: Any) -> None:
+        raise NotImplementedError()
+
+    def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
+        raise NotImplementedError()
+
+    def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
+        raise NotImplementedError()
+
+    def remove(self, event_key: _EventKey[_ET]) -> None:
+        raise NotImplementedError()
+
+    def for_modify(
+        self, obj: _DispatchCommon[_ET]
+    ) -> _InstanceLevelDispatch[_ET]:
+        """Return an event collection which can be modified.
+
+        For _ClsLevelDispatch at the class level of
+        a dispatcher, this returns self.
+
+        """
+        return self
+
 
 class _EmptyListener(_InstanceLevelDispatch[_ET]):
     """Serves as a proxy interface to the events
@@ -306,7 +339,9 @@ class _EmptyListener(_InstanceLevelDispatch[_ET]):
         self.parent_listeners = parent._clslevel[target_cls]
         self.name = parent.name
 
-    def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]:
+    def for_modify(
+        self, obj: _DispatchCommon[_ET]
+    ) -> _ListenerCollection[_ET]:
         """Return an event collection which can be modified.
 
         For _EmptyListener at the instance level of
@@ -315,6 +350,8 @@ class _EmptyListener(_InstanceLevelDispatch[_ET]):
         and returns it.
 
         """
+        obj = cast("_Dispatch[_ET]", obj)
+
         assert obj._instance_cls is not None
         result = _ListenerCollection(self.parent, obj._instance_cls)
         if getattr(obj, self.name) is self:
@@ -512,7 +549,9 @@ class _ListenerCollection(_CompoundListener[_ET]):
         self.listeners = collections.deque()
         self.propagate = set()
 
-    def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]:
+    def for_modify(
+        self, obj: _DispatchCommon[_ET]
+    ) -> _ListenerCollection[_ET]:
         """Return an event collection which can be modified.
 
         For _ListenerCollection at the instance level of
@@ -599,7 +638,7 @@ class _JoinedListener(_CompoundListener[_ET]):
     ) -> _ListenerFnType:
         return self.local._adjust_fn_spec(fn, named)
 
-    def for_modify(self, obj: _JoinedDispatcher[_ET]) -> _JoinedListener[_ET]:
+    def for_modify(self, obj: _DispatchCommon[_ET]) -> _JoinedListener[_ET]:
         self.local = self.parent_listeners = self.local.for_modify(obj)
         return self
 
index 0e0647036f5d2cff058b26a0ac0c6d8ec7bc46d2..ef3ff9dab3a9b749fbff08c49369cb8a6709b95b 100644 (file)
@@ -17,6 +17,7 @@ instances of ``_Dispatch``.
 """
 from __future__ import annotations
 
+import typing
 from typing import Any
 from typing import cast
 from typing import Dict
@@ -71,7 +72,11 @@ class _UnpickleDispatch:
             raise AttributeError("No class with a 'dispatch' member present.")
 
 
-class _Dispatch(Generic[_ET]):
+class _DispatchCommon(Generic[_ET]):
+    __slots__ = ()
+
+
+class _Dispatch(_DispatchCommon[_ET]):
     """Mirror the event listening definitions of an Events class with
     listener collections.
 
@@ -218,6 +223,11 @@ class _HasEventsDispatch(Generic[_ET]):
 
     """
 
+    if typing.TYPE_CHECKING:
+
+        def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]:
+            ...
+
     def __init_subclass__(cls) -> None:
         """Intercept new Event subclasses and create associated _Dispatch
         classes."""
@@ -357,7 +367,7 @@ class Events(_HasEventsDispatch[_ET]):
         cls.dispatch._clear()
 
 
-class _JoinedDispatcher(Generic[_ET]):
+class _JoinedDispatcher(_DispatchCommon[_ET]):
     """Represent a connection between two _Dispatch objects."""
 
     __slots__ = "local", "parent", "_instance_cls"
@@ -402,11 +412,11 @@ class dispatcher(Generic[_ET]):
     @overload
     def __get__(
         self, obj: Literal[None], cls: Type[Any]
-    ) -> Type[_HasEventsDispatch[_ET]]:
+    ) -> Type[_Dispatch[_ET]]:
         ...
 
     @overload
-    def __get__(self, obj: Any, cls: Type[Any]) -> _HasEventsDispatch[_ET]:
+    def __get__(self, obj: Any, cls: Type[Any]) -> _Dispatch[_ET]:
         ...
 
     def __get__(self, obj: Any, cls: Type[Any]) -> Any:
index e20d3e0b53649d07fbae1659f8f6ed328693f857..449f391876443405343a0775af59814a029539a0 100644 (file)
@@ -22,7 +22,6 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import cast
-from typing import ClassVar
 from typing import Deque
 from typing import Dict
 from typing import Generic
@@ -35,7 +34,6 @@ import weakref
 
 from .. import exc
 from .. import util
-from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
     from .attr import RefCollection
@@ -46,7 +44,10 @@ _ListenerFnKeyType = Union[int, Tuple[int, int]]
 _EventKeyTupleType = Tuple[int, str, _ListenerFnKeyType]
 
 
-class _EventTargetType(Protocol):
+_ET = TypeVar("_ET", bound="EventTarget")
+
+
+class EventTarget:
     """represents an event target, that is, something we can listen on
     either with that target as a class or as an instance.
 
@@ -55,10 +56,10 @@ class _EventTargetType(Protocol):
 
     """
 
-    dispatch: ClassVar[dispatcher[Any]]
+    __slots__ = ()
 
+    dispatch: dispatcher[Any]
 
-_ET = TypeVar("_ET", bound=_EventTargetType)
 
 _RefCollectionToListenerType = Dict[
     "weakref.ref[RefCollection[Any]]",
@@ -104,7 +105,7 @@ def _collection_gced(ref: weakref.ref[Any]) -> None:
     if not _collection_to_key or ref not in _collection_to_key:
         return
 
-    ref = cast("weakref.ref[RefCollection[_EventTargetType]]", ref)
+    ref = cast("weakref.ref[RefCollection[EventTarget]]", ref)
 
     listener_to_key = _collection_to_key.pop(ref)
     for key in listener_to_key.values():
index f39f4cd8fa08d3e44dc8e38a154491f8dd2d4a17..1383e024a1cf93ec57ed020e1a258e8546fd2b8d 100644 (file)
@@ -32,7 +32,11 @@ if typing.TYPE_CHECKING:
     from .sql.compiler import Compiled
     from .sql.elements import ClauseElement
 
-_version_token = None
+if typing.TYPE_CHECKING:
+    _version_token: str
+else:
+    # set by __init__.py
+    _version_token = None
 
 
 class HasDescriptionCode:
index 3a78ab188cb61e775acdeb0809adb96e8b81174b..f7e66e341997b53ab7e66be58c88040d0c3cdd37 100644 (file)
@@ -43,16 +43,16 @@ from . import names
 from . import util
 
 try:
-    import sqlalchemy_stubs  # noqa
+    __import__("sqlalchemy-stubs")
 except ImportError:
     pass
 else:
-    import sqlalchemy
-
     raise ImportError(
-        f"The SQLAlchemy mypy plugin in SQLAlchemy "
-        f"{sqlalchemy.__version__} does not work with sqlalchemy-stubs or "
-        "sqlalchemy2-stubs installed"
+        "The SQLAlchemy mypy plugin in SQLAlchemy "
+        "2.0 does not work with sqlalchemy-stubs or "
+        "sqlalchemy2-stubs installed, as well as with any other third party "
+        "SQLAlchemy stubs.  Please uninstall all SQLAlchemy stubs "
+        "packages."
     )
 
 
index 2f63b8569d781236df3ac0961bb77ca99d687d8e..8da45ed0d72171af04358bb65aab99026d421afd 100644 (file)
@@ -75,12 +75,15 @@ def class_logger(cls: Type[_IT]) -> Type[_IT]:
     return cls
 
 
+_IdentifiedLoggerType = Union[logging.Logger, "InstanceLogger"]
+
+
 class Identified:
     __slots__ = ()
 
     logging_name: Optional[str] = None
 
-    logger: Union[logging.Logger, "InstanceLogger"]
+    logger: _IdentifiedLoggerType
 
     _echo: _EchoFlagType
 
index 33367c0c654f5ebb825009a3056acff88b90bcb5..b9c881cfe01f1064639b3a34201107716eec4292 100644 (file)
@@ -662,15 +662,9 @@ class Mapped(Generic[_T], TypingOnly):
         def _empty_constructor(cls, arg1: Any) -> "Mapped[_T]":
             ...
 
-        @overload
-        def __set__(self, instance: Any, value: _T) -> None:
-            ...
-
-        @overload
-        def __set__(self, instance: Any, value: SQLCoreOperations) -> None:
-            ...
-
-        def __set__(self, instance, value):
+        def __set__(
+            self, instance: Any, value: Union[SQLCoreOperations[_T], _T]
+        ):
             ...
 
         def __delete__(self, instance: Any):
index bc2f93d57ea88c40008839c20065482ce5b48022..2c52a7065029273e2c9d13dc14a40fa3514448d3 100644 (file)
@@ -22,6 +22,8 @@ from .base import _AdhocProxiedConnection
 from .base import _ConnectionFairy
 from .base import _ConnectionRecord
 from .base import _finalize_fairy
+from .base import ConnectionPoolEntry as ConnectionPoolEntry
+from .base import ManagesConnection as ManagesConnection
 from .base import Pool as Pool
 from .base import PoolProxiedConnection as PoolProxiedConnection
 from .base import reset_commit as reset_commit
index 72c56716f118d892f675171fedd92e4d1f6d6656..18d268182d8ff3da2cbe8320d0ed19d329c43b8d 100644 (file)
 from __future__ import annotations
 
 from collections import deque
+from enum import Enum
+import threading
 import time
+import typing
 from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Deque
 from typing import Dict
+from typing import List
 from typing import Optional
+from typing import Tuple
 from typing import TYPE_CHECKING
+from typing import Union
 import weakref
 
 from .. import event
 from .. import exc
 from .. import log
 from .. import util
+from ..util.typing import Literal
+from ..util.typing import Protocol
 
 if TYPE_CHECKING:
     from ..engine.interfaces import DBAPIConnection
+    from ..engine.interfaces import DBAPICursor
+    from ..engine.interfaces import Dialect
+    from ..event import _Dispatch
+    from ..event import _ListenerFnType
+    from ..event import dispatcher
 
-reset_rollback = util.symbol("reset_rollback")
-reset_commit = util.symbol("reset_commit")
-reset_none = util.symbol("reset_none")
+
+class ResetStyle(Enum):
+    """Describe options for "reset on return" behaviors."""
+
+    reset_rollback = 0
+    reset_commit = 1
+    reset_none = 2
+
+
+_ResetStyleArgType = Union[
+    ResetStyle,
+    Literal[True],
+    Literal[None],
+    Literal[False],
+    Literal["commit"],
+    Literal["rollback"],
+]
+reset_rollback, reset_commit, reset_none = list(ResetStyle)
 
 
 class _ConnDialect:
@@ -45,22 +76,22 @@ class _ConnDialect:
 
     is_async = False
 
-    def do_rollback(self, dbapi_connection):
+    def do_rollback(self, dbapi_connection: PoolProxiedConnection) -> None:
         dbapi_connection.rollback()
 
-    def do_commit(self, dbapi_connection):
+    def do_commit(self, dbapi_connection: PoolProxiedConnection) -> None:
         dbapi_connection.commit()
 
-    def do_close(self, dbapi_connection):
+    def do_close(self, dbapi_connection: DBAPIConnection) -> None:
         dbapi_connection.close()
 
-    def do_ping(self, dbapi_connection):
+    def do_ping(self, dbapi_connection: DBAPIConnection) -> None:
         raise NotImplementedError(
             "The ping feature requires that a dialect is "
             "passed to the connection pool."
         )
 
-    def get_driver_connection(self, connection):
+    def get_driver_connection(self, connection: DBAPIConnection) -> Any:
         return connection
 
 
@@ -68,23 +99,40 @@ class _AsyncConnDialect(_ConnDialect):
     is_async = True
 
 
-class Pool(log.Identified):
+class _CreatorFnType(Protocol):
+    def __call__(self) -> DBAPIConnection:
+        ...
+
+
+class _CreatorWRecFnType(Protocol):
+    def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection:
+        ...
+
+
+class Pool(log.Identified, event.EventTarget):
 
     """Abstract base class for connection pools."""
 
-    _dialect = _ConnDialect()
+    dispatch: dispatcher[Pool]
+    echo: log._EchoFlagType
+
+    _orig_logging_name: Optional[str]
+    _dialect: Union[_ConnDialect, Dialect] = _ConnDialect()
+    _creator_arg: Union[_CreatorFnType, _CreatorWRecFnType]
+    _invoke_creator: _CreatorWRecFnType
+    _invalidate_time: float
 
     def __init__(
         self,
-        creator,
-        recycle=-1,
-        echo=None,
-        logging_name=None,
-        reset_on_return=True,
-        events=None,
-        dialect=None,
-        pre_ping=False,
-        _dispatch=None,
+        creator: Union[_CreatorFnType, _CreatorWRecFnType],
+        recycle: int = -1,
+        echo: log._EchoFlagType = None,
+        logging_name: Optional[str] = None,
+        reset_on_return: _ResetStyleArgType = True,
+        events: Optional[List[Tuple[_ListenerFnType, str]]] = None,
+        dialect: Optional[Union[_ConnDialect, Dialect]] = None,
+        pre_ping: bool = False,
+        _dispatch: Optional[_Dispatch[Pool]] = None,
     ):
         """
         Construct a Pool.
@@ -188,15 +236,14 @@ class Pool(log.Identified):
         self._recycle = recycle
         self._invalidate_time = 0
         self._pre_ping = pre_ping
-        self._reset_on_return = util.symbol.parse_user_argument(
+        self._reset_on_return = util.parse_user_argument_for_enum(
             reset_on_return,
             {
-                reset_rollback: ["rollback", True],
-                reset_none: ["none", None, False],
-                reset_commit: ["commit"],
+                ResetStyle.reset_rollback: ["rollback", True],
+                ResetStyle.reset_none: ["none", None, False],
+                ResetStyle.reset_commit: ["commit"],
             },
             "reset_on_return",
-            resolve_symbol_names=False,
         )
 
         self.echo = echo
@@ -210,19 +257,32 @@ class Pool(log.Identified):
                 event.listen(self, target, fn)
 
     @util.hybridproperty
-    def _is_asyncio(self):
+    def _is_asyncio(self) -> bool:
         return self._dialect.is_async
 
     @property
-    def _creator(self):
-        return self.__dict__["_creator"]
+    def _creator(self) -> Union[_CreatorFnType, _CreatorWRecFnType]:
+        return self._creator_arg
 
     @_creator.setter
-    def _creator(self, creator):
-        self.__dict__["_creator"] = creator
-        self._invoke_creator = self._should_wrap_creator(creator)
+    def _creator(
+        self, creator: Union[_CreatorFnType, _CreatorWRecFnType]
+    ) -> None:
+        self._creator_arg = creator
+
+        # mypy seems to get super confused assigning functions to
+        # attributes
+        self._invoke_creator = self._should_wrap_creator(creator)  # type: ignore  # noqa E501
+
+    @_creator.deleter
+    def _creator(self) -> None:
+        # needed for mock testing
+        del self._creator_arg
+        del self._invoke_creator  # type: ignore[misc]
 
-    def _should_wrap_creator(self, creator):
+    def _should_wrap_creator(
+        self, creator: Union[_CreatorFnType, _CreatorWRecFnType]
+    ) -> _CreatorWRecFnType:
         """Detect if creator accepts a single argument, or is sent
         as a legacy style no-arg function.
 
@@ -231,26 +291,30 @@ class Pool(log.Identified):
         try:
             argspec = util.get_callable_argspec(self._creator, no_self=True)
         except TypeError:
-            return lambda crec: creator()
+            creator_fn = cast(_CreatorFnType, creator)
+            return lambda rec: creator_fn()
 
-        defaulted = argspec[3] is not None and len(argspec[3]) or 0
+        if argspec.defaults is not None:
+            defaulted = len(argspec.defaults)
+        else:
+            defaulted = 0
         positionals = len(argspec[0]) - defaulted
 
         # look for the exact arg signature that DefaultStrategy
         # sends us
         if (argspec[0], argspec[3]) == (["connection_record"], (None,)):
-            return creator
+            return cast(_CreatorWRecFnType, creator)
         # or just a single positional
         elif positionals == 1:
-            return creator
+            return cast(_CreatorWRecFnType, creator)
         # all other cases, just wrap and assume legacy "creator" callable
         # thing
         else:
-            return lambda crec: creator()
+            creator_fn = cast(_CreatorFnType, creator)
+            return lambda rec: creator_fn()
 
-    def _close_connection(self, connection):
+    def _close_connection(self, connection: DBAPIConnection) -> None:
         self.logger.debug("Closing connection %r", connection)
-
         try:
             self._dialect.do_close(connection)
         except Exception:
@@ -258,12 +322,17 @@ class Pool(log.Identified):
                 "Exception closing connection %r", connection, exc_info=True
             )
 
-    def _create_connection(self):
+    def _create_connection(self) -> ConnectionPoolEntry:
         """Called by subclasses to create a new ConnectionRecord."""
 
         return _ConnectionRecord(self)
 
-    def _invalidate(self, connection, exception=None, _checkin=True):
+    def _invalidate(
+        self,
+        connection: PoolProxiedConnection,
+        exception: Optional[BaseException] = None,
+        _checkin: bool = True,
+    ) -> None:
         """Mark all connections established within the generation
         of the given connection as invalidated.
 
@@ -280,7 +349,7 @@ class Pool(log.Identified):
         if _checkin and getattr(connection, "is_valid", False):
             connection.invalidate(exception)
 
-    def recreate(self):
+    def recreate(self) -> Pool:
         """Return a new :class:`_pool.Pool`, of the same class as this one
         and configured with identical creation arguments.
 
@@ -292,7 +361,7 @@ class Pool(log.Identified):
 
         raise NotImplementedError()
 
-    def dispose(self):
+    def dispose(self) -> None:
         """Dispose of this pool.
 
         This method leaves the possibility of checked-out connections
@@ -307,7 +376,7 @@ class Pool(log.Identified):
 
         raise NotImplementedError()
 
-    def connect(self):
+    def connect(self) -> PoolProxiedConnection:
         """Return a DBAPI connection from the pool.
 
         The connection is instrumented such that when its
@@ -317,7 +386,7 @@ class Pool(log.Identified):
         """
         return _ConnectionFairy._checkout(self)
 
-    def _return_conn(self, record):
+    def _return_conn(self, record: ConnectionPoolEntry) -> None:
         """Given a _ConnectionRecord, return it to the :class:`_pool.Pool`.
 
         This method is called when an instrumented DBAPI connection
@@ -326,100 +395,230 @@ class Pool(log.Identified):
         """
         self._do_return_conn(record)
 
-    def _do_get(self):
+    def _do_get(self) -> ConnectionPoolEntry:
         """Implementation for :meth:`get`, supplied by subclasses."""
 
         raise NotImplementedError()
 
-    def _do_return_conn(self, conn):
+    def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
         """Implementation for :meth:`return_conn`, supplied by subclasses."""
 
         raise NotImplementedError()
 
-    def status(self):
+    def status(self) -> str:
         raise NotImplementedError()
 
 
-class _ConnectionRecord:
+class ManagesConnection:
+    """Common base for the two connection-management interfaces
+    :class:`.PoolProxiedConnection` and :class:`.ConnectionPoolEntry`.
 
-    """Internal object which maintains an individual DBAPI connection
-    referenced by a :class:`_pool.Pool`.
+    These two objects are typically exposed in the public facing API
+    via the connection pool event hooks, documented at :class:`.PoolEvents`.
 
-    The :class:`._ConnectionRecord` object always exists for any particular
-    DBAPI connection whether or not that DBAPI connection has been
-    "checked out".  This is in contrast to the :class:`._ConnectionFairy`
-    which is only a public facade to the DBAPI connection while it is checked
-    out.
+    .. versionadded:: 2.0
 
-    A :class:`._ConnectionRecord` may exist for a span longer than that
-    of a single DBAPI connection.  For example, if the
-    :meth:`._ConnectionRecord.invalidate`
-    method is called, the DBAPI connection associated with this
-    :class:`._ConnectionRecord`
-    will be discarded, but the :class:`._ConnectionRecord` may be used again,
-    in which case a new DBAPI connection is produced when the
-    :class:`_pool.Pool`
-    next uses this record.
+    """
 
-    The :class:`._ConnectionRecord` is delivered along with connection
-    pool events, including :meth:`_events.PoolEvents.connect` and
-    :meth:`_events.PoolEvents.checkout`, however :class:`._ConnectionRecord`
-    still
-    remains an internal object whose API and internals may change.
+    __slots__ = ()
+
+    dbapi_connection: Optional[DBAPIConnection]
+    """A reference to the actual DBAPI connection being tracked.
+
+    This is a :pep:`249`-compliant object that for traditional sync-style
+    dialects is provided by the third-party
+    DBAPI implementation in use.  For asyncio dialects, the implementation
+    is typically an adapter object provided by the SQLAlchemy dialect
+    itself; the underlying asyncio object is available via the
+    :attr:`.ManagesConnection.driver_connection` attribute.
+
+    SQLAlchemy's interface for the DBAPI connection is based on the
+    :class:`.DBAPIConnection` protocol object
 
     .. seealso::
 
-        :class:`._ConnectionFairy`
+        :attr:`.ManagesConnection.driver_connection`
+
+        :ref:`faq_dbapi_connection`
 
     """
 
-    def __init__(self, pool, connect=True):
-        self.__pool = pool
-        if connect:
-            self.__connect()
-        self.finalize_callback = deque()
+    @property
+    def driver_connection(self) -> Optional[Any]:
+        """The "driver level" connection object as used by the Python
+        DBAPI or database driver.
+
+        For traditional :pep:`249` DBAPI implementations, this object will
+        be the same object as that of
+        :attr:`.ManagesConnection.dbapi_connection`.   For an asyncio database
+        driver, this will be the ultimate "connection" object used by that
+        driver, such as the ``asyncpg.Connection`` object which will not have
+        standard pep-249 methods.
+
+        .. versionadded:: 1.4.24
 
-    fresh = False
+        .. seealso::
 
-    fairy_ref = None
+            :attr:`.ManagesConnection.dbapi_connection`
 
-    starttime = None
+            :ref:`faq_dbapi_connection`
 
-    dbapi_connection = None
-    """A reference to the actual DBAPI connection being tracked.
+        """
+        raise NotImplementedError()
+
+    @util.dynamic_property
+    def info(self) -> Dict[str, Any]:
+        """Info dictionary associated with the underlying DBAPI connection
+        referred to by this :class:`.ManagesConnection` instance, allowing
+        user-defined data to be associated with the connection.
+
+        The data in this dictionary is persistent for the lifespan
+        of the DBAPI connection itself, including across pool checkins
+        and checkouts.  When the connection is invalidated
+        and replaced with a new one, this dictionary is cleared.
+
+        For a :class:`.PoolProxiedConnection` instance that's not associated
+        with a :class:`.ConnectionPoolEntry`, such as if it were detached, the
+        attribute returns a dictionary that is local to that
+        :class:`.ConnectionPoolEntry`. Therefore the
+        :attr:`.ManagesConnection.info` attribute will always provide a Python
+        dictionary.
+
+        .. seealso::
 
-    May be ``None`` if this :class:`._ConnectionRecord` has been marked
-    as invalidated; a new DBAPI connection may replace it if the owning
-    pool calls upon this :class:`._ConnectionRecord` to reconnect.
+            :attr:`.ManagesConnection.record_info`
 
-    For adapted drivers, like the Asyncio implementations, this is a
-    :class:`.AdaptedConnection` that adapts the driver connection
-    to the DBAPI protocol.
-    Use :attr:`._ConnectionRecord.driver_connection` to obtain the
-    connection objected returned by the driver.
 
-    .. versionadded:: 1.4.24
+        """
+        raise NotImplementedError()
+
+    @util.dynamic_property
+    def record_info(self) -> Optional[Dict[str, Any]]:
+        """Persistent info dictionary associated with this
+        :class:`.ManagesConnection`.
+
+        Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan
+        of this dictionary is that of the :class:`.ConnectionPoolEntry`
+        which owns it; therefore this dictionary will persist across
+        reconnects and connection invalidation for a particular entry
+        in the connection pool.
+
+        For a :class:`.PoolProxiedConnection` instance that's not associated
+        with a :class:`.ConnectionPoolEntry`, such as if it were detached, the
+        attribute returns None. Contrast to the :attr:`.ManagesConnection.info`
+        dictionary which is never None.
+
+
+        .. seealso::
+
+            :attr:`.ManagesConnection.info`
+
+        """
+        raise NotImplementedError()
+
+    def invalidate(
+        self, e: Optional[BaseException] = None, soft: bool = False
+    ) -> None:
+        """Mark the managed connection as invalidated.
+
+        :param e: an exception object indicating a reason for the invalidation.
+
+        :param soft: if True, the connection isn't closed; instead, this
+         connection will be recycled on next checkout.
+
+        .. seealso::
+
+            :ref:`pool_connection_invalidation`
+
+
+        """
+        raise NotImplementedError()
+
+
+class ConnectionPoolEntry(ManagesConnection):
+    """Interface for the object that maintains an individual database
+    connection on behalf of a :class:`_pool.Pool` instance.
+
+    The :class:`.ConnectionPoolEntry` object represents the long term
+    maintainance of a particular connection for a pool, including expiring or
+    invalidating that connection to have it replaced with a new one, which will
+    continue to be maintained by that same :class:`.ConnectionPoolEntry`
+    instance. Compared to :class:`.PoolProxiedConnection`, which is the
+    short-term, per-checkout connection manager, this object lasts for the
+    lifespan of a particular "slot" within a connection pool.
+
+    The :class:`.ConnectionPoolEntry` object is mostly visible to public-facing
+    API code when it is delivered to connection pool event hooks, such as
+    :meth:`_events.PoolEvents.connect` and :meth:`_events.PoolEvents.checkout`.
+
+    .. versionadded:: 2.0  :class:`.ConnectionPoolEntry` provides the public
+       facing interface for the :class:`._ConnectionRecord` internal class.
 
     """
 
+    __slots__ = ()
+
     @property
-    def driver_connection(self):
-        """The connection object as returned by the driver after a connect.
+    def in_use(self) -> bool:
+        """Return True the connection is currently checked out"""
 
-        For normal sync drivers that support the DBAPI protocol, this object
-        is the same as the one referenced by
-        :attr:`._ConnectionRecord.dbapi_connection`.
+        raise NotImplementedError()
 
-        For adapted drivers, like the Asyncio ones, this is the actual object
-        that was returned by the driver ``connect`` call.
+    def close(self) -> None:
+        """Close the DBAPI connection managed by this connection pool entry."""
+        raise NotImplementedError()
 
-        As :attr:`._ConnectionRecord.dbapi_connection` it may be ``None``
-        if this :class:`._ConnectionRecord` has been marked as invalidated.
 
-        .. versionadded:: 1.4.24
+class _ConnectionRecord(ConnectionPoolEntry):
 
-        """
+    """Maintains a position in a connection pool which references a pooled
+    connection.
 
+    This is an internal object used by the :class:`_pool.Pool` implementation
+    to provide context management to a DBAPI connection maintained by
+    that :class:`_pool.Pool`.   The public facing interface for this class
+    is described by the :class:`.ConnectionPoolEntry` class.  See that
+    class for public API details.
+
+    .. seealso::
+
+        :class:`.ConnectionPoolEntry`
+
+        :class:`.PoolProxiedConnection`
+
+    """
+
+    __slots__ = (
+        "__pool",
+        "fairy_ref",
+        "finalize_callback",
+        "fresh",
+        "starttime",
+        "dbapi_connection",
+        "__weakref__",
+        "__dict__",
+    )
+
+    finalize_callback: Deque[Callable[[DBAPIConnection], None]]
+    fresh: bool
+    fairy_ref: Optional[weakref.ref[_ConnectionFairy]]
+    starttime: float
+
+    def __init__(self, pool: Pool, connect: bool = True):
+        self.fresh = False
+        self.fairy_ref = None
+        self.starttime = 0
+        self.dbapi_connection = None
+
+        self.__pool = pool
+        if connect:
+            self.__connect()
+        self.finalize_callback = deque()
+
+    dbapi_connection: Optional[DBAPIConnection]
+
+    @property
+    def driver_connection(self) -> Optional[Any]:
         if self.dbapi_connection is None:
             return None
         else:
@@ -428,72 +627,41 @@ class _ConnectionRecord:
             )
 
     @property
-    def connection(self):
-        """An alias to :attr:`._ConnectionRecord.dbapi_connection`.
-
-        This alias is deprecated, please use the new name.
-
-        .. deprecated:: 1.4.24
-
-        """
+    def connection(self) -> Optional[DBAPIConnection]:
         return self.dbapi_connection
 
     @connection.setter
-    def connection(self, value):
+    def connection(self, value: DBAPIConnection) -> None:
         self.dbapi_connection = value
 
-    _soft_invalidate_time = 0
+    _soft_invalidate_time: float = 0
 
     @util.memoized_property
-    def info(self):
-        """The ``.info`` dictionary associated with the DBAPI connection.
-
-        This dictionary is shared among the :attr:`._ConnectionFairy.info`
-        and :attr:`_engine.Connection.info` accessors.
-
-        .. note::
-
-            The lifespan of this dictionary is linked to the
-            DBAPI connection itself, meaning that it is **discarded** each time
-            the DBAPI connection is closed and/or invalidated.   The
-            :attr:`._ConnectionRecord.record_info` dictionary remains
-            persistent throughout the lifespan of the
-            :class:`._ConnectionRecord` container.
-
-        """
+    def info(self) -> Dict[str, Any]:
         return {}
 
     @util.memoized_property
-    def record_info(self):
-        """An "info' dictionary associated with the connection record
-        itself.
-
-        Unlike the :attr:`._ConnectionRecord.info` dictionary, which is linked
-        to the lifespan of the DBAPI connection, this dictionary is linked
-        to the lifespan of the :class:`._ConnectionRecord` container itself
-        and will remain persistent throughout the life of the
-        :class:`._ConnectionRecord`.
-
-        .. versionadded:: 1.1
-
-        """
+    def record_info(self) -> Optional[Dict[str, Any]]:
         return {}
 
     @classmethod
-    def checkout(cls, pool):
-        rec = pool._do_get()
+    def checkout(cls, pool: Pool) -> _ConnectionFairy:
+        rec = cast(_ConnectionRecord, pool._do_get())
         try:
             dbapi_connection = rec.get_connection()
         except Exception as err:
             with util.safe_reraise():
                 rec._checkin_failed(err, _fairy_was_created=False)
+            raise
+
         echo = pool._should_log_debug()
-        fairy = _ConnectionFairy(dbapi_connection, rec, echo)
+        fairy = _ConnectionFairy(pool, dbapi_connection, rec, echo)
 
         rec.fairy_ref = ref = weakref.ref(
             fairy,
-            lambda ref: _finalize_fairy
-            and _finalize_fairy(None, rec, pool, ref, echo, True),
+            lambda ref: _finalize_fairy(None, rec, pool, ref, echo, True)
+            if _finalize_fairy
+            else None,
         )
         _strong_ref_connection_records[ref] = rec
         if echo:
@@ -502,13 +670,15 @@ class _ConnectionRecord:
             )
         return fairy
 
-    def _checkin_failed(self, err, _fairy_was_created=True):
+    def _checkin_failed(
+        self, err: Exception, _fairy_was_created: bool = True
+    ) -> None:
         self.invalidate(e=err)
         self.checkin(
             _fairy_was_created=_fairy_was_created,
         )
 
-    def checkin(self, _fairy_was_created=True):
+    def checkin(self, _fairy_was_created: bool = True) -> None:
         if self.fairy_ref is None and _fairy_was_created:
             # _fairy_was_created is False for the initial get connection phase;
             # meaning there was no _ConnectionFairy and we must unconditionally
@@ -524,47 +694,28 @@ class _ConnectionRecord:
         pool = self.__pool
         while self.finalize_callback:
             finalizer = self.finalize_callback.pop()
-            finalizer(connection)
+            if connection is not None:
+                finalizer(connection)
         if pool.dispatch.checkin:
             pool.dispatch.checkin(connection, self)
 
         pool._return_conn(self)
 
     @property
-    def in_use(self):
+    def in_use(self) -> bool:
         return self.fairy_ref is not None
 
     @property
-    def last_connect_time(self):
+    def last_connect_time(self) -> float:
         return self.starttime
 
-    def close(self):
+    def close(self) -> None:
         if self.dbapi_connection is not None:
             self.__close()
 
-    def invalidate(self, e=None, soft=False):
-        """Invalidate the DBAPI connection held by this
-        :class:`._ConnectionRecord`.
-
-        This method is called for all connection invalidations, including
-        when the :meth:`._ConnectionFairy.invalidate` or
-        :meth:`_engine.Connection.invalidate` methods are called,
-        as well as when any
-        so-called "automatic invalidation" condition occurs.
-
-        :param e: an exception object indicating a reason for the
-          invalidation.
-
-        :param soft: if True, the connection isn't closed; instead, this
-          connection will be recycled on next checkout.
-
-         .. versionadded:: 1.0.3
-
-        .. seealso::
-
-            :ref:`pool_connection_invalidation`
-
-        """
+    def invalidate(
+        self, e: Optional[BaseException] = None, soft: bool = False
+    ) -> None:
         # already invalidated
         if self.dbapi_connection is None:
             return
@@ -595,7 +746,7 @@ class _ConnectionRecord:
             self.__close()
             self.dbapi_connection = None
 
-    def get_connection(self):
+    def get_connection(self) -> DBAPIConnection:
         recycle = False
 
         # NOTE: the various comparisons here are assuming that measurable time
@@ -610,8 +761,9 @@ class _ConnectionRecord:
         # within 16 milliseconds accuracy, so unit tests for connection
         # invalidation need a sleep of at least this long between initial start
         # time and invalidation for the logic below to work reliably.
+
         if self.dbapi_connection is None:
-            self.info.clear()
+            self.info.clear()  # type: ignore  # our info is always present
             self.__connect()
         elif (
             self.__pool._recycle > -1
@@ -639,26 +791,29 @@ class _ConnectionRecord:
 
         if recycle:
             self.__close()
-            self.info.clear()
+            self.info.clear()  # type: ignore  # our info is always present
 
             self.__connect()
+
+        assert self.dbapi_connection is not None
         return self.dbapi_connection
 
-    def _is_hard_or_soft_invalidated(self):
+    def _is_hard_or_soft_invalidated(self) -> bool:
         return (
             self.dbapi_connection is None
             or self.__pool._invalidate_time > self.starttime
             or (self._soft_invalidate_time > self.starttime)
         )
 
-    def __close(self):
+    def __close(self) -> None:
         self.finalize_callback.clear()
         if self.__pool.dispatch.close:
             self.__pool.dispatch.close(self.dbapi_connection, self)
+        assert self.dbapi_connection is not None
         self.__pool._close_connection(self.dbapi_connection)
         self.dbapi_connection = None
 
-    def __connect(self):
+    def __connect(self) -> None:
         pool = self.__pool
 
         # ensure any existing connection is removed, so that if
@@ -688,14 +843,16 @@ class _ConnectionRecord:
 
 
 def _finalize_fairy(
-    dbapi_connection,
-    connection_record,
-    pool,
-    ref,  # this is None when called directly, not by the gc
-    echo,
-    reset=True,
-    fairy=None,
-):
+    dbapi_connection: Optional[DBAPIConnection],
+    connection_record: Optional[_ConnectionRecord],
+    pool: Pool,
+    ref: Optional[
+        weakref.ref[_ConnectionFairy]
+    ],  # this is None when called directly, not by the gc
+    echo: Optional[log._EchoFlagType],
+    reset: bool = True,
+    fairy: Optional[_ConnectionFairy] = None,
+) -> None:
     """Cleanup for a :class:`._ConnectionFairy` whether or not it's already
     been garbage collected.
 
@@ -705,12 +862,16 @@ def _finalize_fairy(
     will only log a message and raise a warning.
     """
 
-    if ref:
+    is_gc_cleanup = ref is not None
+
+    if is_gc_cleanup:
+        assert ref is not None
         _strong_ref_connection_records.pop(ref, None)
     elif fairy:
         _strong_ref_connection_records.pop(weakref.ref(fairy), None)
 
-    if ref is not None:
+    if is_gc_cleanup:
+        assert connection_record is not None
         if connection_record.fairy_ref is not ref:
             return
         assert dbapi_connection is None
@@ -720,10 +881,10 @@ def _finalize_fairy(
     dont_restore_gced = pool._dialect.is_async
 
     if dont_restore_gced:
-        detach = not connection_record or ref
-        can_manipulate_connection = not ref
+        detach = connection_record is None or is_gc_cleanup
+        can_manipulate_connection = ref is None
     else:
-        detach = not connection_record
+        detach = connection_record is None
         can_manipulate_connection = True
 
     if dbapi_connection is not None:
@@ -737,11 +898,14 @@ def _finalize_fairy(
             )
 
         try:
-            fairy = fairy or _ConnectionFairy(
-                dbapi_connection,
-                connection_record,
-                echo,
-            )
+            if not fairy:
+                assert connection_record is not None
+                fairy = _ConnectionFairy(
+                    pool,
+                    dbapi_connection,
+                    connection_record,
+                    echo,
+                )
             assert fairy.dbapi_connection is dbapi_connection
             if reset and can_manipulate_connection:
                 fairy._reset(pool)
@@ -786,6 +950,7 @@ def _finalize_fairy(
     # test/engine/test_pool.py::PoolEventsTest::test_checkin_event_gc[True]
     # which actually started failing when pytest warnings plugin was
     # turned on, due to util.warn() above
+    fairy.dbapi_connection = fairy._connection_record = None  # type: ignore
     del dbapi_connection
     del connection_record
     del fairy
@@ -795,53 +960,36 @@ def _finalize_fairy(
 # GC under pypy will call ConnectionFairy finalizers.  linked directly to the
 # weakref that will empty itself when collected so that it should not create
 # any unmanaged memory references.
-_strong_ref_connection_records = {}
+_strong_ref_connection_records: Dict[
+    weakref.ref[_ConnectionFairy], _ConnectionRecord
+] = {}
 
 
-class PoolProxiedConnection:
-    """interface for the wrapper connection that is used by the connection
-    pool.
+class PoolProxiedConnection(ManagesConnection):
+    """A connection-like adapter for a :pep:`249` DBAPI connection, which
+    includes additional methods specific to the :class:`.Pool` implementation.
 
-    :class:`.PoolProxiedConnection` is basically the public-facing interface
-    for the :class:`._ConnectionFairy` implementation object, users familiar
-    with :class:`._ConnectionFairy` can consider this object to be
-    equivalent.
+    :class:`.PoolProxiedConnection` is the public-facing interface for the
+    internal :class:`._ConnectionFairy` implementation object; users familiar
+    with :class:`._ConnectionFairy` can consider this object to be equivalent.
 
-    .. versionadded:: 2.0
+    .. versionadded:: 2.0  :class:`.PoolProxiedConnection` provides the public-
+       facing interface for the :class:`._ConnectionFairy` internal class.
 
     """
 
     __slots__ = ()
 
-    @util.memoized_property
-    def dbapi_connection(self) -> "DBAPIConnection":
-        """A reference to the actual DBAPI connection being tracked.
+    if typing.TYPE_CHECKING:
 
-        .. seealso::
+        def commit(self) -> None:
+            ...
 
-            :attr:`.PoolProxiedConnection.driver_connection`
+        def cursor(self) -> DBAPICursor:
+            ...
 
-            :attr:`.PoolProxiedConnection.dbapi_connection`
-
-            :ref:`faq_dbapi_connection`
-
-        """
-        raise NotImplementedError()
-
-    @property
-    def driver_connection(self) -> Any:
-        """The connection object as returned by the driver after a connect.
-
-        .. seealso::
-
-            :attr:`.PoolProxiedConnection.dbapi_connection`
-
-            :attr:`._ConnectionRecord.driver_connection`
-
-            :ref:`faq_dbapi_connection`
-
-        """
-        raise NotImplementedError()
+        def rollback(self) -> None:
+            ...
 
     @property
     def is_valid(self) -> bool:
@@ -850,62 +998,11 @@ class PoolProxiedConnection:
 
         raise NotImplementedError()
 
-    @util.memoized_property
-    def info(self) -> Dict[str, Any]:
-        """Info dictionary associated with the underlying DBAPI connection
-        referred to by this :class:`.ConnectionFairy`, allowing user-defined
-        data to be associated with the connection.
-
-        The data here will follow along with the DBAPI connection including
-        after it is returned to the connection pool and used again
-        in subsequent instances of :class:`._ConnectionFairy`.  It is shared
-        with the :attr:`._ConnectionRecord.info` and
-        :attr:`_engine.Connection.info`
-        accessors.
-
-        The dictionary associated with a particular DBAPI connection is
-        discarded when the connection itself is discarded.
-
-        """
-
-        raise NotImplementedError()
-
     @property
-    def record_info(self) -> Dict[str, Any]:
-        """Info dictionary associated with the :class:`._ConnectionRecord
-        container referred to by this :class:`.PoolProxiedConnection`.
-
-        Unlike the :attr:`.PoolProxiedConnection.info` dictionary, the lifespan
-        of this dictionary is persistent across connections that are
-        disconnected and/or invalidated within the lifespan of a
-        :class:`._ConnectionRecord`.
-
-        """
-
-        raise NotImplementedError()
+    def is_detached(self) -> bool:
+        """Return True if this :class:`.PoolProxiedConnection` is detached
+        from its pool."""
 
-    def invalidate(
-        self, e: Optional[Exception] = None, soft: bool = False
-    ) -> None:
-        """Mark this connection as invalidated.
-
-        This method can be called directly, and is also called as a result
-        of the :meth:`_engine.Connection.invalidate` method.   When invoked,
-        the DBAPI connection is immediately closed and discarded from
-        further use by the pool.  The invalidation mechanism proceeds
-        via the :meth:`._ConnectionRecord.invalidate` internal method.
-
-        :param e: an exception object indicating a reason for the invalidation.
-
-        :param soft: if True, the connection isn't closed; instead, this
-         connection will be recycled on next checkout.
-
-        .. seealso::
-
-            :ref:`pool_connection_invalidation`
-
-
-        """
         raise NotImplementedError()
 
     def detach(self) -> None:
@@ -913,8 +1010,8 @@ class PoolProxiedConnection:
 
         This means that the connection will no longer be returned to the
         pool when closed, and will instead be literally closed.  The
-        containing ConnectionRecord is separated from the DB-API connection,
-        and will create a new connection when next used.
+        associated :class:`.ConnectionPoolEntry` is de-associated from this
+        DBAPI connection.
 
         Note that any overall connection limiting constraints imposed by a
         Pool implementation may be violated after a detach, as the detached
@@ -953,43 +1050,37 @@ class _AdhocProxiedConnection(PoolProxiedConnection):
 
     __slots__ = ("dbapi_connection", "_connection_record")
 
-    def __init__(self, dbapi_connection, connection_record):
+    dbapi_connection: DBAPIConnection
+    _connection_record: ConnectionPoolEntry
+
+    def __init__(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+    ):
         self.dbapi_connection = dbapi_connection
         self._connection_record = connection_record
 
     @property
-    def driver_connection(self):
+    def driver_connection(self) -> Any:
         return self._connection_record.driver_connection
 
     @property
-    def connection(self):
-        """An alias to :attr:`._ConnectionFairy.dbapi_connection`.
-
-        This alias is deprecated, please use the new name.
-
-        .. deprecated:: 1.4.24
-
-        """
-        return self._dbapi_connection
+    def connection(self) -> DBAPIConnection:
+        return self.dbapi_connection
 
     @property
-    def is_valid(self):
+    def is_valid(self) -> bool:
         raise AttributeError("is_valid not implemented by this proxy")
 
-    @property
-    def record_info(self):
+    @util.dynamic_property
+    def record_info(self) -> Optional[Dict[str, Any]]:
         return self._connection_record.record_info
 
-    def cursor(self, *args, **kwargs):
-        """Return a new DBAPI cursor for the underlying connection.
-
-        This method is a proxy for the ``connection.cursor()`` DBAPI
-        method.
-
-        """
+    def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor:
         return self.dbapi_connection.cursor(*args, **kwargs)
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: Any) -> Any:
         return getattr(self.dbapi_connection, key)
 
 
@@ -1001,7 +1092,8 @@ class _ConnectionFairy(PoolProxiedConnection):
     This is an internal object used by the :class:`_pool.Pool` implementation
     to provide context management to a DBAPI connection delivered by
     that :class:`_pool.Pool`.   The public facing interface for this class
-    is described by the :class:`.PoolProxiedConnection` class.
+    is described by the :class:`.PoolProxiedConnection` class.  See that
+    class for public API details.
 
     The name "fairy" is inspired by the fact that the
     :class:`._ConnectionFairy` object's lifespan is transitory, as it lasts
@@ -1011,68 +1103,76 @@ class _ConnectionFairy(PoolProxiedConnection):
 
     .. seealso::
 
-        :class:`._ConnectionRecord`
-
-    """
+        :class:`.PoolProxiedConnection`
 
-    def __init__(self, dbapi_connection, connection_record, echo):
-        self.dbapi_connection = dbapi_connection
-        self._connection_record = connection_record
-        self._echo = echo
+        :class:`.ConnectionPoolEntry`
 
-    _connection_record = None
-    """A reference to the :class:`._ConnectionRecord` object associated
-    with the DBAPI connection.
-
-    This is currently an internal accessor which is subject to change.
 
     """
 
-    @property
-    def driver_connection(self):
-        """The connection object as returned by the driver after a connect.
+    __slots__ = (
+        "dbapi_connection",
+        "_connection_record",
+        "_echo",
+        "_pool",
+        "_counter",
+        "__weakref__",
+        "__dict__",
+    )
 
-        .. versionadded:: 1.4.24
-
-        .. seealso::
+    pool: Pool
+    dbapi_connection: DBAPIConnection
+    _echo: log._EchoFlagType
 
-            :attr:`._ConnectionFairy.dbapi_connection`
-
-            :attr:`._ConnectionRecord.driver_connection`
+    def __init__(
+        self,
+        pool: Pool,
+        dbapi_connection: DBAPIConnection,
+        connection_record: _ConnectionRecord,
+        echo: log._EchoFlagType,
+    ):
+        self._pool = pool
+        self._counter = 0
+        self.dbapi_connection = dbapi_connection
+        self._connection_record = connection_record
+        self._echo = echo
 
-            :ref:`faq_dbapi_connection`
+    _connection_record: Optional[_ConnectionRecord]
 
-        """
+    @property
+    def driver_connection(self) -> Optional[Any]:
+        if self._connection_record is None:
+            return None
         return self._connection_record.driver_connection
 
     @property
-    def connection(self):
-        """An alias to :attr:`._ConnectionFairy.dbapi_connection`.
-
-        This alias is deprecated, please use the new name.
-
-        .. deprecated:: 1.4.24
-
-        """
+    def connection(self) -> DBAPIConnection:
         return self.dbapi_connection
 
     @connection.setter
-    def connection(self, value):
+    def connection(self, value: DBAPIConnection) -> None:
         self.dbapi_connection = value
 
     @classmethod
-    def _checkout(cls, pool, threadconns=None, fairy=None):
+    def _checkout(
+        cls,
+        pool: Pool,
+        threadconns: Optional[threading.local] = None,
+        fairy: Optional[_ConnectionFairy] = None,
+    ) -> _ConnectionFairy:
         if not fairy:
             fairy = _ConnectionRecord.checkout(pool)
 
-            fairy._pool = pool
-            fairy._counter = 0
-
             if threadconns is not None:
                 threadconns.current = weakref.ref(fairy)
 
-        if fairy.dbapi_connection is None:
-            raise exc.InvalidRequestError("This connection is closed")
+        assert (
+            fairy._connection_record is not None
+        ), "can't 'checkout' a detached connection fairy"
+        assert (
+            fairy.dbapi_connection is not None
+        ), "can't 'checkout' an invalidated connection fairy"
+
         fairy._counter += 1
         if (
             not pool.dispatch.checkout and not pool._pre_ping
@@ -1084,6 +1184,7 @@ class _ConnectionFairy(PoolProxiedConnection):
         # there are three attempts made here, but note that if the database
         # is not accessible from a connection standpoint, those won't proceed
         # here.
+
         attempts = 2
         while attempts > 0:
             connection_is_fresh = fairy._connection_record.fresh
@@ -1160,10 +1261,10 @@ class _ConnectionFairy(PoolProxiedConnection):
         fairy.invalidate()
         raise exc.InvalidRequestError("This connection is closed")
 
-    def _checkout_existing(self):
+    def _checkout_existing(self) -> _ConnectionFairy:
         return _ConnectionFairy._checkout(self._pool, fairy=self)
 
-    def _checkin(self, reset=True):
+    def _checkin(self, reset: bool = True) -> None:
         _finalize_fairy(
             self.dbapi_connection,
             self._connection_record,
@@ -1173,14 +1274,13 @@ class _ConnectionFairy(PoolProxiedConnection):
             reset=reset,
             fairy=self,
         )
-        self.dbapi_connection = None
-        self._connection_record = None
 
-    _close = _checkin
+    def _close(self) -> None:
+        self._checkin()
 
-    def _reset(self, pool):
+    def _reset(self, pool: Pool) -> None:
         if pool.dispatch.reset:
-            pool.dispatch.reset(self, self._connection_record)
+            pool.dispatch.reset(self.dbapi_connection, self._connection_record)
         if pool._reset_on_return is reset_rollback:
             if self._echo:
                 pool.logger.debug(
@@ -1196,50 +1296,34 @@ class _ConnectionFairy(PoolProxiedConnection):
             pool._dialect.do_commit(self)
 
     @property
-    def _logger(self):
+    def _logger(self) -> log._IdentifiedLoggerType:
         return self._pool.logger
 
     @property
-    def is_valid(self):
-        """Return True if this :class:`._ConnectionFairy` still refers
-        to an active DBAPI connection."""
-
+    def is_valid(self) -> bool:
         return self.dbapi_connection is not None
 
-    @util.memoized_property
-    def info(self):
-        """Info dictionary associated with the underlying DBAPI connection
-        referred to by this :class:`.ConnectionFairy`, allowing user-defined
-        data to be associated with the connection.
-
-        See :attr:`.PoolProxiedConnection.info` for full description.
-
-        """
-        return self._connection_record.info
-
     @property
-    def record_info(self):
-        """Info dictionary associated with the :class:`._ConnectionRecord
-        container referred to by this :class:`.ConnectionFairy`.
+    def is_detached(self) -> bool:
+        return self._connection_record is not None
 
-        See :attr:`.PoolProxiedConnection.record_info` for full description.
-
-        """
-        if self._connection_record:
-            return self._connection_record.record_info
+    @util.memoized_property
+    def info(self) -> Dict[str, Any]:
+        if self._connection_record is None:
+            return {}
         else:
-            return None
-
-    def invalidate(self, e=None, soft=False):
-        """Mark this connection as invalidated.
-
-        See :attr:`.PoolProxiedConnection.invalidate` for full description.
-
-        .. seealso::
+            return self._connection_record.info
 
-            :ref:`pool_connection_invalidation`
+    @util.dynamic_property
+    def record_info(self) -> Optional[Dict[str, Any]]:
+        if self._connection_record is None:
+            return None
+        else:
+            return self._connection_record.record_info
 
-        """
+    def invalidate(
+        self, e: Optional[BaseException] = None, soft: bool = False
+    ) -> None:
 
         if self.dbapi_connection is None:
             util.warn("Can't invalidate an already-closed connection.")
@@ -1247,51 +1331,43 @@ class _ConnectionFairy(PoolProxiedConnection):
         if self._connection_record:
             self._connection_record.invalidate(e=e, soft=soft)
         if not soft:
-            self.dbapi_connection = None
-            self._checkin()
-
-    def cursor(self, *args, **kwargs):
-        """Return a new DBAPI cursor for the underlying connection.
+            # prevent any rollback / reset actions etc. on
+            # the connection
+            self.dbapi_connection = None  # type: ignore
 
-        This method is a proxy for the ``connection.cursor()`` DBAPI
-        method.
+            # finalize
+            self._checkin()
 
-        """
+    def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor:
+        assert self.dbapi_connection is not None
         return self.dbapi_connection.cursor(*args, **kwargs)
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Any:
         return getattr(self.dbapi_connection, key)
 
-    def detach(self):
-        """Separate this connection from its Pool.
-
-        See :meth:`.PoolProxiedConnection.detach` for full description.
-
-        """
-
+    def detach(self) -> None:
         if self._connection_record is not None:
             rec = self._connection_record
             rec.fairy_ref = None
             rec.dbapi_connection = None
             # TODO: should this be _return_conn?
             self._pool._do_return_conn(self._connection_record)
-            self.info = self.info.copy()
+
+            # can't get the descriptor assignment to work here
+            # in pylance.  mypy is OK w/ it
+            self.info = self.info.copy()  # type: ignore
+
             self._connection_record = None
 
             if self._pool.dispatch.detach:
                 self._pool.dispatch.detach(self.dbapi_connection, rec)
 
-    def close(self):
-        """Release this connection back to the pool.
-
-        See :meth:`.PoolProxiedConnection.close` for full description.
-
-        """
+    def close(self) -> None:
         self._counter -= 1
         if self._counter == 0:
             self._checkin()
 
-    def _close_no_reset(self):
+    def _close_no_reset(self) -> None:
         self._counter -= 1
         if self._counter == 0:
             self._checkin(reset=False)
index e53d614b0ebd25453f43ea78b0041fa6b9e55650..d0d89291bc1ed76e95c96342d19cef93c8e54659 100644 (file)
@@ -4,13 +4,26 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
 
+import typing
+from typing import Any
+from typing import Optional
+from typing import Type
+from typing import Union
+
+from .base import ConnectionPoolEntry
 from .base import Pool
+from .base import PoolProxiedConnection
 from .. import event
 from .. import util
 
+if typing.TYPE_CHECKING:
+    from ..engine import Engine
+    from ..engine.interfaces import DBAPIConnection
+
 
-class PoolEvents(event.Events):
+class PoolEvents(event.Events[Pool]):
     """Available events for :class:`_pool.Pool`.
 
     The methods here define the name of an event as well
@@ -37,35 +50,48 @@ class PoolEvents(event.Events):
         # will associate with engine.pool
         event.listen(engine, 'checkout', my_on_checkout)
 
-    """  # noqa
+    """  # noqa E501
 
     _target_class_doc = "SomeEngineOrPool"
     _dispatch_target = Pool
 
     @util.preload_module("sqlalchemy.engine")
     @classmethod
-    def _accept_with(cls, target):
-        Engine = util.preloaded.engine.Engine
+    def _accept_with(
+        cls, target: Union[Pool, Type[Pool], Engine, Type[Engine]]
+    ) -> Union[Pool, Type[Pool]]:
+        if not typing.TYPE_CHECKING:
+            Engine = util.preloaded.engine.Engine
 
         if isinstance(target, type):
             if issubclass(target, Engine):
                 return Pool
-            elif issubclass(target, Pool):
+            else:
+                assert issubclass(target, Pool)
                 return target
         elif isinstance(target, Engine):
             return target.pool
         else:
+            assert isinstance(target, Pool)
             return target
 
     @classmethod
-    def _listen(cls, event_key, **kw):
+    def _listen(  # type: ignore[override]   # would rather keep **kw
+        cls,
+        event_key: event._EventKey[Pool],
+        **kw: Any,
+    ) -> None:
         target = event_key.dispatch_target
 
         kw.setdefault("asyncio", target._is_asyncio)
 
         event_key.base_listen(**kw)
 
-    def connect(self, dbapi_connection, connection_record):
+    def connect(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+    ) -> None:
         """Called at the moment a particular DBAPI connection is first
         created for a given :class:`_pool.Pool`.
 
@@ -74,14 +100,18 @@ class PoolEvents(event.Events):
         to produce a new DBAPI connection.
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
         """
 
-    def first_connect(self, dbapi_connection, connection_record):
+    def first_connect(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+    ) -> None:
         """Called exactly once for the first time a DBAPI connection is
         checked out from a particular :class:`_pool.Pool`.
 
@@ -99,24 +129,29 @@ class PoolEvents(event.Events):
         encoding settings, collation settings, and many others.
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
         """
 
-    def checkout(self, dbapi_connection, connection_record, connection_proxy):
+    def checkout(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+        connection_proxy: PoolProxiedConnection,
+    ) -> None:
         """Called when a connection is retrieved from the Pool.
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
-        :param connection_proxy: the :class:`._ConnectionFairy` object which
-          will proxy the public interface of the DBAPI connection for the
+        :param connection_proxy: the :class:`.PoolProxiedConnection` object
+          which will proxy the public interface of the DBAPI connection for the
           lifespan of the checkout.
 
         If you raise a :class:`~sqlalchemy.exc.DisconnectionError`, the current
@@ -130,7 +165,11 @@ class PoolEvents(event.Events):
 
         """
 
-    def checkin(self, dbapi_connection, connection_record):
+    def checkin(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+    ) -> None:
         """Called when a connection returns to the pool.
 
         Note that the connection may be closed, and may be None if the
@@ -138,14 +177,18 @@ class PoolEvents(event.Events):
         for detached connections.  (They do not return to the pool.)
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
         """
 
-    def reset(self, dbapi_connection, connection_record):
+    def reset(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+    ) -> None:
         """Called before the "reset" action occurs for a pooled connection.
 
         This event represents
@@ -160,10 +203,10 @@ class PoolEvents(event.Events):
         cases where the connection is discarded immediately after reset.
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
         .. seealso::
 
@@ -173,21 +216,26 @@ class PoolEvents(event.Events):
 
         """
 
-    def invalidate(self, dbapi_connection, connection_record, exception):
+    def invalidate(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+        exception: Optional[BaseException],
+    ) -> None:
         """Called when a DBAPI connection is to be "invalidated".
 
-        This event is called any time the :meth:`._ConnectionRecord.invalidate`
-        method is invoked, either from API usage or via "auto-invalidation",
-        without the ``soft`` flag.
+        This event is called any time the
+        :meth:`.ConnectionPoolEntry.invalidate` method is invoked, either from
+        API usage or via "auto-invalidation", without the ``soft`` flag.
 
         The event occurs before a final attempt to call ``.close()`` on the
         connection occurs.
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
         :param exception: the exception object corresponding to the reason
          for this invalidation, if any.  May be ``None``.
@@ -201,10 +249,16 @@ class PoolEvents(event.Events):
 
         """
 
-    def soft_invalidate(self, dbapi_connection, connection_record, exception):
+    def soft_invalidate(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+        exception: Optional[BaseException],
+    ) -> None:
         """Called when a DBAPI connection is to be "soft invalidated".
 
-        This event is called any time the :meth:`._ConnectionRecord.invalidate`
+        This event is called any time the
+        :meth:`.ConnectionPoolEntry.invalidate`
         method is invoked with the ``soft`` flag.
 
         Soft invalidation refers to when the connection record that tracks
@@ -215,17 +269,21 @@ class PoolEvents(event.Events):
         .. versionadded:: 1.0.3
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
         :param exception: the exception object corresponding to the reason
          for this invalidation, if any.  May be ``None``.
 
         """
 
-    def close(self, dbapi_connection, connection_record):
+    def close(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+    ) -> None:
         """Called when a DBAPI connection is closed.
 
         The event is emitted before the close occurs.
@@ -241,14 +299,18 @@ class PoolEvents(event.Events):
         .. versionadded:: 1.1
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
         """
 
-    def detach(self, dbapi_connection, connection_record):
+    def detach(
+        self,
+        dbapi_connection: DBAPIConnection,
+        connection_record: ConnectionPoolEntry,
+    ) -> None:
         """Called when a DBAPI connection is "detached" from a pool.
 
         This event is emitted after the detach occurs.  The connection
@@ -257,14 +319,14 @@ class PoolEvents(event.Events):
         .. versionadded:: 1.1
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
-        :param connection_record: the :class:`._ConnectionRecord` managing the
-         DBAPI connection.
+        :param connection_record: the :class:`.ConnectionPoolEntry` managing
+         the DBAPI connection.
 
         """
 
-    def close_detached(self, dbapi_connection):
+    def close_detached(self, dbapi_connection: DBAPIConnection) -> None:
         """Called when a detached DBAPI connection is closed.
 
         The event is emitted before the close occurs.
@@ -276,6 +338,6 @@ class PoolEvents(event.Events):
         .. versionadded:: 1.1
 
         :param dbapi_connection: a DBAPI connection.
-         The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+         The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute.
 
         """
index 7a422cd2ac97a0dedf80e028f55a0487b5cc28b2..d1be3f54195ddf4169017065c3bb17439852c3a0 100644 (file)
@@ -9,19 +9,36 @@
 """Pool implementation classes.
 
 """
+from __future__ import annotations
 
 import threading
 import traceback
+import typing
+from typing import Any
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Type
+from typing import Union
 import weakref
 
 from .base import _AsyncConnDialect
 from .base import _ConnectionFairy
 from .base import _ConnectionRecord
+from .base import _CreatorFnType
+from .base import _CreatorWRecFnType
+from .base import ConnectionPoolEntry
 from .base import Pool
+from .base import PoolProxiedConnection
 from .. import exc
 from .. import util
 from ..util import chop_traceback
 from ..util import queue as sqla_queue
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+    from ..engine.interfaces import DBAPIConnection
 
 
 class QueuePool(Pool):
@@ -34,17 +51,22 @@ class QueuePool(Pool):
 
     """
 
-    _is_asyncio = False
-    _queue_class = sqla_queue.Queue
+    _is_asyncio = False  # type: ignore[assignment]
+
+    _queue_class: Type[
+        sqla_queue.QueueCommon[ConnectionPoolEntry]
+    ] = sqla_queue.Queue
+
+    _pool: sqla_queue.QueueCommon[ConnectionPoolEntry]
 
     def __init__(
         self,
-        creator,
-        pool_size=5,
-        max_overflow=10,
-        timeout=30.0,
-        use_lifo=False,
-        **kw,
+        creator: Union[_CreatorFnType, _CreatorWRecFnType],
+        pool_size: int = 5,
+        max_overflow: int = 10,
+        timeout: float = 30.0,
+        use_lifo: bool = False,
+        **kw: Any,
     ):
         r"""
         Construct a QueuePool.
@@ -107,20 +129,20 @@ class QueuePool(Pool):
         self._timeout = timeout
         self._overflow_lock = threading.Lock()
 
-    def _do_return_conn(self, conn):
+    def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
         try:
-            self._pool.put(conn, False)
+            self._pool.put(record, False)
         except sqla_queue.Full:
             try:
-                conn.close()
+                record.close()
             finally:
                 self._dec_overflow()
 
-    def _do_get(self):
+    def _do_get(self) -> ConnectionPoolEntry:
         use_overflow = self._max_overflow > -1
 
+        wait = use_overflow and self._overflow >= self._max_overflow
         try:
-            wait = use_overflow and self._overflow >= self._max_overflow
             return self._pool.get(wait, self._timeout)
         except sqla_queue.Empty:
             # don't do things inside of "except Empty", because when we say
@@ -144,10 +166,11 @@ class QueuePool(Pool):
             except:
                 with util.safe_reraise():
                     self._dec_overflow()
+                raise
         else:
             return self._do_get()
 
-    def _inc_overflow(self):
+    def _inc_overflow(self) -> bool:
         if self._max_overflow == -1:
             self._overflow += 1
             return True
@@ -158,7 +181,7 @@ class QueuePool(Pool):
             else:
                 return False
 
-    def _dec_overflow(self):
+    def _dec_overflow(self) -> Literal[True]:
         if self._max_overflow == -1:
             self._overflow -= 1
             return True
@@ -166,7 +189,7 @@ class QueuePool(Pool):
             self._overflow -= 1
             return True
 
-    def recreate(self):
+    def recreate(self) -> QueuePool:
         self.logger.info("Pool recreating")
         return self.__class__(
             self._creator,
@@ -183,7 +206,7 @@ class QueuePool(Pool):
             dialect=self._dialect,
         )
 
-    def dispose(self):
+    def dispose(self) -> None:
         while True:
             try:
                 conn = self._pool.get(False)
@@ -194,7 +217,7 @@ class QueuePool(Pool):
         self._overflow = 0 - self.size()
         self.logger.info("Pool disposed. %s", self.status())
 
-    def status(self):
+    def status(self) -> str:
         return (
             "Pool size: %d  Connections in pool: %d "
             "Current Overflow: %d Current Checked out "
@@ -207,25 +230,28 @@ class QueuePool(Pool):
             )
         )
 
-    def size(self):
+    def size(self) -> int:
         return self._pool.maxsize
 
-    def timeout(self):
+    def timeout(self) -> float:
         return self._timeout
 
-    def checkedin(self):
+    def checkedin(self) -> int:
         return self._pool.qsize()
 
-    def overflow(self):
+    def overflow(self) -> int:
         return self._overflow
 
-    def checkedout(self):
+    def checkedout(self) -> int:
         return self._pool.maxsize - self._pool.qsize() + self._overflow
 
 
 class AsyncAdaptedQueuePool(QueuePool):
-    _is_asyncio = True
-    _queue_class = sqla_queue.AsyncAdaptedQueue
+    _is_asyncio = True  # type: ignore[assignment]
+    _queue_class: Type[
+        sqla_queue.QueueCommon[ConnectionPoolEntry]
+    ] = sqla_queue.AsyncAdaptedQueue
+
     _dialect = _AsyncConnDialect()
 
 
@@ -246,16 +272,16 @@ class NullPool(Pool):
 
     """
 
-    def status(self):
+    def status(self) -> str:
         return "NullPool"
 
-    def _do_return_conn(self, conn):
-        conn.close()
+    def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
+        record.close()
 
-    def _do_get(self):
+    def _do_get(self) -> ConnectionPoolEntry:
         return self._create_connection()
 
-    def recreate(self):
+    def recreate(self) -> NullPool:
         self.logger.info("Pool recreating")
 
         return self.__class__(
@@ -269,7 +295,7 @@ class NullPool(Pool):
             dialect=self._dialect,
         )
 
-    def dispose(self):
+    def dispose(self) -> None:
         pass
 
 
@@ -304,16 +330,21 @@ class SingletonThreadPool(Pool):
 
     """
 
-    _is_asyncio = False
+    _is_asyncio = False  # type: ignore[assignment]
 
-    def __init__(self, creator, pool_size=5, **kw):
+    def __init__(
+        self,
+        creator: Union[_CreatorFnType, _CreatorWRecFnType],
+        pool_size: int = 5,
+        **kw: Any,
+    ):
         Pool.__init__(self, creator, **kw)
         self._conn = threading.local()
         self._fairy = threading.local()
-        self._all_conns = set()
+        self._all_conns: Set[ConnectionPoolEntry] = set()
         self.size = pool_size
 
-    def recreate(self):
+    def recreate(self) -> SingletonThreadPool:
         self.logger.info("Pool recreating")
         return self.__class__(
             self._creator,
@@ -327,7 +358,7 @@ class SingletonThreadPool(Pool):
             dialect=self._dialect,
         )
 
-    def dispose(self):
+    def dispose(self) -> None:
         """Dispose of this pool."""
 
         for conn in self._all_conns:
@@ -340,23 +371,26 @@ class SingletonThreadPool(Pool):
 
         self._all_conns.clear()
 
-    def _cleanup(self):
+    def _cleanup(self) -> None:
         while len(self._all_conns) >= self.size:
             c = self._all_conns.pop()
             c.close()
 
-    def status(self):
+    def status(self) -> str:
         return "SingletonThreadPool id:%d size: %d" % (
             id(self),
             len(self._all_conns),
         )
 
-    def _do_return_conn(self, conn):
-        pass
+    def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
+        try:
+            del self._fairy.current  # type: ignore
+        except AttributeError:
+            pass
 
-    def _do_get(self):
+    def _do_get(self) -> ConnectionPoolEntry:
         try:
-            c = self._conn.current()
+            c = cast(ConnectionPoolEntry, self._conn.current())
             if c:
                 return c
         except AttributeError:
@@ -368,11 +402,11 @@ class SingletonThreadPool(Pool):
         self._all_conns.add(c)
         return c
 
-    def connect(self):
+    def connect(self) -> PoolProxiedConnection:
         # vendored from Pool to include the now removed use_threadlocal
         # behavior
         try:
-            rec = self._fairy.current()
+            rec = cast(_ConnectionFairy, self._fairy.current())
         except AttributeError:
             pass
         else:
@@ -381,13 +415,6 @@ class SingletonThreadPool(Pool):
 
         return _ConnectionFairy._checkout(self, self._fairy)
 
-    def _return_conn(self, record):
-        try:
-            del self._fairy.current
-        except AttributeError:
-            pass
-        self._do_return_conn(record)
-
 
 class StaticPool(Pool):
 
@@ -401,13 +428,13 @@ class StaticPool(Pool):
     """
 
     @util.memoized_property
-    def connection(self):
+    def connection(self) -> _ConnectionRecord:
         return _ConnectionRecord(self)
 
-    def status(self):
+    def status(self) -> str:
         return "StaticPool"
 
-    def dispose(self):
+    def dispose(self) -> None:
         if (
             "connection" in self.__dict__
             and self.connection.dbapi_connection is not None
@@ -415,7 +442,7 @@ class StaticPool(Pool):
             self.connection.close()
             del self.__dict__["connection"]
 
-    def recreate(self):
+    def recreate(self) -> StaticPool:
         self.logger.info("Pool recreating")
         return self.__class__(
             creator=self._creator,
@@ -428,20 +455,23 @@ class StaticPool(Pool):
             dialect=self._dialect,
         )
 
-    def _transfer_from(self, other_static_pool):
+    def _transfer_from(self, other_static_pool: StaticPool) -> None:
         # used by the test suite to make a new engine / pool without
         # losing the state of an existing SQLite :memory: connection
-        self._invoke_creator = (
-            lambda crec: other_static_pool.connection.dbapi_connection
-        )
+        def creator(rec: ConnectionPoolEntry) -> DBAPIConnection:
+            conn = other_static_pool.connection.dbapi_connection
+            assert conn is not None
+            return conn
 
-    def _create_connection(self):
+        self._invoke_creator = creator
+
+    def _create_connection(self) -> ConnectionPoolEntry:
         raise NotImplementedError()
 
-    def _do_return_conn(self, conn):
+    def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
         pass
 
-    def _do_get(self):
+    def _do_get(self) -> ConnectionPoolEntry:
         rec = self.connection
         if rec._is_hard_or_soft_invalidated():
             del self.__dict__["connection"]
@@ -461,28 +491,31 @@ class AssertionPool(Pool):
 
     """
 
-    def __init__(self, *args, **kw):
+    _conn: Optional[ConnectionPoolEntry]
+    _checkout_traceback: Optional[List[str]]
+
+    def __init__(self, *args: Any, **kw: Any):
         self._conn = None
         self._checked_out = False
         self._store_traceback = kw.pop("store_traceback", True)
         self._checkout_traceback = None
         Pool.__init__(self, *args, **kw)
 
-    def status(self):
+    def status(self) -> str:
         return "AssertionPool"
 
-    def _do_return_conn(self, conn):
+    def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
         if not self._checked_out:
             raise AssertionError("connection is not checked out")
         self._checked_out = False
-        assert conn is self._conn
+        assert record is self._conn
 
-    def dispose(self):
+    def dispose(self) -> None:
         self._checked_out = False
         if self._conn:
             self._conn.close()
 
-    def recreate(self):
+    def recreate(self) -> AssertionPool:
         self.logger.info("Pool recreating")
         return self.__class__(
             self._creator,
@@ -495,7 +528,7 @@ class AssertionPool(Pool):
             dialect=self._dialect,
         )
 
-    def _do_get(self):
+    def _do_get(self) -> ConnectionPoolEntry:
         if self._checked_out:
             if self._checkout_traceback:
                 suffix = " at:\n%s" % "".join(
index 85bbca20f5bb43925d47aa6c194023a323ee0f1f..a41420504567f4ef38813fd844eb93ce9c8e5c71 100644 (file)
@@ -94,6 +94,7 @@ from .langhelpers import decode_slice as decode_slice
 from .langhelpers import decorator as decorator
 from .langhelpers import dictlike_iteritems as dictlike_iteritems
 from .langhelpers import duck_type_collection as duck_type_collection
+from .langhelpers import dynamic_property as dynamic_property
 from .langhelpers import ellipses_string as ellipses_string
 from .langhelpers import EnsureKWArg as EnsureKWArg
 from .langhelpers import format_argspec_init as format_argspec_init
@@ -122,6 +123,9 @@ from .langhelpers import (
 )
 from .langhelpers import NoneType as NoneType
 from .langhelpers import only_once as only_once
+from .langhelpers import (
+    parse_user_argument_for_enum as parse_user_argument_for_enum,
+)
 from .langhelpers import PluginLoader as PluginLoader
 from .langhelpers import portable_instancemethod as portable_instancemethod
 from .langhelpers import quoted_token_parser as quoted_token_parser
index fa206670271f93491622655698053bdbdceb298a..b17b408dd7bf972cd3296ce44847291583587f7e 100644 (file)
@@ -10,13 +10,37 @@ from contextvars import copy_context as _copy_context
 import sys
 import typing
 from typing import Any
+from typing import Awaitable
 from typing import Callable
 from typing import Coroutine
-
-import greenlet  # type: ignore # noqa
+from typing import TypeVar
 
 from .langhelpers import memoized_property
 from .. import exc
+from ..util.typing import Protocol
+
+if typing.TYPE_CHECKING:
+
+    class greenlet(Protocol):
+
+        dead: bool
+
+        def __init__(self, fn: Callable[..., Any], driver: "greenlet"):
+            ...
+
+        def throw(self, *arg: Any) -> Any:
+            ...
+
+        def switch(self, value: Any) -> Any:
+            ...
+
+    def getcurrent() -> greenlet:
+        ...
+
+else:
+    from greenlet import getcurrent
+    from greenlet import greenlet
+
 
 if not typing.TYPE_CHECKING:
     try:
@@ -24,12 +48,14 @@ if not typing.TYPE_CHECKING:
         # If greenlet.gr_context is present in current version of greenlet,
         # it will be set with a copy of the current context on creation.
         # Refs: https://github.com/python-greenlet/greenlet/pull/198
-        getattr(greenlet.greenlet, "gr_context")
+        getattr(greenlet, "gr_context")
     except (ImportError, AttributeError):
         _copy_context = None  # noqa
 
+_T = TypeVar("_T", bound=Any)
 
-def is_exit_exception(e):
+
+def is_exit_exception(e: BaseException) -> bool:
     # note asyncio.CancelledError is already BaseException
     # so was an exit exception in any case
     return not isinstance(e, Exception) or isinstance(
@@ -42,15 +68,17 @@ def is_exit_exception(e):
 # Issue for context: https://github.com/python-greenlet/greenlet/issues/173
 
 
-class _AsyncIoGreenlet(greenlet.greenlet):  # type: ignore
-    def __init__(self, fn, driver):
-        greenlet.greenlet.__init__(self, fn, driver)
+class _AsyncIoGreenlet(greenlet):  # type: ignore
+    dead: bool
+
+    def __init__(self, fn: Callable[..., Any], driver: greenlet):
+        greenlet.__init__(self, fn, driver)
         self.driver = driver
         if _copy_context is not None:
             self.gr_context = _copy_context()
 
 
-def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any:
+def await_only(awaitable: Awaitable[_T]) -> _T:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
@@ -60,7 +88,7 @@ def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any:
 
     """
     # this is called in the context greenlet while running fn
-    current = greenlet.getcurrent()
+    current = getcurrent()
     if not isinstance(current, _AsyncIoGreenlet):
         raise exc.MissingGreenlet(
             "greenlet_spawn has not been called; can't call await_() here. "
@@ -71,10 +99,10 @@ def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any:
     # a coroutine to run. Once the awaitable is done, the driver greenlet
     # switches back to this greenlet with the result of awaitable that is
     # then returned to the caller (or raised as error)
-    return current.driver.switch(awaitable)
+    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
 
 
-def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any:
+def await_fallback(awaitable: Awaitable[_T]) -> _T:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
@@ -83,8 +111,9 @@ def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any:
     :param awaitable: The coroutine to call.
 
     """
+
     # this is called in the context greenlet while running fn
-    current = greenlet.getcurrent()
+    current = getcurrent()
     if not isinstance(current, _AsyncIoGreenlet):
         loop = get_event_loop()
         if loop.is_running():
@@ -93,9 +122,9 @@ def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any:
                 "loop is already running; can't call await_() here. "
                 "Was IO attempted in an unexpected place?"
             )
-        return loop.run_until_complete(awaitable)
+        return loop.run_until_complete(awaitable)  # type: ignore[no-any-return]  # noqa E501
 
-    return current.driver.switch(awaitable)
+    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
 
 
 async def greenlet_spawn(
@@ -114,7 +143,7 @@ async def greenlet_spawn(
     :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
     """
 
-    context = _AsyncIoGreenlet(fn, greenlet.getcurrent())
+    context = _AsyncIoGreenlet(fn, getcurrent())
     # runs the function synchronously in gl greenlet. If the execution
     # is interrupted by await_, context is not dead and result is a
     # coroutine to wait. If the context is dead the function has
@@ -149,21 +178,23 @@ async def greenlet_spawn(
 
 class AsyncAdaptedLock:
     @memoized_property
-    def mutex(self):
+    def mutex(self) -> asyncio.Lock:
         # there should not be a race here for coroutines creating the
         # new lock as we are not using await, so therefore no concurrency
         return asyncio.Lock()
 
-    def __enter__(self):
+    def __enter__(self) -> bool:
         # await is used to acquire the lock only after the first calling
         # coroutine has created the mutex.
         return await_fallback(self.mutex.acquire())
 
-    def __exit__(self, *arg, **kw):
+    def __exit__(self, *arg: Any, **kw: Any) -> None:
         self.mutex.release()
 
 
-def _util_async_run_coroutine_function(fn, *args, **kwargs):
+def _util_async_run_coroutine_function(
+    fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
+) -> Any:
     """for test suite/ util only"""
 
     loop = get_event_loop()
@@ -175,7 +206,10 @@ def _util_async_run_coroutine_function(fn, *args, **kwargs):
     return loop.run_until_complete(fn(*args, **kwargs))
 
 
-def _util_async_run(fn, *args, **kwargs):
+def _util_async_run(
+    fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
+) -> Any:
+
     """for test suite/ util only"""
 
     loop = get_event_loop()
@@ -183,11 +217,11 @@ def _util_async_run(fn, *args, **kwargs):
         return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
     else:
         # allow for a wrapped test function to call another
-        assert isinstance(greenlet.getcurrent(), _AsyncIoGreenlet)
+        assert isinstance(getcurrent(), _AsyncIoGreenlet)
         return fn(*args, **kwargs)
 
 
-def get_event_loop():
+def get_event_loop() -> asyncio.AbstractEventLoop:
     """vendor asyncio.get_event_loop() for python 3.7 and above.
 
     Python 3.10 deprecates get_event_loop() as a standalone.
index f91d902dae6722f31f2836be529d8a3ee75ba3e2..7e1d3213ab018b11ad9d8b38974a9104b98294df 100644 (file)
@@ -14,23 +14,32 @@ import re
 from typing import Any
 from typing import Callable
 from typing import cast
+from typing import Dict
+from typing import Match
 from typing import Optional
+from typing import Sequence
+from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TypeVar
+from typing import Union
 
 from . import compat
 from .langhelpers import _hash_limit_string
 from .langhelpers import _warnings_warn
 from .langhelpers import decorator
+from .langhelpers import dynamic_property
 from .langhelpers import inject_docstring_text
 from .langhelpers import inject_param_text
-from .typing import ReadOnlyInstanceDescriptor
 from .. import exc
 
 _T = TypeVar("_T", bound=Any)
 
 
+# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
+_F = TypeVar("_F", bound=Callable[..., Any])
+
+
 def _warn_with_version(
     msg: str,
     version: str,
@@ -52,7 +61,13 @@ def warn_deprecated(
     )
 
 
-def warn_deprecated_limited(msg, args, version, stacklevel=3, code=None):
+def warn_deprecated_limited(
+    msg: str,
+    args: Sequence[Any],
+    version: str,
+    stacklevel: int = 3,
+    code: Optional[str] = None,
+) -> None:
     """Issue a deprecation warning with a parameterized string,
     limiting the number of registrations.
 
@@ -64,10 +79,12 @@ def warn_deprecated_limited(msg, args, version, stacklevel=3, code=None):
     )
 
 
-def deprecated_cls(version, message, constructor="__init__"):
+def deprecated_cls(
+    version: str, message: str, constructor: str = "__init__"
+) -> Callable[[Type[_T]], Type[_T]]:
     header = ".. deprecated:: %s %s" % (version, (message or ""))
 
-    def decorate(cls):
+    def decorate(cls: Type[_T]) -> Type[_T]:
         return _decorate_cls_with_warning(
             cls,
             constructor,
@@ -84,9 +101,9 @@ def deprecated_property(
     version: str,
     message: Optional[str] = None,
     add_deprecation_to_docstring: bool = True,
-    warning: Optional[str] = None,
+    warning: Optional[Type[exc.SADeprecationWarning]] = None,
     enable_warnings: bool = True,
-) -> Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]]:
+) -> Callable[[Callable[..., _T]], dynamic_property[_T]]:
     """the @deprecated decorator with a @property.
 
     E.g.::
@@ -113,9 +130,9 @@ def deprecated_property(
     great!   now it is.
 
     """
-    return cast(
-        Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]],
-        lambda fn: property(
+
+    def decorate(fn: Callable[..., _T]) -> dynamic_property[_T]:
+        return dynamic_property(
             deprecated(
                 version,
                 message=message,
@@ -123,17 +140,18 @@ def deprecated_property(
                 warning=warning,
                 enable_warnings=enable_warnings,
             )(fn)
-        ),
-    )
+        )
+
+    return decorate
 
 
 def deprecated(
-    version,
-    message=None,
-    add_deprecation_to_docstring=True,
-    warning=None,
-    enable_warnings=True,
-):
+    version: str,
+    message: Optional[str] = None,
+    add_deprecation_to_docstring: bool = True,
+    warning: Optional[Type[exc.SADeprecationWarning]] = None,
+    enable_warnings: bool = True,
+) -> Callable[[_F], _F]:
     """Decorates a function and issues a deprecation warning on use.
 
     :param version:
@@ -166,7 +184,9 @@ def deprecated(
 
     message += " (deprecated since: %s)" % version
 
-    def decorate(fn):
+    def decorate(fn: _F) -> _F:
+        assert message is not None
+        assert warning is not None
         return _decorate_with_warning(
             fn,
             warning,
@@ -179,13 +199,17 @@ def deprecated(
     return decorate
 
 
-def moved_20(message, **kw):
+def moved_20(
+    message: str, **kw: Any
+) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
     return deprecated(
         "2.0", message=message, warning=exc.MovedIn20Warning, **kw
     )
 
 
-def became_legacy_20(api_name, alternative=None, **kw):
+def became_legacy_20(
+    api_name: str, alternative: Optional[str] = None, **kw: Any
+) -> Callable[[_F], _F]:
     type_reg = re.match("^:(attr|func|meth):", api_name)
     if type_reg:
         type_ = {"attr": "attribute", "func": "function", "meth": "method"}[
@@ -221,10 +245,7 @@ def became_legacy_20(api_name, alternative=None, **kw):
     return deprecated("2.0", message=message, warning=warning_cls, **kw)
 
 
-_C = TypeVar("_C", bound=Callable[..., Any])
-
-
-def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]:
+def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]:
     """Decorates a function to warn on use of certain parameters.
 
     e.g. ::
@@ -240,18 +261,19 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]:
 
     """
 
-    messages = {}
-    versions = {}
-    version_warnings = {}
+    messages: Dict[str, str] = {}
+    versions: Dict[str, str] = {}
+    version_warnings: Dict[str, Type[exc.SADeprecationWarning]] = {}
 
     for param, (version, message) in specs.items():
         versions[param] = version
         messages[param] = _sanitize_restructured_text(message)
         version_warnings[param] = exc.SADeprecationWarning
 
-    def decorate(fn):
+    def decorate(fn: _F) -> _F:
         spec = compat.inspect_getfullargspec(fn)
 
+        check_defaults: Union[Set[str], Tuple[()]]
         if spec.defaults is not None:
             defaults = dict(
                 zip(
@@ -268,7 +290,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]:
         check_any_kw = spec.varkw
 
         @decorator
-        def warned(fn, *args, **kwargs):
+        def warned(fn: _F, *args: Any, **kwargs: Any) -> _F:
             for m in check_defaults:
                 if (defaults[m] is None and kwargs[m] is not None) or (
                     defaults[m] is not None and kwargs[m] != defaults[m]
@@ -283,7 +305,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]:
             if check_any_kw in messages and set(kwargs).difference(
                 check_defaults
             ):
-
+                assert check_any_kw is not None
                 _warn_with_version(
                     messages[check_any_kw],
                     versions[check_any_kw],
@@ -299,7 +321,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]:
                         version_warnings[m],
                         stacklevel=3,
                     )
-            return fn(*args, **kwargs)
+            return fn(*args, **kwargs)  # type: ignore[no-any-return]
 
         doc = fn.__doc__ is not None and fn.__doc__ or ""
         if doc:
@@ -311,15 +333,15 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]:
                     for param, (version, message) in specs.items()
                 },
             )
-        decorated = warned(fn)
+        decorated = cast(_F, warned)(fn)
         decorated.__doc__ = doc
-        return decorated
+        return decorated  # type: ignore[no-any-return]
 
     return decorate
 
 
-def _sanitize_restructured_text(text):
-    def repl(m):
+def _sanitize_restructured_text(text: str) -> str:
+    def repl(m: Match[str]) -> str:
         type_, name = m.group(1, 2)
         if type_ in ("func", "meth"):
             name += "()"
@@ -330,8 +352,13 @@ def _sanitize_restructured_text(text):
 
 
 def _decorate_cls_with_warning(
-    cls, constructor, wtype, message, version, docstring_header=None
-):
+    cls: Type[_T],
+    constructor: str,
+    wtype: Type[exc.SADeprecationWarning],
+    message: str,
+    version: str,
+    docstring_header: Optional[str] = None,
+) -> Type[_T]:
     doc = cls.__doc__ is not None and cls.__doc__ or ""
     if docstring_header is not None:
 
@@ -361,6 +388,7 @@ def _decorate_cls_with_warning(
 
         if constructor is not None:
             assert constructor_fn is not None
+            assert wtype is not None
             setattr(
                 cls,
                 constructor,
@@ -372,8 +400,13 @@ def _decorate_cls_with_warning(
 
 
 def _decorate_with_warning(
-    func, wtype, message, version, docstring_header=None, enable_warnings=True
-):
+    func: _F,
+    wtype: Type[exc.SADeprecationWarning],
+    message: str,
+    version: str,
+    docstring_header: Optional[str] = None,
+    enable_warnings: bool = True,
+) -> _F:
     """Wrap a function with a warnings.warn and augmented docstring."""
 
     message = _sanitize_restructured_text(message)
@@ -387,13 +420,13 @@ def _decorate_with_warning(
         doc_only = ""
 
     @decorator
-    def warned(fn, *args, **kwargs):
+    def warned(fn: _F, *args: Any, **kwargs: Any) -> _F:
         skip_warning = not enable_warnings or kwargs.pop(
             "_sa_skip_warning", False
         )
         if not skip_warning:
             _warn_with_version(message, version, wtype, stacklevel=3)
-        return fn(*args, **kwargs)
+        return fn(*args, **kwargs)  # type: ignore[no-any-return]
 
     doc = func.__doc__ is not None and func.__doc__ or ""
     if docstring_header is not None:
@@ -403,9 +436,9 @@ def _decorate_with_warning(
 
         doc = inject_docstring_text(doc, docstring_header, 1)
 
-    decorated = warned(func)
+    decorated = cast(_F, warned)(func)
     decorated.__doc__ = doc
     decorated._sa_warn = lambda: _warn_with_version(
         message, version, wtype, stacklevel=3
     )
-    return decorated
+    return decorated  # type: ignore[no-any-return]
index 9e024b3c039a752cedb483af8e1a2b2bcf507347..43f9d5c73f50a8b26c0796ae8130f7f6fb2715f1 100644 (file)
@@ -32,6 +32,7 @@ from typing import Generic
 from typing import Iterator
 from typing import List
 from typing import Mapping
+from typing import NoReturn
 from typing import Optional
 from typing import overload
 from typing import Sequence
@@ -103,9 +104,14 @@ class safe_reraise:
             with safe_reraise():
                 sess.rollback()
 
+    TODO: is this context manager getting us anything in Python 3?
+    Not sure of the coroutine issue stated above; we would assume this was
+    when using eventlet / gevent.  not sure if our own greenlet integration
+    is impacted.
+
     """
 
-    __slots__ = ("warn_only", "_exc_info")
+    __slots__ = ("_exc_info",)
 
     _exc_info: Union[
         None,
@@ -117,9 +123,6 @@ class safe_reraise:
         Tuple[None, None, None],
     ]
 
-    def __init__(self, warn_only: bool = False):
-        self.warn_only = warn_only
-
     def __enter__(self) -> None:
         self._exc_info = sys.exc_info()
 
@@ -128,15 +131,14 @@ class safe_reraise:
         type_: Optional[Type[BaseException]],
         value: Optional[BaseException],
         traceback: Optional[types.TracebackType],
-    ) -> None:
+    ) -> NoReturn:
         assert self._exc_info is not None
         # see #2703 for notes
         if type_ is None:
             exc_type, exc_value, exc_tb = self._exc_info
             assert exc_value is not None
             self._exc_info = None  # remove potential circular references
-            if not self.warn_only:
-                raise exc_value.with_traceback(exc_tb)
+            raise exc_value.with_traceback(exc_tb)
         else:
             self._exc_info = None  # remove potential circular references
             assert value is not None
@@ -1123,13 +1125,22 @@ def as_interface(obj, cls=None, methods=None, required=None):
     )
 
 
+Selfdynamic_property = TypeVar(
+    "Selfdynamic_property", bound="dynamic_property[Any]"
+)
+
 Selfmemoized_property = TypeVar(
     "Selfmemoized_property", bound="memoized_property[Any]"
 )
 
 
-class memoized_property(Generic[_T]):
-    """A read-only @property that is only evaluated once."""
+class dynamic_property(Generic[_T]):
+    """A read-only @property that is evaluated each time.
+
+    This is mostly the same as @property except we can type it
+    alongside memoized_property
+
+    """
 
     fget: Callable[..., _T]
     __doc__: Optional[str]
@@ -1140,6 +1151,27 @@ class memoized_property(Generic[_T]):
         self.__doc__ = doc or fget.__doc__
         self.__name__ = fget.__name__
 
+    @overload
+    def __get__(
+        self: Selfdynamic_property, obj: None, cls: Any
+    ) -> Selfdynamic_property:
+        ...
+
+    @overload
+    def __get__(self, obj: Any, cls: Any) -> _T:
+        ...
+
+    def __get__(
+        self: Selfdynamic_property, obj: Any, cls: Any
+    ) -> Union[Selfdynamic_property, _T]:
+        if obj is None:
+            return self
+        return self.fget(obj)  # type: ignore[no-any-return]
+
+
+class memoized_property(dynamic_property[_T]):
+    """A read-only @property that is only evaluated once."""
+
     @overload
     def __get__(
         self: Selfmemoized_property, obj: None, cls: Any
@@ -1158,7 +1190,16 @@ class memoized_property(Generic[_T]):
         obj.__dict__[self.__name__] = result = self.fget(obj)
         return result  # type: ignore
 
-    def _reset(self, obj):
+    if typing.TYPE_CHECKING:
+        # __set__ can't actually be implemented because it would
+        # cause __get__ to be called in all cases
+        def __set__(self, instance: Any, value: Any) -> None:
+            ...
+
+        def __delete__(self, instance: Any) -> None:
+            ...
+
+    def _reset(self, obj: Any) -> None:
         memoized_property.reset(obj, self.__name__)
 
     @classmethod
@@ -1628,6 +1669,39 @@ class symbol:
         raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
 
 
+def parse_user_argument_for_enum(
+    arg: Any,
+    choices: Dict[_T, List[Any]],
+    name: str,
+) -> Optional[_T]:
+    """Given a user parameter, parse the parameter into a chosen value
+    from a list of choice objects, typically Enum values.
+
+    The user argument can be a string name that matches the name of a
+    symbol, or the symbol object itself, or any number of alternate choices
+    such as True/False/ None etc.
+
+    :param arg: the user argument.
+    :param choices: dictionary of enum values to lists of possible
+        entries for each.
+    :param name: name of the argument.   Used in an :class:`.ArgumentError`
+        that is raised if the parameter doesn't match any available argument.
+
+    """
+    # TODO: use whatever built in thing Enum provides for this,
+    # if applicable
+    for enum_value, choice in choices.items():
+        if arg is enum_value:
+            return enum_value
+        elif arg in choice:
+            return enum_value
+
+    if arg is None:
+        return None
+
+    raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+
+
 _creation_order = 1
 
 
@@ -1644,7 +1718,7 @@ def set_creation_order(instance):
     _creation_order += 1
 
 
-def warn_exception(func, *args, **kwargs):
+def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
     """executes the given function, catches all exceptions and converts to
     a warning.
 
@@ -1678,7 +1752,9 @@ class _hash_limit_string(str):
 
     _hash: int
 
-    def __new__(cls, value, num, args):
+    def __new__(
+        cls, value: str, num: int, args: Sequence[Any]
+    ) -> _hash_limit_string:
         interpolated = (value % args) + (
             " (this warning may be suppressed after %d occurrences)" % num
         )
@@ -1686,14 +1762,14 @@ class _hash_limit_string(str):
         self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
         return self
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return self._hash
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         return hash(self) == hash(other)
 
 
-def warn(msg, code=None):
+def warn(msg: str, code: Optional[str] = None) -> None:
     """Issue a warning.
 
     If msg is a string, :class:`.exc.SAWarning` is used as
@@ -1706,7 +1782,7 @@ def warn(msg, code=None):
         _warnings_warn(msg, exc.SAWarning)
 
 
-def warn_limited(msg, args):
+def warn_limited(msg: str, args: Sequence[Any]) -> None:
     """Issue a warning with a parameterized string, limiting the number
     of registrations.
 
@@ -1716,7 +1792,11 @@ def warn_limited(msg, args):
     _warnings_warn(msg, exc.SAWarning)
 
 
-def _warnings_warn(message, category=None, stacklevel=2):
+def _warnings_warn(
+    message: Union[str, Warning],
+    category: Optional[Type[Warning]] = None,
+    stacklevel: int = 2,
+) -> None:
 
     # adjust the given stacklevel to be outside of SQLAlchemy
     try:
@@ -1736,7 +1816,7 @@ def _warnings_warn(message, category=None, stacklevel=2):
         while frame is not None and re.match(
             r"^(?:sqlalchemy\.|alembic\.)", frame.f_globals.get("__name__", "")
         ):
-            frame = frame.f_back
+            frame = frame.f_back  # type: ignore[assignment]
             stacklevel += 1
 
     if category is not None:
@@ -1775,7 +1855,11 @@ _SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
 _UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
 
 
-def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
+def chop_traceback(
+    tb: List[str],
+    exclude_prefix: re.Pattern[str] = _UNITTEST_RE,
+    exclude_suffix: re.Pattern[str] = _SQLA_RE,
+) -> List[str]:
     """Chop extraneous lines off beginning and end of a traceback.
 
     :param tb:
index 3062d9d8ab00081b279a5a2ebce9c5724e229488..06b60c8bf8fd0fea05ecc9ff547926ff63c58e68 100644 (file)
@@ -17,16 +17,26 @@ producing a ``put()`` inside the ``get()`` and therefore a reentrant
 condition.
 
 """
+from __future__ import annotations
+
 import asyncio
 from collections import deque
 import threading
 from time import time as _time
+import typing
+from typing import Any
+from typing import Awaitable
+from typing import Deque
+from typing import Generic
+from typing import Optional
+from typing import TypeVar
 
 from .concurrency import await_fallback
 from .concurrency import await_only
 from .langhelpers import memoized_property
 
 
+_T = TypeVar("_T", bound=Any)
 __all__ = ["Empty", "Full", "Queue"]
 
 
@@ -42,8 +52,41 @@ class Full(Exception):
     pass
 
 
-class Queue:
-    def __init__(self, maxsize=0, use_lifo=False):
+class QueueCommon(Generic[_T]):
+    maxsize: int
+    use_lifo: bool
+
+    def __init__(self, maxsize: int = 0, use_lifo: bool = False):
+        ...
+
+    def empty(self) -> bool:
+        raise NotImplementedError()
+
+    def full(self) -> bool:
+        raise NotImplementedError()
+
+    def qsize(self) -> int:
+        raise NotImplementedError()
+
+    def put_nowait(self, item: _T) -> None:
+        raise NotImplementedError()
+
+    def put(
+        self, item: _T, block: bool = True, timeout: Optional[float] = None
+    ) -> None:
+        raise NotImplementedError()
+
+    def get_nowait(self) -> _T:
+        raise NotImplementedError()
+
+    def get(self, block: bool = True, timeout: Optional[float] = None) -> _T:
+        raise NotImplementedError()
+
+
+class Queue(QueueCommon[_T]):
+    queue: Deque[_T]
+
+    def __init__(self, maxsize: int = 0, use_lifo: bool = False):
         """Initialize a queue object with a given maximum size.
 
         If `maxsize` is <= 0, the queue size is infinite.
@@ -66,27 +109,29 @@ class Queue:
         # If this queue uses LIFO or FIFO
         self.use_lifo = use_lifo
 
-    def qsize(self):
+    def qsize(self) -> int:
         """Return the approximate size of the queue (not reliable!)."""
 
         with self.mutex:
             return self._qsize()
 
-    def empty(self):
+    def empty(self) -> bool:
         """Return True if the queue is empty, False otherwise (not
         reliable!)."""
 
         with self.mutex:
             return self._empty()
 
-    def full(self):
+    def full(self) -> bool:
         """Return True if the queue is full, False otherwise (not
         reliable!)."""
 
         with self.mutex:
             return self._full()
 
-    def put(self, item, block=True, timeout=None):
+    def put(
+        self, item: _T, block: bool = True, timeout: Optional[float] = None
+    ) -> None:
         """Put an item into the queue.
 
         If optional args `block` is True and `timeout` is None (the
@@ -118,7 +163,7 @@ class Queue:
             self._put(item)
             self.not_empty.notify()
 
-    def put_nowait(self, item):
+    def put_nowait(self, item: _T) -> None:
         """Put an item into the queue without blocking.
 
         Only enqueue the item if a free slot is immediately available.
@@ -126,7 +171,7 @@ class Queue:
         """
         return self.put(item, False)
 
-    def get(self, block=True, timeout=None):
+    def get(self, block: bool = True, timeout: Optional[float] = None) -> _T:
         """Remove and return an item from the queue.
 
         If optional args `block` is True and `timeout` is None (the
@@ -158,7 +203,7 @@ class Queue:
             self.not_full.notify()
             return item
 
-    def get_nowait(self):
+    def get_nowait(self) -> _T:
         """Remove and return an item from the queue without blocking.
 
         Only get an item if one is immediately available. Otherwise
@@ -167,32 +212,23 @@ class Queue:
 
         return self.get(False)
 
-    # Override these methods to implement other queue organizations
-    # (e.g. stack or priority queue).
-    # These will only be called with appropriate locks held
-
-    # Initialize the queue representation
-    def _init(self, maxsize):
+    def _init(self, maxsize: int) -> None:
         self.maxsize = maxsize
         self.queue = deque()
 
-    def _qsize(self):
+    def _qsize(self) -> int:
         return len(self.queue)
 
-    # Check whether the queue is empty
-    def _empty(self):
+    def _empty(self) -> bool:
         return not self.queue
 
-    # Check whether the queue is full
-    def _full(self):
+    def _full(self) -> bool:
         return self.maxsize > 0 and len(self.queue) == self.maxsize
 
-    # Put a new item in the queue
-    def _put(self, item):
+    def _put(self, item: _T) -> None:
         self.queue.append(item)
 
-    # Get an item from the queue
-    def _get(self):
+    def _get(self) -> _T:
         if self.use_lifo:
             # LIFO
             return self.queue.pop()
@@ -201,14 +237,21 @@ class Queue:
             return self.queue.popleft()
 
 
-class AsyncAdaptedQueue:
-    await_ = staticmethod(await_only)
+class AsyncAdaptedQueue(QueueCommon[_T]):
+    if typing.TYPE_CHECKING:
 
-    def __init__(self, maxsize=0, use_lifo=False):
+        @staticmethod
+        def await_(coroutine: Awaitable[Any]) -> _T:
+            ...
+
+    else:
+        await_ = staticmethod(await_only)
+
+    def __init__(self, maxsize: int = 0, use_lifo: bool = False):
         self.use_lifo = use_lifo
         self.maxsize = maxsize
 
-    def empty(self):
+    def empty(self) -> bool:
         return self._queue.empty()
 
     def full(self):
@@ -218,7 +261,7 @@ class AsyncAdaptedQueue:
         return self._queue.qsize()
 
     @memoized_property
-    def _queue(self):
+    def _queue(self) -> asyncio.Queue[_T]:
         # Delay creation of the queue until it is first used, to avoid
         # binding it to a possibly wrong event loop.
         # By delaying the creation of the pool we accommodate the common
@@ -226,39 +269,41 @@ class AsyncAdaptedQueue:
         # different event loop is in present compared to when the application
         # is actually run.
 
+        queue: asyncio.Queue[_T]
+
         if self.use_lifo:
             queue = asyncio.LifoQueue(maxsize=self.maxsize)
         else:
             queue = asyncio.Queue(maxsize=self.maxsize)
         return queue
 
-    def put_nowait(self, item):
+    def put_nowait(self, item: _T) -> None:
         try:
-            return self._queue.put_nowait(item)
+            self._queue.put_nowait(item)
         except asyncio.QueueFull as err:
             raise Full() from err
 
-    def put(self, item, block=True, timeout=None):
+    def put(
+        self, item: _T, block: bool = True, timeout: Optional[float] = None
+    ) -> None:
         if not block:
             return self.put_nowait(item)
 
         try:
             if timeout is not None:
-                return self.await_(
-                    asyncio.wait_for(self._queue.put(item), timeout)
-                )
+                self.await_(asyncio.wait_for(self._queue.put(item), timeout))
             else:
-                return self.await_(self._queue.put(item))
+                self.await_(self._queue.put(item))
         except (asyncio.QueueFull, asyncio.TimeoutError) as err:
             raise Full() from err
 
-    def get_nowait(self):
+    def get_nowait(self) -> _T:
         try:
             return self._queue.get_nowait()
         except asyncio.QueueEmpty as err:
             raise Empty() from err
 
-    def get(self, block=True, timeout=None):
+    def get(self, block: bool = True, timeout: Optional[float] = None) -> _T:
         if not block:
             return self.get_nowait()
 
@@ -273,5 +318,6 @@ class AsyncAdaptedQueue:
             raise Empty() from err
 
 
-class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue):
-    await_ = staticmethod(await_fallback)
+class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue[_T]):
+    if not typing.TYPE_CHECKING:
+        await_ = staticmethod(await_fallback)
index 404f239c8933bb3ac0c0c97863ba816571dafcd3..ddda420db14e0fc365d8e298f1d0b686c81ada29 100644 (file)
@@ -80,25 +80,6 @@ class _TypeToInstance(Generic[_T]):
         ...
 
 
-class ReadOnlyInstanceDescriptor(Protocol[_T]):
-    """protocol representing an instance-only descriptor"""
-
-    @overload
-    def __get__(
-        self, instance: None, owner: Any
-    ) -> "ReadOnlyInstanceDescriptor[_T]":
-        ...
-
-    @overload
-    def __get__(self, instance: object, owner: Any) -> _T:
-        ...
-
-    def __get__(
-        self, instance: object, owner: Any
-    ) -> Union["ReadOnlyInstanceDescriptor[_T]", _T]:
-        ...
-
-
 def de_stringify_annotation(
     cls: Type[Any], annotation: Union[str, Type[Any]]
 ) -> Union[str, Type[Any]]:
index b6f095239035036fe21dbfda51563ecbf2ded653..f7750b6a6ba24f12aec03eb6b8988f110f986435 100644 (file)
@@ -40,6 +40,7 @@ markers = [
 
 [tool.pyright]
 include = [
+    "lib/sqlalchemy/pool/",
     "lib/sqlalchemy/event/",
     "lib/sqlalchemy/events.py",
     "lib/sqlalchemy/exc.py",
@@ -50,6 +51,9 @@ include = [
     "lib/sqlalchemy/util/",
 ]
 
+reportPrivateUsage = "none"
+reportUnusedClass = "none"
+reportUnusedFunction = "none"
 
 
 [tool.mypy]
@@ -78,6 +82,7 @@ strict = true
 # strict checking
 [[tool.mypy.overrides]]
 module = [
+    "sqlalchemy.pool.*",
     "sqlalchemy.event.*",
     "sqlalchemy.events",
     "sqlalchemy.exc",
index 0c897520254fb6492487d4ff83f5fdba111cea14..c1613069e111cd044ac39b8ac203234402fdcf1a 100644 (file)
@@ -322,7 +322,7 @@ class PoolTest(PoolTestBase):
         is_(rec.connection, rec.dbapi_connection)
         is_(rec.driver_connection, rec.dbapi_connection)
 
-        fairy = pool._ConnectionFairy(rec.dbapi_connection, rec, False)
+        fairy = pool._ConnectionFairy(p1, rec.dbapi_connection, rec, False)
 
         is_not_none(fairy.dbapi_connection)
         is_(fairy.connection, fairy.dbapi_connection)
@@ -346,12 +346,13 @@ class PoolTest(PoolTestBase):
 
         rec = pool._ConnectionRecord(p1)
 
+        assert rec.dbapi_connection is not None
         is_not_none(rec.dbapi_connection)
 
         is_(rec.connection, rec.dbapi_connection)
         is_(rec.driver_connection, mock_dc)
 
-        fairy = pool._ConnectionFairy(rec.dbapi_connection, rec, False)
+        fairy = pool._ConnectionFairy(p1, rec.dbapi_connection, rec, False)
 
         is_not_none(fairy.dbapi_connection)
         is_(fairy.connection, fairy.dbapi_connection)