From: Mike Bayer Date: Sun, 13 Feb 2022 21:45:18 +0000 (-0500) Subject: pep-484 for sqlalchemy.event; use future annotations X-Git-Tag: rel_2_0_0b1~482^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5c6081ddb03447697f909a03572b6d6d79e61b71;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep-484 for sqlalchemy.event; use future annotations __future__.annotations mode allows us to use non-string annotations for argument and return types in most cases, but more importantly it removes a large amount of runtime overhead that would be spent in evaluating the annotations. Change-Id: I2f5b6126fe0019713fc50001be3627b664019ede References: #6810 --- diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0693914684..b888441fd4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,7 @@ repos: additional_dependencies: - flake8-import-order - flake8-builtins + - flake8-future-annotations - flake8-docstrings>=1.6.0 - flake8-rst-docstrings # flake8-rst-docstrings dependency, leaving it here diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index eadb427d0d..dc1c536c8d 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from . import util as _util from .engine import AdaptedConnection as AdaptedConnection from .engine import BaseCursorResult as BaseCursorResult diff --git a/lib/sqlalchemy/engine/_py_processors.py b/lib/sqlalchemy/engine/_py_processors.py index 66c915a8fb..e3024471a2 100644 --- a/lib/sqlalchemy/engine/_py_processors.py +++ b/lib/sqlalchemy/engine/_py_processors.py @@ -13,6 +13,8 @@ They all share one common characteristic: None is passed through unchanged. """ +from __future__ import annotations + import datetime import re diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py index 981b6e0b27..a6d5b79d59 100644 --- a/lib/sqlalchemy/engine/_py_row.py +++ b/lib/sqlalchemy/engine/_py_row.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import operator MD_INDEX = 0 # integer index in cursor.description diff --git a/lib/sqlalchemy/engine/_py_util.py b/lib/sqlalchemy/engine/_py_util.py index 2db6c049bb..ff03a47613 100644 --- a/lib/sqlalchemy/engine/_py_util.py +++ b/lib/sqlalchemy/engine/_py_util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import abc as collections_abc from .. import exc diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4045eae907..4fd2739484 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -4,6 +4,8 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import contextlib import sys import typing diff --git a/lib/sqlalchemy/engine/characteristics.py b/lib/sqlalchemy/engine/characteristics.py index 10455451fd..c3674c931e 100644 --- a/lib/sqlalchemy/engine/characteristics.py +++ b/lib/sqlalchemy/engine/characteristics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 2f8ce17df9..a252b7cfeb 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from typing import Any from typing import Union diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index f372b88985..2b077056fa 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -9,6 +9,8 @@ :class:`.BaseCursorResult`, :class:`.CursorResult`.""" +from __future__ import annotations + import collections import functools diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4861214c4a..b7dbfc52ee 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -13,6 +13,8 @@ as the base class for their own corresponding classes. """ +from __future__ import annotations + import functools import random import re diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index 3af46c119b..ab462bbe1f 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -6,6 +6,8 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from .base import Engine from .interfaces import ConnectionEventsTarget from .interfaces import Dialect diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 2bbe23e042..ce884614c0 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -7,6 +7,8 @@ """Define core interfaces used by the engine system.""" +from __future__ import annotations + from enum import Enum from typing import Any from typing import Callable diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index cee4db8026..76e77a3f3d 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from operator import attrgetter from . import url as _url diff --git a/lib/sqlalchemy/engine/processors.py b/lib/sqlalchemy/engine/processors.py index 829af67963..398c1fa361 100644 --- a/lib/sqlalchemy/engine/processors.py +++ b/lib/sqlalchemy/engine/processors.py @@ -12,6 +12,8 @@ processors. They all share one common characteristic: None is passed through unchanged. """ +from __future__ import annotations + from ._py_processors import str_to_datetime_processor_factory # noqa try: diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 882392e9c2..e1281365e9 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -24,6 +24,8 @@ methods such as get_table_names, get_columns, etc. use the key 'name'. So for most return values, each record will have a 'name' attribute.. """ +from __future__ import annotations + import contextlib from typing import List from typing import Optional diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 5970e2448f..2e54c87dbd 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -6,6 +6,9 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php """Define generic result set constructs.""" + +from __future__ import annotations + import collections.abc as collections_abc import functools import itertools diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 75c56450e2..29b2f338b6 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -7,6 +7,8 @@ """Define row constructs including :class:`.Row`.""" +from __future__ import annotations + import collections.abc as collections_abc import operator import typing diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 8042acd39a..7f291af823 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -10,6 +10,8 @@ """ +from __future__ import annotations + from .mock import MockConnection # noqa diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index ec5ab2bec7..a55233397e 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -14,6 +14,8 @@ argument; alternatively, the URL is a public-facing construct which can be used directly and is also accepted directly by ``create_engine()``. """ +from __future__ import annotations + import collections.abc as collections_abc import re from typing import Dict diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index f74cd3f847..f9ee65befe 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from .. import exc from .. import util diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index a89bea894e..2d10372ab1 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -5,13 +5,15 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -from .api import CANCEL -from .api import contains -from .api import listen -from .api import listens_for -from .api import NO_RETVAL -from .api import remove -from .attr import RefCollection -from .base import dispatcher -from .base import Events +from __future__ import annotations + +from .api import CANCEL as CANCEL +from .api import contains as contains +from .api import listen as listen +from .api import listens_for as listens_for +from .api import NO_RETVAL as NO_RETVAL +from .api import remove as remove +from .attr import RefCollection as RefCollection +from .base import dispatcher as dispatcher +from .base import Events as Events from .legacy import _legacy_signature diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py index d2fd9473cc..52f796b19d 100644 --- a/lib/sqlalchemy/event/api.py +++ b/lib/sqlalchemy/event/api.py @@ -8,8 +8,15 @@ """Public API functions for the event system. """ +from __future__ import annotations + +from typing import Any +from typing import Callable + from .base import _registrars +from .registry import _ET from .registry import _EventKey +from .registry import _ListenerFnType from .. import exc from .. import util @@ -18,7 +25,9 @@ CANCEL = util.symbol("CANCEL") NO_RETVAL = util.symbol("NO_RETVAL") -def _event_key(target, identifier, fn): +def _event_key( + target: _ET, identifier: str, fn: _ListenerFnType +) -> _EventKey[_ET]: for evt_cls in _registrars[identifier]: tgt = evt_cls._accept_with(target) if tgt is not None: @@ -29,7 +38,9 @@ def _event_key(target, identifier, fn): ) -def listen(target, identifier, fn, *args, **kw): +def listen( + target: Any, identifier: str, fn: Callable[..., Any], *args: Any, **kw: Any +) -> None: """Register a listener function for the given target. The :func:`.listen` function is part of the primary interface for the @@ -113,7 +124,9 @@ def listen(target, identifier, fn, *args, **kw): _event_key(target, identifier, fn).listen(*args, **kw) -def listens_for(target, identifier, *args, **kw): +def listens_for( + target: Any, identifier: str, *args: Any, **kw: Any +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Decorate a function as a listener for the given target + identifier. The :func:`.listens_for` decorator is part of the primary interface for the @@ -154,14 +167,14 @@ def listens_for(target, identifier, *args, **kw): """ - def decorate(fn): + def decorate(fn: Callable[..., Any]) -> Callable[..., Any]: listen(target, identifier, fn, *args, **kw) return fn return decorate -def remove(target, identifier, fn): +def remove(target: Any, identifier: str, fn: Callable[..., Any]) -> None: """Remove an event listener. The arguments here should match exactly those which were sent to @@ -211,7 +224,7 @@ def remove(target, identifier, fn): _event_key(target, identifier, fn).remove() -def contains(target, identifier, fn): +def contains(target: Any, identifier: str, fn: Callable[..., Any]) -> bool: """Return True if the given target/ident/fn is set up to listen.""" return _event_key(target, identifier, fn).contains() diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index a059662224..d1ae7a8452 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -28,43 +28,89 @@ as well as support for subclass propagation (e.g. events assigned to ``Pool`` vs. ``QueuePool``) are all implemented here. """ +from __future__ import annotations + import collections from itertools import chain import threading +from types import TracebackType +import typing +from typing import Any +from typing import cast +from typing import Collection +from typing import Deque +from typing import FrozenSet +from typing import Generic +from typing import Iterator +from typing import MutableMapping +from typing import MutableSequence +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union import weakref from . import legacy from . import registry +from .registry import _ET +from .registry import _EventKey +from .registry import _ListenerFnType from .. import exc from .. import util from ..util.concurrency import AsyncAdaptedLock +from ..util.typing import Protocol + +_T = TypeVar("_T", bound=Any) + +if typing.TYPE_CHECKING: + from .base import _Dispatch + from .base import _HasEventsDispatch + from .base import _JoinedDispatcher -class RefCollection(util.MemoizedSlots): +class RefCollection(util.MemoizedSlots, Generic[_ET]): __slots__ = ("ref",) - def _memoized_attr_ref(self): + ref: weakref.ref[RefCollection[_ET]] + + def _memoized_attr_ref(self) -> weakref.ref[RefCollection[_ET]]: return weakref.ref(self, registry._collection_gced) -class _empty_collection: - def append(self, element): +class _empty_collection(Collection[_T]): + def append(self, element: _T) -> None: + pass + + def appendleft(self, element: _T) -> None: pass - def extend(self, other): + def extend(self, other: Sequence[_T]) -> None: pass - def remove(self, element): + def remove(self, element: _T) -> None: pass - def __iter__(self): + def __contains__(self, element: Any) -> bool: + return False + + def __iter__(self) -> Iterator[_T]: return iter([]) - def clear(self): + def clear(self) -> None: pass + def __len__(self) -> int: + return 0 + + +_ListenerFnSequenceType = Union[Deque[_T], _empty_collection[_T]] + -class _ClsLevelDispatch(RefCollection): +class _ClsLevelDispatch(RefCollection[_ET]): """Class-level events on :class:`._Dispatch` classes.""" __slots__ = ( @@ -77,7 +123,20 @@ class _ClsLevelDispatch(RefCollection): "__weakref__", ) - def __init__(self, parent_dispatch_cls, fn): + clsname: str + name: str + arg_names: Sequence[str] + has_kw: bool + legacy_signatures: MutableSequence[legacy._LegacySignatureType] + _clslevel: MutableMapping[ + Type[_ET], _ListenerFnSequenceType[_ListenerFnType] + ] + + def __init__( + self, + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + fn: _ListenerFnType, + ): self.name = fn.__name__ self.clsname = parent_dispatch_cls.__name__ argspec = util.inspect_getfullargspec(fn) @@ -94,7 +153,9 @@ class _ClsLevelDispatch(RefCollection): self._clslevel = weakref.WeakKeyDictionary() - def _adjust_fn_spec(self, fn, named): + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: if named: fn = self._wrap_fn_for_kw(fn) if self.legacy_signatures: @@ -106,15 +167,15 @@ class _ClsLevelDispatch(RefCollection): fn = legacy._wrap_fn_for_legacy(self, fn, argspec) return fn - def _wrap_fn_for_kw(self, fn): - def wrap_kw(*args, **kw): + def _wrap_fn_for_kw(self, fn: _ListenerFnType) -> _ListenerFnType: + def wrap_kw(*args: Any, **kw: Any) -> Any: argdict = dict(zip(self.arg_names, args)) argdict.update(kw) return fn(**argdict) return wrap_kw - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: target = event_key.dispatch_target assert isinstance( target, type @@ -125,6 +186,7 @@ class _ClsLevelDispatch(RefCollection): ) for cls in util.walk_subclasses(target): + cls = cast(Type[_ET], cls) if cls is not target and cls not in self._clslevel: self.update_subclass(cls) else: @@ -133,7 +195,7 @@ class _ClsLevelDispatch(RefCollection): self._clslevel[cls].appendleft(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: target = event_key.dispatch_target assert isinstance( target, type @@ -143,6 +205,7 @@ class _ClsLevelDispatch(RefCollection): "Can't assign an event directly to the %s class" % target ) for cls in util.walk_subclasses(target): + cls = cast("Type[_ET]", cls) if cls is not target and cls not in self._clslevel: self.update_subclass(cls) else: @@ -151,39 +214,41 @@ class _ClsLevelDispatch(RefCollection): self._clslevel[cls].append(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def _assign_cls_collection(self, target): + def _assign_cls_collection(self, target: Type[_ET]) -> None: if getattr(target, "_sa_propagate_class_events", True): self._clslevel[target] = collections.deque() else: self._clslevel[target] = _empty_collection() - def update_subclass(self, target): + def update_subclass(self, target: Type[_ET]) -> None: if target not in self._clslevel: self._assign_cls_collection(target) clslevel = self._clslevel[target] for cls in target.__mro__[1:]: + cls = cast("Type[_ET]", cls) if cls in self._clslevel: clslevel.extend( [fn for fn in self._clslevel[cls] if fn not in clslevel] ) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: target = event_key.dispatch_target for cls in util.walk_subclasses(target): + cls = cast("Type[_ET]", cls) if cls in self._clslevel: self._clslevel[cls].remove(event_key._listen_fn) registry._removed_from_collection(event_key, self) - def clear(self): + def clear(self) -> None: """Clear all class level listeners""" - to_clear = set() + to_clear: Set[_ListenerFnType] = set() for dispatcher in self._clslevel.values(): to_clear.update(dispatcher) dispatcher.clear() registry._clear(self, to_clear) - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ClsLevelDispatch[_ET]: """Return an event collection which can be modified. For _ClsLevelDispatch at the class level of @@ -193,14 +258,30 @@ class _ClsLevelDispatch(RefCollection): return self -class _InstanceLevelDispatch(RefCollection): +class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]): __slots__ = () - def _adjust_fn_spec(self, fn, named): + parent: _ClsLevelDispatch[_ET] + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: return self.parent._adjust_fn_spec(fn, named) + def __contains__(self, item: Any) -> bool: + raise NotImplementedError() + + def __len__(self) -> int: + raise NotImplementedError() + + def __iter__(self) -> Iterator[_ListenerFnType]: + raise NotImplementedError() + + def __bool__(self) -> bool: + raise NotImplementedError() + -class _EmptyListener(_InstanceLevelDispatch): +class _EmptyListener(_InstanceLevelDispatch[_ET]): """Serves as a proxy interface to the events served by a _ClsLevelDispatch, when there are no instance-level events present. @@ -210,19 +291,22 @@ class _EmptyListener(_InstanceLevelDispatch): """ - propagate = frozenset() - listeners = () - __slots__ = "parent", "parent_listeners", "name" - def __init__(self, parent, target_cls): + propagate: FrozenSet[_ListenerFnType] = frozenset() + listeners: Tuple[()] = () + parent: _ClsLevelDispatch[_ET] + parent_listeners: _ListenerFnSequenceType[_ListenerFnType] + name: str + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) - self.parent = parent # _ClsLevelDispatch + self.parent = parent self.parent_listeners = parent._clslevel[target_cls] self.name = parent.name - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]: """Return an event collection which can be modified. For _EmptyListener at the instance level of @@ -231,6 +315,7 @@ class _EmptyListener(_InstanceLevelDispatch): and returns it. """ + assert obj._instance_cls is not None result = _ListenerCollection(self.parent, obj._instance_cls) if getattr(obj, self.name) is self: setattr(obj, self.name, result) @@ -238,41 +323,79 @@ class _EmptyListener(_InstanceLevelDispatch): assert isinstance(getattr(obj, self.name), _JoinedListener) return result - def _needs_modify(self, *args, **kw): + def _needs_modify(self, *args: Any, **kw: Any) -> NoReturn: raise NotImplementedError("need to call for_modify()") - exec_once = ( - exec_once_unless_exception - ) = insert = append = remove = clear = _needs_modify + def exec_once(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def insert(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) - def __call__(self, *args, **kw): + def append(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def remove(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def clear(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def __call__(self, *args: Any, **kw: Any) -> None: """Execute this event.""" for fn in self.parent_listeners: fn(*args, **kw) - def __len__(self): + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners + + def __len__(self) -> int: return len(self.parent_listeners) - def __iter__(self): + def __iter__(self) -> Iterator[_ListenerFnType]: return iter(self.parent_listeners) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.parent_listeners) __nonzero__ = __bool__ -class _CompoundListener(_InstanceLevelDispatch): +class _MutexProtocol(Protocol): + def __enter__(self) -> bool: + ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + ... + + +class _CompoundListener(_InstanceLevelDispatch[_ET]): __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once" - def _set_asyncio(self): + _exec_once_mutex: _MutexProtocol + parent_listeners: Collection[_ListenerFnType] + listeners: Collection[_ListenerFnType] + _exec_once: bool + _exec_w_sync_once: bool + + def _set_asyncio(self) -> None: self._exec_once_mutex = AsyncAdaptedLock() - def _memoized_attr__exec_once_mutex(self): + def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol: return threading.Lock() - def _exec_once_impl(self, retry_on_exception, *args, **kw): + def _exec_once_impl( + self, retry_on_exception: bool, *args: Any, **kw: Any + ) -> None: with self._exec_once_mutex: if not self._exec_once: try: @@ -285,14 +408,14 @@ class _CompoundListener(_InstanceLevelDispatch): if not exception or not retry_on_exception: self._exec_once = True - def exec_once(self, *args, **kw): + def exec_once(self, *args: Any, **kw: Any) -> None: """Execute this event, but only if it has not been executed already for this collection.""" if not self._exec_once: self._exec_once_impl(False, *args, **kw) - def exec_once_unless_exception(self, *args, **kw): + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None: """Execute this event, but only if it has not been executed already for this collection, or was called by a previous exec_once_unless_exception call and @@ -307,7 +430,7 @@ class _CompoundListener(_InstanceLevelDispatch): if not self._exec_once: self._exec_once_impl(True, *args, **kw) - def _exec_w_sync_on_first_run(self, *args, **kw): + def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None: """Execute this event, and use a mutex if it has not been executed already for this collection, or was called by a previous _exec_w_sync_on_first_run call and @@ -330,7 +453,7 @@ class _CompoundListener(_InstanceLevelDispatch): else: self(*args, **kw) - def __call__(self, *args, **kw): + def __call__(self, *args: Any, **kw: Any) -> None: """Execute this event.""" for fn in self.parent_listeners: @@ -338,19 +461,22 @@ class _CompoundListener(_InstanceLevelDispatch): for fn in self.listeners: fn(*args, **kw) - def __len__(self): + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners or item in self.listeners + + def __len__(self) -> int: return len(self.parent_listeners) + len(self.listeners) - def __iter__(self): + def __iter__(self) -> Iterator[_ListenerFnType]: return chain(self.parent_listeners, self.listeners) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.listeners or self.parent_listeners) __nonzero__ = __bool__ -class _ListenerCollection(_CompoundListener): +class _ListenerCollection(_CompoundListener[_ET]): """Instance-level attributes on instances of :class:`._Dispatch`. Represents a collection of listeners. @@ -369,7 +495,13 @@ class _ListenerCollection(_CompoundListener): "__weakref__", ) - def __init__(self, parent, target_cls): + parent_listeners: Collection[_ListenerFnType] + parent: _ClsLevelDispatch[_ET] + name: str + listeners: Deque[_ListenerFnType] + propagate: Set[_ListenerFnType] + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) self._exec_once = False @@ -380,7 +512,7 @@ class _ListenerCollection(_CompoundListener): self.listeners = collections.deque() self.propagate = set() - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]: """Return an event collection which can be modified. For _ListenerCollection at the instance level of @@ -389,10 +521,11 @@ class _ListenerCollection(_CompoundListener): """ return self - def _update(self, other, only_propagate=True): + def _update( + self, other: _ListenerCollection[_ET], only_propagate: bool = True + ) -> None: """Populate from the listeners in another :class:`_Dispatch` object.""" - existing_listeners = self.listeners existing_listener_set = set(existing_listeners) self.propagate.update(other.propagate) @@ -409,56 +542,75 @@ class _ListenerCollection(_CompoundListener): to_associate = other.propagate.union(other_listeners) registry._stored_in_collection_multi(self, other, to_associate) - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: if event_key.prepend_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: if event_key.append_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: self.listeners.remove(event_key._listen_fn) self.propagate.discard(event_key._listen_fn) registry._removed_from_collection(event_key, self) - def clear(self): + def clear(self) -> None: registry._clear(self, self.listeners) self.propagate.clear() self.listeners.clear() -class _JoinedListener(_CompoundListener): - __slots__ = "parent", "name", "local", "parent_listeners" +class _JoinedListener(_CompoundListener[_ET]): + __slots__ = "parent_dispatch", "name", "local", "parent_listeners" + + parent_dispatch: _Dispatch[_ET] + name: str + local: _InstanceLevelDispatch[_ET] + parent_listeners: Collection[_ListenerFnType] - def __init__(self, parent, name, local): + def __init__( + self, + parent_dispatch: _Dispatch[_ET], + name: str, + local: _EmptyListener[_ET], + ): self._exec_once = False - self.parent = parent + self.parent_dispatch = parent_dispatch self.name = name self.local = local self.parent_listeners = self.local - @property - def listeners(self): - return getattr(self.parent, self.name) - - def _adjust_fn_spec(self, fn, named): + if not typing.TYPE_CHECKING: + # first error, I don't really understand: + # Signature of "listeners" incompatible with + # supertype "_CompoundListener" [override] + # the name / return type are exactly the same + # second error is getattr_isn't typed, the cast() here + # adds too much method overhead + @property + def listeners(self) -> Collection[_ListenerFnType]: + return getattr(self.parent_dispatch, self.name) + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: return self.local._adjust_fn_spec(fn, named) - def for_modify(self, obj): + def for_modify(self, obj: _JoinedDispatcher[_ET]) -> _JoinedListener[_ET]: self.local = self.parent_listeners = self.local.for_modify(obj) return self - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: self.local.insert(event_key, propagate) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: self.local.append(event_key, propagate) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: self.local.remove(event_key) - def clear(self): + def clear(self) -> None: raise NotImplementedError() diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 25d3692408..0e0647036f 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -15,21 +15,37 @@ at the class level of a particular ``_Dispatch`` class as well as within instances of ``_Dispatch``. """ -from typing import ClassVar +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import MutableMapping from typing import Optional +from typing import overload +from typing import Tuple from typing import Type +from typing import Union import weakref from .attr import _ClsLevelDispatch from .attr import _EmptyListener +from .attr import _InstanceLevelDispatch from .attr import _JoinedListener +from .registry import _ET +from .registry import _EventKey from .. import util -from ..util.typing import Protocol +from ..util.typing import Literal -_registrars = util.defaultdict(list) +_registrars: MutableMapping[ + str, List[Type[_HasEventsDispatch[Any]]] +] = util.defaultdict(list) -def _is_event_name(name): +def _is_event_name(name: str) -> bool: # _sa_event prefix is special to support internal-only event names. # most event names are just plain method names that aren't # underscored. @@ -45,17 +61,17 @@ class _UnpickleDispatch: """ - def __call__(self, _instance_cls): + def __call__(self, _instance_cls: Type[_ET]) -> _Dispatch[_ET]: for cls in _instance_cls.__mro__: if "dispatch" in cls.__dict__: - return cls.__dict__["dispatch"].dispatch._for_class( - _instance_cls - ) + return cast( + "_Dispatch[_ET]", cls.__dict__["dispatch"].dispatch + )._for_class(_instance_cls) else: raise AttributeError("No class with a 'dispatch' member present.") -class _Dispatch: +class _Dispatch(Generic[_ET]): """Mirror the event listening definitions of an Events class with listener collections. @@ -79,20 +95,35 @@ class _Dispatch: # so __dict__ is used in just that case and potentially others. __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners" - _empty_listener_reg = weakref.WeakKeyDictionary() + _empty_listener_reg: MutableMapping[ + Type[_ET], Dict[str, _EmptyListener[_ET]] + ] = weakref.WeakKeyDictionary() + + _empty_listeners: Dict[str, _EmptyListener[_ET]] + + _event_names: List[str] + + _instance_cls: Optional[Type[_ET]] - _events: Type["_HasEventsDispatch"] + _joined_dispatch_cls: Type[_JoinedDispatcher[_ET]] + + _events: Type[_HasEventsDispatch[_ET]] """reference back to the Events class. Bidirectional against _HasEventsDispatch.dispatch """ - def __init__(self, parent, instance_cls=None): + def __init__( + self, + parent: Optional[_Dispatch[_ET]], + instance_cls: Optional[Type[_ET]] = None, + ): self._parent = parent self._instance_cls = instance_cls if instance_cls: + assert parent is not None try: self._empty_listeners = self._empty_listener_reg[instance_cls] except KeyError: @@ -105,7 +136,7 @@ class _Dispatch: else: self._empty_listeners = {} - def __getattr__(self, name): + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: # Assign EmptyListeners as attributes on demand # to reduce startup time for new dispatch objects. try: @@ -117,24 +148,23 @@ class _Dispatch: return ls @property - def _event_descriptors(self): + def _event_descriptors(self) -> Iterator[_ClsLevelDispatch[_ET]]: for k in self._event_names: # Yield _ClsLevelDispatch related # to relevant event name. yield getattr(self, k) - @property - def _listen(self): - return self._events._listen + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self._events._listen(event_key, **kw) - def _for_class(self, instance_cls): + def _for_class(self, instance_cls: Type[_ET]) -> _Dispatch[_ET]: return self.__class__(self, instance_cls) - def _for_instance(self, instance): + def _for_instance(self, instance: _ET) -> _Dispatch[_ET]: instance_cls = instance.__class__ return self._for_class(instance_cls) - def _join(self, other): + def _join(self, other: _Dispatch[_ET]) -> _JoinedDispatcher[_ET]: """Create a 'join' of this :class:`._Dispatch` and another. This new dispatcher will dispatch events to both @@ -147,14 +177,15 @@ class _Dispatch: (_JoinedDispatcher,), {"__slots__": self._event_names}, ) - self.__class__._joined_dispatch_cls = cls return self._joined_dispatch_cls(self, other) - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return _UnpickleDispatch(), (self._instance_cls,) - def _update(self, other, only_propagate=True): + def _update( + self, other: _Dispatch[_ET], only_propagate: bool = True + ) -> None: """Populate from the listeners in another :class:`_Dispatch` object.""" for ls in other._event_descriptors: @@ -164,32 +195,23 @@ class _Dispatch: ls, only_propagate=only_propagate ) - def _clear(self): + def _clear(self) -> None: for ls in self._event_descriptors: ls.for_modify(self).clear() -def _remove_dispatcher(cls): +def _remove_dispatcher(cls: Type[_HasEventsDispatch[_ET]]) -> None: for k in cls.dispatch._event_names: _registrars[k].remove(cls) if not _registrars[k]: del _registrars[k] -class _HasEventsDispatchProto(Protocol): - """protocol for non-event classes that will also receive the 'dispatch' - attribute in the form of a descriptor. - - """ - - dispatch: ClassVar["dispatcher"] - - -class _HasEventsDispatch: - _dispatch_target: Optional[Type[_HasEventsDispatchProto]] +class _HasEventsDispatch(Generic[_ET]): + _dispatch_target: Optional[Type[_ET]] """class which will receive the .dispatch collection""" - dispatch: _Dispatch + dispatch: _Dispatch[_ET] """reference back to the _Dispatch class. Bidirectional against _Dispatch._events @@ -202,19 +224,41 @@ class _HasEventsDispatch: cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__) + @classmethod + def _accept_with( + cls, target: Union[_ET, Type[_ET]] + ) -> Optional[Union[_ET, Type[_ET]]]: + raise NotImplementedError() + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: + raise NotImplementedError() + @staticmethod - def _set_dispatch(cls, dispatch_cls): + def _set_dispatch( + klass: Type[_HasEventsDispatch[_ET]], + dispatch_cls: Type[_Dispatch[_ET]], + ) -> _Dispatch[_ET]: # This allows an Events subclass to define additional utility # methods made available to the target via # "self.dispatch._events." # @staticmethod to allow easy "super" calls while in a metaclass # constructor. - cls.dispatch = dispatch_cls(None) - dispatch_cls._events = cls - return cls.dispatch + klass.dispatch = dispatch_cls(None) + dispatch_cls._events = klass + return klass.dispatch @classmethod - def _create_dispatcher_class(cls, classname, bases, dict_): + def _create_dispatcher_class( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any] + ) -> None: """Create a :class:`._Dispatch` class corresponding to an :class:`.Events` class.""" @@ -227,14 +271,16 @@ class _HasEventsDispatch: dispatch_base = _Dispatch event_names = [k for k in dict_ if _is_event_name(k)] - dispatch_cls = type( - "%sDispatch" % classname, - (dispatch_base,), - {"__slots__": event_names}, + dispatch_cls = cast( + "Type[_Dispatch[_ET]]", + type( + "%sDispatch" % classname, + (dispatch_base,), # type: ignore + {"__slots__": event_names}, + ), ) dispatch_cls._event_names = event_names - dispatch_inst = cls._set_dispatch(cls, dispatch_cls) for k in dispatch_cls._event_names: setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) @@ -251,23 +297,28 @@ class _HasEventsDispatch: assert dispatch_target_cls is not None if ( hasattr(dispatch_target_cls, "__slots__") - and "_slots_dispatch" in dispatch_target_cls.__slots__ + and "_slots_dispatch" in dispatch_target_cls.__slots__ # type: ignore # noqa E501 ): dispatch_target_cls.dispatch = slots_dispatcher(cls) else: dispatch_target_cls.dispatch = dispatcher(cls) -class Events(_HasEventsDispatch): +class Events(_HasEventsDispatch[_ET]): """Define event listening functions for a particular target type.""" @classmethod - def _accept_with(cls, target): - def dispatch_is(*types): + def _accept_with( + cls, target: Union[_ET, Type[_ET]] + ) -> Optional[Union[_ET, Type[_ET]]]: + def dispatch_is(*types: Type[Any]) -> bool: return all(isinstance(target.dispatch, t) for t in types) - def dispatch_parent_is(t): - return isinstance(target.dispatch.parent, t) + def dispatch_parent_is(t: Type[Any]) -> bool: + + return isinstance( + cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t + ) # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. @@ -282,39 +333,45 @@ class Events(_HasEventsDispatch): ): return target + return None + @classmethod def _listen( cls, - event_key, - propagate=False, - insert=False, - named=False, - asyncio=False, - ): + event_key: _EventKey[_ET], + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: event_key.base_listen( propagate=propagate, insert=insert, named=named, asyncio=asyncio ) @classmethod - def _remove(cls, event_key): + def _remove(cls, event_key: _EventKey[_ET]) -> None: event_key.remove() @classmethod - def _clear(cls): + def _clear(cls) -> None: cls.dispatch._clear() -class _JoinedDispatcher: +class _JoinedDispatcher(Generic[_ET]): """Represent a connection between two _Dispatch objects.""" __slots__ = "local", "parent", "_instance_cls" - def __init__(self, local, parent): + local: _Dispatch[_ET] + parent: _Dispatch[_ET] + _instance_cls: Optional[Type[_ET]] + + def __init__(self, local: _Dispatch[_ET], parent: _Dispatch[_ET]): self.local = local self.parent = parent self._instance_cls = self.local._instance_cls - def __getattr__(self, name): + def __getattr__(self, name: str) -> _JoinedListener[_ET]: # Assign _JoinedListeners as attributes on demand # to reduce startup time for new dispatch objects. ls = getattr(self.local, name) @@ -322,16 +379,15 @@ class _JoinedDispatcher: setattr(self, ls.name, jl) return jl - @property - def _listen(self): - return self.parent._listen + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self.parent._listen(event_key, **kw) @property - def _events(self): + def _events(self) -> Type[_HasEventsDispatch[_ET]]: return self.parent._events -class dispatcher: +class dispatcher(Generic[_ET]): """Descriptor used by target classes to deliver the _Dispatch class at the class level and produce new _Dispatch instances for target @@ -339,11 +395,21 @@ class dispatcher: """ - def __init__(self, events): + def __init__(self, events: Type[_HasEventsDispatch[_ET]]): self.dispatch = events.dispatch self.events = events - def __get__(self, obj, cls): + @overload + def __get__( + self, obj: Literal[None], cls: Type[Any] + ) -> Type[_HasEventsDispatch[_ET]]: + ... + + @overload + def __get__(self, obj: Any, cls: Type[Any]) -> _HasEventsDispatch[_ET]: + ... + + def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: return self.dispatch @@ -358,8 +424,8 @@ class dispatcher: return disp -class slots_dispatcher(dispatcher): - def __get__(self, obj, cls): +class slots_dispatcher(dispatcher[_ET]): + def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: return self.dispatch diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 053b47eaac..75e5be7fe0 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -9,11 +9,34 @@ generation of deprecation notes and docstrings. """ - +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type + +from .registry import _ET +from .registry import _ListenerFnType from .. import util +from ..util.compat import FullArgSpec + +if typing.TYPE_CHECKING: + from .attr import _ClsLevelDispatch + from .base import _HasEventsDispatch + + +_LegacySignatureType = Tuple[str, List[str], Optional[Callable[..., Any]]] -def _legacy_signature(since, argnames, converter=None): +def _legacy_signature( + since: str, + argnames: List[str], + converter: Optional[Callable[..., Any]] = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """legacy sig decorator @@ -25,16 +48,20 @@ def _legacy_signature(since, argnames, converter=None): """ - def leg(fn): + def leg(fn: Callable[..., Any]) -> Callable[..., Any]: if not hasattr(fn, "_legacy_signatures"): - fn._legacy_signatures = [] - fn._legacy_signatures.append((since, argnames, converter)) + fn._legacy_signatures = [] # type: ignore[attr-defined] + fn._legacy_signatures.append((since, argnames, converter)) # type: ignore[attr-defined] # noqa E501 return fn return leg -def _wrap_fn_for_legacy(dispatch_collection, fn, argspec): +def _wrap_fn_for_legacy( + dispatch_collection: "_ClsLevelDispatch[_ET]", + fn: _ListenerFnType, + argspec: FullArgSpec, +) -> _ListenerFnType: for since, argnames, conv in dispatch_collection.legacy_signatures: if argnames[-1] == "**kw": has_kw = True @@ -64,34 +91,39 @@ def _wrap_fn_for_legacy(dispatch_collection, fn, argspec): ) ) - if conv: + if conv is not None: assert not has_kw - def wrap_leg(*args): + def wrap_leg(*args: Any, **kw: Any) -> Any: util.warn_deprecated(warning_txt, version=since) + assert conv is not None return fn(*conv(*args)) else: - def wrap_leg(*args, **kw): + def wrap_leg(*args: Any, **kw: Any) -> Any: util.warn_deprecated(warning_txt, version=since) argdict = dict(zip(dispatch_collection.arg_names, args)) - args = [argdict[name] for name in argnames] + args_from_dict = [argdict[name] for name in argnames] if has_kw: - return fn(*args, **kw) + return fn(*args_from_dict, **kw) else: - return fn(*args) + return fn(*args_from_dict) return wrap_leg else: return fn -def _indent(text, indent): +def _indent(text: str, indent: str) -> str: return "\n".join(indent + line for line in text.split("\n")) -def _standard_listen_example(dispatch_collection, sample_target, fn): +def _standard_listen_example( + dispatch_collection: "_ClsLevelDispatch[_ET]", + sample_target: Any, + fn: _ListenerFnType, +) -> str: example_kw_arg = _indent( "\n".join( "%(arg)s = kw['%(arg)s']" % {"arg": arg} @@ -128,7 +160,11 @@ def _standard_listen_example(dispatch_collection, sample_target, fn): return text -def _legacy_listen_examples(dispatch_collection, sample_target, fn): +def _legacy_listen_examples( + dispatch_collection: "_ClsLevelDispatch[_ET]", + sample_target: str, + fn: _ListenerFnType, +) -> str: text = "" for since, args, conv in dispatch_collection.legacy_signatures: text += ( @@ -152,7 +188,10 @@ def _legacy_listen_examples(dispatch_collection, sample_target, fn): return text -def _version_signature_changes(parent_dispatch_cls, dispatch_collection): +def _version_signature_changes( + parent_dispatch_cls: Type["_HasEventsDispatch[_ET]"], + dispatch_collection: "_ClsLevelDispatch[_ET]", +) -> str: since, args, conv = dispatch_collection.legacy_signatures[0] return ( "\n.. deprecated:: %(since)s\n" @@ -171,7 +210,11 @@ def _version_signature_changes(parent_dispatch_cls, dispatch_collection): ) -def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn): +def _augment_fn_docs( + dispatch_collection: "_ClsLevelDispatch[_ET]", + parent_dispatch_cls: Type["_HasEventsDispatch[_ET]"], + fn: _ListenerFnType, +) -> str: header = ( ".. container:: event_signatures\n\n" " Example argument forms::\n" diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index d831a332fc..e20d3e0b53 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -14,15 +14,60 @@ membership in all those collections can be revoked at once, based on an equivalent :class:`._EventKey`. """ +from __future__ import annotations + import collections import types +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Deque +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union import weakref from .. import exc from .. import util +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from .attr import RefCollection + from .base import dispatcher + +_ListenerFnType = Callable[..., Any] +_ListenerFnKeyType = Union[int, Tuple[int, int]] +_EventKeyTupleType = Tuple[int, str, _ListenerFnKeyType] + + +class _EventTargetType(Protocol): + """represents an event target, that is, something we can listen on + either with that target as a class or as an instance. + + Examples include: Connection, Mapper, Table, Session, + InstrumentedAttribute, Engine, Pool, Dialect. + """ -_key_to_collection = collections.defaultdict(dict) + dispatch: ClassVar[dispatcher[Any]] + + +_ET = TypeVar("_ET", bound=_EventTargetType) + +_RefCollectionToListenerType = Dict[ + "weakref.ref[RefCollection[Any]]", + "weakref.ref[_ListenerFnType]", +] + +_key_to_collection: Dict[ + _EventKeyTupleType, _RefCollectionToListenerType +] = collections.defaultdict(dict) """ Given an original listen() argument, can locate all listener collections and the listener fn contained @@ -34,7 +79,14 @@ listener collections and the listener fn contained } """ -_collection_to_key = collections.defaultdict(dict) +_ListenerToEventKeyType = Dict[ + "weakref.ref[_ListenerFnType]", + _EventKeyTupleType, +] +_collection_to_key: Dict[ + weakref.ref[RefCollection[Any]], + _ListenerToEventKeyType, +] = collections.defaultdict(dict) """ Given a _ListenerCollection or _ClsLevelListener, can locate all the original listen() arguments and the listener fn contained @@ -47,10 +99,13 @@ ref(listenercollection) -> { """ -def _collection_gced(ref): +def _collection_gced(ref: weakref.ref[Any]) -> None: # defaultdict, so can't get a KeyError if not _collection_to_key or ref not in _collection_to_key: return + + ref = cast("weakref.ref[RefCollection[_EventTargetType]]", ref) + listener_to_key = _collection_to_key.pop(ref) for key in listener_to_key.values(): if key in _key_to_collection: @@ -61,7 +116,9 @@ def _collection_gced(ref): _key_to_collection.pop(key) -def _stored_in_collection(event_key, owner): +def _stored_in_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> bool: key = event_key._key dispatch_reg = _key_to_collection[key] @@ -80,7 +137,9 @@ def _stored_in_collection(event_key, owner): return True -def _removed_from_collection(event_key, owner): +def _removed_from_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> None: key = event_key._key dispatch_reg = _key_to_collection[key] @@ -97,15 +156,19 @@ def _removed_from_collection(event_key, owner): listener_to_key.pop(listen_ref) -def _stored_in_collection_multi(newowner, oldowner, elements): +def _stored_in_collection_multi( + newowner: RefCollection[_ET], + oldowner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: if not elements: return - oldowner = oldowner.ref - newowner = newowner.ref + oldowner_ref = oldowner.ref + newowner_ref = newowner.ref - old_listener_to_key = _collection_to_key[oldowner] - new_listener_to_key = _collection_to_key[newowner] + old_listener_to_key = _collection_to_key[oldowner_ref] + new_listener_to_key = _collection_to_key[newowner_ref] for listen_fn in elements: listen_ref = weakref.ref(listen_fn) @@ -121,31 +184,34 @@ def _stored_in_collection_multi(newowner, oldowner, elements): except KeyError: continue - if newowner in dispatch_reg: - assert dispatch_reg[newowner] == listen_ref + if newowner_ref in dispatch_reg: + assert dispatch_reg[newowner_ref] == listen_ref else: - dispatch_reg[newowner] = listen_ref + dispatch_reg[newowner_ref] = listen_ref new_listener_to_key[listen_ref] = key -def _clear(owner, elements): +def _clear( + owner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: if not elements: return - owner = owner.ref - listener_to_key = _collection_to_key[owner] + owner_ref = owner.ref + listener_to_key = _collection_to_key[owner_ref] for listen_fn in elements: listen_ref = weakref.ref(listen_fn) key = listener_to_key[listen_ref] dispatch_reg = _key_to_collection[key] - dispatch_reg.pop(owner, None) + dispatch_reg.pop(owner_ref, None) if not dispatch_reg: del _key_to_collection[key] -class _EventKey: +class _EventKey(Generic[_ET]): """Represent :func:`.listen` arguments.""" __slots__ = ( @@ -157,10 +223,24 @@ class _EventKey: "dispatch_target", ) - def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None): + target: _ET + identifier: str + fn: _ListenerFnType + fn_key: _ListenerFnKeyType + dispatch_target: Any + _fn_wrap: Optional[_ListenerFnType] + + def __init__( + self, + target: _ET, + identifier: str, + fn: _ListenerFnType, + dispatch_target: Any, + _fn_wrap: Optional[_ListenerFnType] = None, + ): self.target = target self.identifier = identifier - self.fn = fn + self.fn = fn # type: ignore[assignment] if isinstance(fn, types.MethodType): self.fn_key = id(fn.__func__), id(fn.__self__) else: @@ -169,10 +249,10 @@ class _EventKey: self.dispatch_target = dispatch_target @property - def _key(self): + def _key(self) -> _EventKeyTupleType: return (id(self.target), self.identifier, self.fn_key) - def with_wrapper(self, fn_wrap): + def with_wrapper(self, fn_wrap: _ListenerFnType) -> _EventKey[_ET]: if fn_wrap is self._listen_fn: return self else: @@ -184,7 +264,7 @@ class _EventKey: _fn_wrap=fn_wrap, ) - def with_dispatch_target(self, dispatch_target): + def with_dispatch_target(self, dispatch_target: Any) -> _EventKey[_ET]: if dispatch_target is self.dispatch_target: return self else: @@ -196,7 +276,7 @@ class _EventKey: _fn_wrap=self.fn_wrap, ) - def listen(self, *args, **kw): + def listen(self, *args: Any, **kw: Any) -> None: once = kw.pop("once", False) once_unless_exception = kw.pop("_once_unless_exception", False) named = kw.pop("named", False) @@ -228,7 +308,7 @@ class _EventKey: else: self.dispatch_target.dispatch._listen(self, *args, **kw) - def remove(self): + def remove(self) -> None: key = self._key if key not in _key_to_collection: @@ -245,18 +325,18 @@ class _EventKey: if collection is not None and listener_fn is not None: collection.remove(self.with_wrapper(listener_fn)) - def contains(self): + def contains(self) -> bool: """Return True if this event key is registered to listen.""" return self._key in _key_to_collection def base_listen( self, - propagate=False, - insert=False, - named=False, - retval=None, - asyncio=False, - ): + propagate: bool = False, + insert: bool = False, + named: bool = False, + retval: Optional[bool] = None, + asyncio: bool = False, + ) -> None: target, identifier = self.dispatch_target, self.identifier @@ -272,21 +352,33 @@ class _EventKey: for_modify.append(self, propagate) @property - def _listen_fn(self): + def _listen_fn(self) -> _ListenerFnType: return self.fn_wrap or self.fn - def append_to_list(self, owner, list_): + def append_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: if _stored_in_collection(self, owner): list_.append(self._listen_fn) return True else: return False - def remove_from_list(self, owner, list_): + def remove_from_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> None: _removed_from_collection(self, owner) list_.remove(self._listen_fn) - def prepend_to_list(self, owner, list_): + def prepend_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: if _stored_in_collection(self, owner): list_.appendleft(self._listen_fn) return True diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index d17b0b12f5..4a8a523375 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -7,6 +7,8 @@ """Core event interfaces.""" +from __future__ import annotations + from .engine.events import ConnectionEvents from .engine.events import DialectEvents from .pool.events import PoolEvents diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 6732edd4e8..f39f4cd8fa 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -12,6 +12,8 @@ raised as a result of DBAPI exceptions are all subclasses of :exc:`.DBAPIError`. """ +from __future__ import annotations + import typing from typing import Any from typing import List diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 4e244b5b9e..bfc3459d03 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from typing import List from typing import Optional from typing import Union diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index bd6c6f41e8..13ba2e6662 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from typing import List from typing import Optional from typing import Union diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index 6a5e99e480..08035d74cd 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from typing import Optional from typing import Sequence diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index ad4449e5bb..8232ca6dbd 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from typing import Dict from typing import List from typing import Optional diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index c9520fef33..3a78ab188c 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -9,6 +9,8 @@ Mypy plugin for SQLAlchemy ORM. """ +from __future__ import annotations + from typing import Callable from typing import List from typing import Optional diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 741772eacd..943c71b367 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from typing import Any from typing import Iterable diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 5384851b10..612b627245 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -119,6 +119,8 @@ start numbering at 1 or some other integer, provide ``count_from=1``. """ +from __future__ import annotations + from typing import Callable from typing import List from typing import Optional diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index c6e9ca69af..6b06c0d6b6 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -28,6 +28,8 @@ tools which build on top of SQLAlchemy configurations to be constructed in a forwards-compatible way. """ +from __future__ import annotations + from typing import Any from typing import Callable from typing import Dict diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index c6a8b6ea7f..2f63b8569d 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -17,6 +17,8 @@ and :class:`_pool.Pool` objects, corresponds to a logger specific to that instance only. """ +from __future__ import annotations + import logging import sys from typing import Any diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index bbed933104..5a8a0f6cf6 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -13,6 +13,8 @@ documentation for an overview of how this module is used. """ +from __future__ import annotations + from . import exc as exc from . import mapper as mapperlib from . import strategy_options as strategy_options diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index a1f1faa053..8e05c6ef28 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import typing from typing import Any from typing import Collection @@ -59,7 +61,7 @@ SynonymProperty = Synonym "for entities to be matched up to a query that is established " "via :meth:`.Query.from_statement` and now does nothing.", ) -def contains_alias(alias) -> "AliasOption": +def contains_alias(alias) -> AliasOption: r"""Return a :class:`.MapperOption` that will indicate to the :class:`_query.Query` that the main table has been aliased. @@ -74,7 +76,7 @@ def contains_alias(alias) -> "AliasOption": @overload def mapped_column( - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], *args: SchemaEventTarget, nullable: Literal[None] = ..., primary_key: Literal[None] = ..., @@ -87,7 +89,7 @@ def mapped_column( @overload def mapped_column( __name: str, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], *args: SchemaEventTarget, nullable: Literal[None] = ..., primary_key: Literal[None] = ..., @@ -100,7 +102,7 @@ def mapped_column( @overload def mapped_column( __name: str, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], *args: SchemaEventTarget, nullable: Literal[True] = ..., primary_key: Literal[None] = ..., @@ -112,7 +114,7 @@ def mapped_column( @overload def mapped_column( - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], *args: SchemaEventTarget, nullable: Literal[True] = ..., primary_key: Literal[None] = ..., @@ -125,7 +127,7 @@ def mapped_column( @overload def mapped_column( __name: str, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], *args: SchemaEventTarget, nullable: Literal[False] = ..., primary_key: Literal[None] = ..., @@ -137,7 +139,7 @@ def mapped_column( @overload def mapped_column( - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], *args: SchemaEventTarget, nullable: Literal[False] = ..., primary_key: Literal[None] = ..., @@ -149,7 +151,7 @@ def mapped_column( @overload def mapped_column( - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], *args: SchemaEventTarget, nullable: bool = ..., primary_key: Literal[True] = ..., @@ -162,7 +164,7 @@ def mapped_column( @overload def mapped_column( __name: str, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], *args: SchemaEventTarget, nullable: bool = ..., primary_key: Literal[True] = ..., diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index fbfb2b2ee0..c4afdb3a9e 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -13,6 +13,9 @@ defines a large part of the ORM's interactivity. """ + +from __future__ import annotations + from collections import namedtuple import operator from typing import Any diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index e6d4a67298..33367c0c65 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -9,6 +9,8 @@ """ +from __future__ import annotations + import operator import typing from typing import Any diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 037b70257b..d0cb53e29b 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -10,6 +10,9 @@ This system allows specification of classes and expressions used in :func:`_orm.relationship` using strings. """ + +from __future__ import annotations + import re from typing import MutableMapping from typing import Union diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index ba4225563d..00ae9dac75 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -102,6 +102,8 @@ The owning object and :class:`.CollectionAttributeImpl` are also reachable through the adapter, allowing for some very sophisticated behavior. """ +from __future__ import annotations + import operator import threading import typing diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 34f291864f..f51abde0c3 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -4,6 +4,9 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + import itertools from typing import List diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 5ac9966dd0..4e28eeff7f 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -5,6 +5,9 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php """Public API functions and helpers for declarative.""" + +from __future__ import annotations + import itertools import re import typing diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 14812f2c29..d05d27b0e0 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -9,6 +9,8 @@ """ +from __future__ import annotations + from . import attributes from . import exc from . import sync diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 4526a8b332..4616e40945 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -10,6 +10,8 @@ that exist as configurational elements, but don't participate as actively in the load/persist ORM loop. """ +from __future__ import annotations + import inspect import itertools import operator diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 3d9c61c205..52a6ec4b00 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -12,6 +12,8 @@ basic add/delete mutation. """ +from __future__ import annotations + from . import attributes from . import exc as orm_exc from . import interfaces diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 61e3f6e909..453fc8903c 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import operator from .. import exc diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index cc4bedc3f8..e62a833975 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -8,6 +8,8 @@ """ORM event interfaces. """ +from __future__ import annotations + import weakref from . import instrumentation diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index 8dd4d90d68..f70ea78373 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -6,6 +6,9 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php """SQLAlchemy ORM exceptions.""" + +from __future__ import annotations + from .. import exc as sa_exc from .. import util from ..exc import MultipleResultsFound # noqa diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index f8204ec775..3caf0b22fb 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import weakref from . import util as orm_util diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 5d0b572060..a050c533a5 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -30,6 +30,8 @@ alternate instrumentation forms. """ +from __future__ import annotations + from . import base from . import collections from . import exc diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 1f9ec78f76..eed9735263 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -16,6 +16,8 @@ are exposed when inspecting mappings. """ +from __future__ import annotations + import collections import typing from typing import Any diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index a40437c67e..16ab1f4c84 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -13,6 +13,8 @@ as well as some of the attribute loading strategies. """ +from __future__ import annotations + from . import attributes from . import exc as orm_exc from . import path_registry diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py index 75abeef4cd..4324a000d1 100644 --- a/lib/sqlalchemy/orm/mapped_collection.py +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import operator from typing import Any from typing import Callable diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index cd0d1e8203..15e9b84311 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -14,6 +14,8 @@ This is a semi-private module; the main configurational API of the ORM is available in :class:`~sqlalchemy.orm.`. """ +from __future__ import annotations + from collections import deque from functools import reduce from itertools import chain diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 0d87739cc1..9a7aa91a03 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -8,6 +8,8 @@ """ +from __future__ import annotations + from functools import reduce from itertools import chain import logging diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index b3381b0390..519eb393f6 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -13,6 +13,8 @@ The functions here are called only by the unit of work functions in unitofwork.py. """ +from __future__ import annotations + from itertools import chain from itertools import groupby from itertools import zip_longest diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index f28c45fab8..9f9ca90cb4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -12,6 +12,8 @@ mapped attributes. """ +from __future__ import annotations + from typing import Any from typing import cast from typing import List diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 61174487ad..3b4b082a4c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -18,6 +18,8 @@ ORM session, whereas the ``Select`` construct interacts directly with the database to return iterable result sets. """ +from __future__ import annotations + import collections.abc as collections_abc import itertools import operator diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 47da85716f..9eb80a1072 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from . import exc as orm_exc from .base import class_mapper from .session import Session diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 6911ab5058..4140d52c5c 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -6,6 +6,8 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php """Provides the Session class and related utilities.""" +from __future__ import annotations + import contextlib import itertools import sys diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 01ee16a13f..58fa3e41a1 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -12,6 +12,8 @@ defines a large part of the ORM's interactivity. """ +from __future__ import annotations + import weakref from . import base diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py index 1421c1ae93..1afeab05bc 100644 --- a/lib/sqlalchemy/orm/state_changes.py +++ b/lib/sqlalchemy/orm/state_changes.py @@ -8,6 +8,8 @@ """ +from __future__ import annotations + import contextlib from enum import Enum from typing import Any diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 316aa7ed73..85e0151937 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -8,6 +8,8 @@ """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" +from __future__ import annotations + import collections import itertools diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 3f093e543d..63679dd275 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -8,6 +8,8 @@ """ +from __future__ import annotations + import typing from typing import Any from typing import cast diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 2994841f58..a49bd6f8ee 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -10,6 +10,8 @@ between instances based on join conditions. """ +from __future__ import annotations + from . import attributes from . import exc from . import util as orm_util diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index b478f427cc..da098e8c5f 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -13,6 +13,8 @@ organizes them in order of dependency, and executes. """ +from __future__ import annotations + from . import attributes from . import exc as orm_exc from . import util as orm_util diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 45c578355a..e00e059546 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import re import types import typing diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 6c770e201c..72c56716f1 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -10,6 +10,8 @@ """ +from __future__ import annotations + from collections import deque import time from typing import Any diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index b2ca1cfefa..70a982ce24 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -9,6 +9,8 @@ """ +from __future__ import annotations + from .sql.base import SchemaVisitor as SchemaVisitor from .sql.ddl import _CreateDropBase as _CreateDropBase from .sql.ddl import _DDLCompiles as _DDLCompiles diff --git a/lib/sqlalchemy/sql/_dml_constructors.py b/lib/sqlalchemy/sql/_dml_constructors.py index e62edf5e61..a8c24413fc 100644 --- a/lib/sqlalchemy/sql/_dml_constructors.py +++ b/lib/sqlalchemy/sql/_dml_constructors.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from .dml import Delete from .dml import Insert from .dml import Update diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index a8c9372e0f..4132ac6798 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import typing from typing import Any from typing import cast as _typing_cast diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py index 594967a40b..96e8f6b2c7 100644 --- a/lib/sqlalchemy/sql/_py_util.py +++ b/lib/sqlalchemy/sql/_py_util.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from typing import Dict diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index d3cf207da0..9043aa6d05 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from typing import Any from typing import Union diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 4d2dd26884..7d8b9ee5c4 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from typing import Mapping from typing import Sequence diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index c879bfc2d3..b76393ad6b 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -11,6 +11,8 @@ associations. """ +from __future__ import annotations + from . import operators from .base import HasCacheKey from .traversals import anon_map diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 5828f9369d..3936ed9c63 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -10,6 +10,8 @@ """ +from __future__ import annotations + import collections.abc as collections_abc from functools import reduce import itertools diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 42bd603537..49f1899d5a 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from collections import namedtuple import enum from itertools import zip_longest diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index d5a75a1658..d616417ab3 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import collections.abc as collections_abc import numbers import re diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index bf78b4231a..4a169f719d 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -22,6 +22,8 @@ To generate user-defined SQL strings, see :doc:`/ext/compiler`. """ +from __future__ import annotations + import collections import collections.abc as collections_abc import contextlib diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 4a0555bf48..4292aa9162 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -9,6 +9,8 @@ within INSERT and UPDATE statements. """ +from __future__ import annotations + import functools import operator diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index f622023b02..7acb69bebb 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -9,6 +9,8 @@ Provides the hierarchy of DDL-defining schema items as well as routines to invoke them for a create/drop call. """ +from __future__ import annotations + import typing from typing import Callable from typing import List diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 1759e686ef..001710d7bb 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -8,6 +8,8 @@ """Default implementation of SQL comparison operations. """ +from __future__ import annotations + import typing from typing import Any from typing import Callable diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 33dca66cd5..5aded307b6 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,6 +9,8 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and :class:`_expression.Delete`. """ +from __future__ import annotations + import collections.abc as collections_abc import typing diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index d14521ba73..0c532a135a 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -10,6 +10,8 @@ """ +from __future__ import annotations + import itertools import operator import re diff --git a/lib/sqlalchemy/sql/events.py b/lib/sqlalchemy/sql/events.py index c42578986d..1a1fc4c417 100644 --- a/lib/sqlalchemy/sql/events.py +++ b/lib/sqlalchemy/sql/events.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from .base import SchemaEventTarget from .. import event diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 22195cd7c5..36ddbf309b 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -11,6 +11,8 @@ """ +from __future__ import annotations + from ._dml_constructors import delete as delete from ._dml_constructors import insert as insert from ._dml_constructors import update as update diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 2e6d64c552..eb3d17ee46 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -9,6 +9,8 @@ """ +from __future__ import annotations + from typing import Any from typing import TypeVar diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index ae7358870f..9d011ef539 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import collections.abc as collections_abc import inspect import itertools diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py index 15a1566a6f..9b6fcdbae8 100644 --- a/lib/sqlalchemy/sql/naming.py +++ b/lib/sqlalchemy/sql/naming.py @@ -10,6 +10,8 @@ """ +from __future__ import annotations + import re from . import events # noqa diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 255e77b7f9..d4fa8042dd 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -10,6 +10,8 @@ """Defines operators used in SQL expressions.""" +from __future__ import annotations + from operator import add as _uncast_add from operator import and_ as _uncast_and_ from operator import contains as _uncast_contains diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index b41ef7a5d1..64bd4b951b 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -4,6 +4,8 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import typing from sqlalchemy.util.langhelpers import TypingOnly diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 9387ae030c..5286917959 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -28,6 +28,8 @@ Since these objects are part of the SQL expression language, they are usable as components in SQL expressions. """ +from __future__ import annotations + import collections import typing from typing import Any diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index b0985f75d8..7f6360edb0 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -11,6 +11,8 @@ SQL tables and derived rowsets. """ +from __future__ import annotations + import collections from enum import Enum import itertools diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 0ec771cb43..819f1dc9a8 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -8,6 +8,8 @@ """SQL specific types. """ +from __future__ import annotations + import collections.abc as collections_abc import datetime as dt import decimal diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 18fd1d4b81..4fa23d370c 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from collections import deque import collections.abc as collections_abc import itertools diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 6b878dc70b..f76b4e4621 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -9,6 +9,8 @@ """ +from __future__ import annotations + import typing from typing import Any from typing import Callable diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index c0de1902ff..e3e358cdb9 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -8,6 +8,8 @@ """High level utilities which build upon other modules here. """ +from __future__ import annotations + from collections import deque from itertools import chain import typing diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 640c07d618..523426d092 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -23,6 +23,8 @@ https://techspot.zzzeek.org/2008/01/23/expression-transformations/ . """ +from __future__ import annotations + from collections import deque import itertools import operator diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 5c79422dd1..2e46ed8245 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import contextlib from itertools import filterfalse import re diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 5c3634c7b5..e6c00de097 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import collections import contextlib import re diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py index 21890604a3..0acec0def9 100644 --- a/lib/sqlalchemy/testing/asyncio.py +++ b/lib/sqlalchemy/testing/asyncio.py @@ -17,6 +17,8 @@ # would run in the real world. +from __future__ import annotations + from functools import wraps import inspect diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index c1ca670dac..04a6a1d3ac 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import collections import typing from typing import Any diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 1d337b28f9..79adb8c3cd 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import collections import re import typing diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index 8578ca9300..67a3095706 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import sqlalchemy as sa from .. import exc as sa_exc diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 7228e5afeb..6c6b21fcec 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import re import sys diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index 5d32231889..f336444a26 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -9,6 +9,8 @@ unpickling. """ +from __future__ import annotations + from . import fixtures diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index d02b94de6f..6fc5efc50a 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -12,6 +12,8 @@ in a more fine-grained way than nose's profiling plugin. """ +from __future__ import annotations + import collections import contextlib import os diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index e51eb172e4..6e5555b330 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import logging diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 41e5d6772d..a5c601995e 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -15,6 +15,8 @@ to provide specific inclusion/exclusions. """ +from __future__ import annotations + import platform from . import config diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 78bc4d2693..ca725976bc 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import sys from . import config diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 52f30e1890..0cba4e16b1 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import decimal import gc import random diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 2d65e68ec4..e82566be76 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -4,6 +4,8 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + from . import assertions from .. import exc as sa_exc from ..exc import SATestSuiteWarning diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 07263c5b9e..9464cc9c4e 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -9,56 +9,8 @@ """ -__all__ = [ - "TypeEngine", - "TypeDecorator", - "UserDefinedType", - "ExternalType", - "INT", - "CHAR", - "VARCHAR", - "NCHAR", - "NVARCHAR", - "TEXT", - "Text", - "FLOAT", - "NUMERIC", - "REAL", - "DECIMAL", - "TIMESTAMP", - "DATETIME", - "CLOB", - "BLOB", - "BINARY", - "VARBINARY", - "BOOLEAN", - "BIGINT", - "SMALLINT", - "INTEGER", - "DATE", - "TIME", - "TupleType", - "String", - "Integer", - "SmallInteger", - "BigInteger", - "Numeric", - "Float", - "DateTime", - "Date", - "Time", - "LargeBinary", - "Boolean", - "Unicode", - "Concatenable", - "UnicodeText", - "PickleType", - "Interval", - "Enum", - "Indexable", - "ARRAY", - "JSON", -] + +from __future__ import annotations from .sql.sqltypes import _Binary from .sql.sqltypes import ARRAY @@ -117,3 +69,54 @@ from .sql.type_api import TypeDecorator from .sql.type_api import TypeEngine from .sql.type_api import UserDefinedType from .sql.type_api import Variant + +__all__ = [ + "TypeEngine", + "TypeDecorator", + "UserDefinedType", + "ExternalType", + "INT", + "CHAR", + "VARCHAR", + "NCHAR", + "NVARCHAR", + "TEXT", + "Text", + "FLOAT", + "NUMERIC", + "REAL", + "DECIMAL", + "TIMESTAMP", + "DATETIME", + "CLOB", + "BLOB", + "BINARY", + "VARBINARY", + "BOOLEAN", + "BIGINT", + "SMALLINT", + "INTEGER", + "DATE", + "TIME", + "TupleType", + "String", + "Integer", + "SmallInteger", + "BigInteger", + "Numeric", + "Float", + "DateTime", + "Date", + "Time", + "LargeBinary", + "Boolean", + "Unicode", + "Concatenable", + "UnicodeText", + "PickleType", + "Interval", + "Enum", + "Indexable", + "ARRAY", + "JSON", +] diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 8509868028..bbb08d91f2 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -6,6 +6,8 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php """Collection classes and helpers.""" +from __future__ import annotations + import collections.abc as collections_abc import operator import threading diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index b9e58e68cd..fa20667027 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -157,8 +157,7 @@ class AsyncAdaptedLock: def __enter__(self): # await is used to acquire the lock only after the first calling # coroutine has created the mutex. - await_fallback(self.mutex.acquire()) - return self + return await_fallback(self.mutex.acquire()) def __exit__(self, *arg, **kw): self.mutex.release() diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/_preloaded.py index b0f8ab444a..511b93351d 100644 --- a/lib/sqlalchemy/util/_preloaded.py +++ b/lib/sqlalchemy/util/_preloaded.py @@ -9,6 +9,8 @@ runtime. """ +from __future__ import annotations + import sys from types import ModuleType import typing diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 9bf5c3546d..ee54180ac4 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -1,3 +1,12 @@ +# util/_py_collections.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + from itertools import filterfalse from typing import AbstractSet from typing import Any diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 62cffa556e..d30236dd9d 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -6,6 +6,7 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php """Handle Python version/platform incompatibilities.""" + from __future__ import annotations import base64 @@ -137,6 +138,9 @@ def cmp(a, b): def _formatannotation(annotation, base_module=None): """vendored from python 3.7""" + if isinstance(annotation, str): + return f'"{annotation}"' + if getattr(annotation, "__module__", None) == "typing": return f'"{repr(annotation).replace("typing.", "").replace("~", "")}"' if isinstance(annotation, type): diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 6b94a22948..778d1275b9 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -4,8 +4,10 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations import asyncio # noqa +import typing have_greenlet = False greenlet_error = None @@ -28,7 +30,7 @@ else: _util_async_run_coroutine_function as _util_async_run_coroutine_function, # noqa F401, E501 ) -if not have_greenlet: +if not typing.TYPE_CHECKING and not have_greenlet: def _not_implemented(): # this conditional is to prevent pylance from considering diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 7c25861665..f91d902dae 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -8,12 +8,15 @@ """Helpers related to deprecation of functions, methods, classes, other functionality.""" +from __future__ import annotations + import re from typing import Any from typing import Callable from typing import cast from typing import Optional from typing import Tuple +from typing import Type from typing import TypeVar from . import compat @@ -28,14 +31,22 @@ from .. import exc _T = TypeVar("_T", bound=Any) -def _warn_with_version(msg, version, type_, stacklevel, code=None): +def _warn_with_version( + msg: str, + version: str, + type_: Type[exc.SADeprecationWarning], + stacklevel: int, + code: Optional[str] = None, +) -> None: warn = type_(msg, code=code) warn.deprecated_since = version _warnings_warn(warn, stacklevel=stacklevel + 1) -def warn_deprecated(msg, version, stacklevel=3, code=None): +def warn_deprecated( + msg: str, version: str, stacklevel: int = 3, code: Optional[str] = None +) -> None: _warn_with_version( msg, version, exc.SADeprecationWarning, stacklevel, code=code ) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index ed879894d5..9e024b3c03 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -9,6 +9,7 @@ modules, classes, hierarchies, attributes, functions, and methods. """ +from __future__ import annotations import collections from functools import update_wrapper @@ -452,7 +453,9 @@ def get_func_kwargs(func): return compat.inspect_getfullargspec(func)[0] -def get_callable_argspec(fn, no_self=False, _is_init=False): +def get_callable_argspec( + fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False +) -> compat.FullArgSpec: """Return the argument signature for any callable. All pure-Python callables are accepted, including @@ -496,10 +499,12 @@ def get_callable_argspec(fn, no_self=False, _is_init=False): fn.__init__, no_self=no_self, _is_init=True ) elif hasattr(fn, "__func__"): - return compat.inspect_getfullargspec(fn.__func__) + return compat.inspect_getfullargspec(fn.__func__) # type: ignore[attr-defined] # noqa E501 elif hasattr(fn, "__call__"): - if inspect.ismethod(fn.__call__): - return get_callable_argspec(fn.__call__, no_self=no_self) + if inspect.ismethod(fn.__call__): # type: ignore [operator] + return get_callable_argspec( + fn.__call__, no_self=no_self # type: ignore [operator] + ) else: raise TypeError("Can't inspect callable: %s" % fn) else: @@ -1521,7 +1526,12 @@ class hybridmethod: class _symbol(int): name: str - def __new__(cls, name, doc=None, canonical=None): + def __new__( + cls, + name: str, + doc: Optional[str] = None, + canonical: Optional[int] = None, + ) -> "_symbol": """Construct a new named symbol.""" assert isinstance(name, str) if canonical is None: @@ -1570,7 +1580,12 @@ class symbol: symbols: Dict[str, "_symbol"] = {} _lock = threading.Lock() - def __new__(cls, name, doc=None, canonical=None): + def __new__( # type: ignore[misc] + cls, + name: str, + doc: Optional[str] = None, + canonical: Optional[int] = None, + ) -> _symbol: with cls._lock: sym = cls.symbols.get(name) if sym is None: @@ -1730,13 +1745,15 @@ def _warnings_warn(message, category=None, stacklevel=2): warnings.warn(message, stacklevel=stacklevel + 1) -def only_once(fn, retry_on_exception): +def only_once( + fn: Callable[..., _T], retry_on_exception: bool +) -> Callable[..., Optional[_T]]: """Decorate the given function to be a no-op after it is called exactly once.""" once = [fn] - def go(*arg, **kw): + def go(*arg: Any, **kw: Any) -> Optional[_T]: # strong reference fn so that it isn't garbage collected, # which interferes with the event system's expectations strong_fn = fn # noqa @@ -1749,6 +1766,8 @@ def only_once(fn, retry_on_exception): once.insert(0, once_fn) raise + return None + return go @@ -1936,7 +1955,7 @@ def add_parameter_text(params, text): return decorate -def _dedent_docstring(text): +def _dedent_docstring(text: str) -> str: split_text = text.split("\n", 1) if len(split_text) == 1: return text @@ -1948,8 +1967,10 @@ def _dedent_docstring(text): return textwrap.dedent(text) -def inject_docstring_text(doctext, injecttext, pos): - doctext = _dedent_docstring(doctext or "") +def inject_docstring_text( + given_doctext: Optional[str], injecttext: str, pos: int +) -> str: + doctext: str = _dedent_docstring(given_doctext or "") lines = doctext.split("\n") if len(lines) == 1: lines.append("") @@ -1969,7 +1990,7 @@ def inject_docstring_text(doctext, injecttext, pos): _param_reg = re.compile(r"(\s+):param (.+?):") -def inject_param_text(doctext, inject_params): +def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str: doclines = collections.deque(doctext.splitlines()) lines = [] @@ -2012,7 +2033,7 @@ def inject_param_text(doctext, inject_params): return "\n".join(lines) -def repr_tuple_names(names): +def repr_tuple_names(names: List[str]) -> Optional[str]: """Trims a list of strings from the middle and return a string of up to four elements. Strings greater than 11 characters will be truncated""" if len(names) == 0: diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index bbc819fc31..bccb16672c 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -7,6 +7,8 @@ """Topological sorting algorithms.""" +from __future__ import annotations + from .. import util from ..exc import CircularDependencyError diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 56ea4d0e06..404f239c89 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import typing from typing import Any diff --git a/pyproject.toml b/pyproject.toml index be5dd15962..b6f0952390 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ markers = [ [tool.pyright] include = [ + "lib/sqlalchemy/event/", "lib/sqlalchemy/events.py", "lib/sqlalchemy/exc.py", "lib/sqlalchemy/log.py", @@ -77,6 +78,7 @@ strict = true # strict checking [[tool.mypy.overrides]] module = [ + "sqlalchemy.event.*", "sqlalchemy.events", "sqlalchemy.exc", "sqlalchemy.inspection", diff --git a/setup.cfg b/setup.cfg index 99abcea1c8..a8c12377da 100644 --- a/setup.cfg +++ b/setup.cfg @@ -109,8 +109,9 @@ import-order-style = google application-import-names = sqlalchemy,test per-file-ignores = **/__init__.py:F401 - test/ext/mypy/plain_files/*:F821,E501 - test/ext/mypy/plugin_files/*:F821,E501 + test/*:FA100 + test/ext/mypy/plain_files/*:F821,E501,FA100 + test/ext/mypy/plugin_files/*:F821,E501,FA100 lib/sqlalchemy/events.py:F401 lib/sqlalchemy/schema.py:F401 lib/sqlalchemy/types.py:F401 diff --git a/tox.ini b/tox.ini index 71fef2a834..c3420a00f7 100644 --- a/tox.ini +++ b/tox.ini @@ -156,6 +156,7 @@ deps= flake8 flake8-import-order flake8-builtins + flake8-future-annotations flake8-docstrings>=1.6.0 flake8-rst-docstrings # flake8-rst-docstrings dependency, leaving it here