From: Mike Bayer Date: Sun, 9 Jan 2022 16:49:02 +0000 (-0500) Subject: mypy: sqlalchemy.util X-Git-Tag: rel_2_0_0b1~514 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ff1ab665cb1694b85085680d1a02c7c11fa2a6d4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git mypy: sqlalchemy.util Starting to set up practices and conventions to get the library typed. Key goals for typing are: 1. whole library can pass mypy without any strict turned on. 2. we can incrementally turn on some strict flags on a per-package/ module basis, as here we turn on more strictness for sqlalchemy.util, exc, and log 3. mypy ORM plugin tests work fully without sqlalchemy2-stubs installed 4. public facing methods all have return types, major parameter signatures filled in also 5. Foundational elements like util etc. are typed enough so that we can use them in fully typed internals higher up the stack. Conventions set up here: 1. we can use lots of config in setup.cfg to limit where mypy is throwing errors and how detailed it should be in different packages / modules. We can use this to push up gerrits that will pass tests fully without everything being typed. 2. a new tox target pep484 is added. this links to a new jenkins pep484 job that works across all projects (alembic, dogpile, etc.) We've worked around some mypy bugs that will likely be around for awhile, and also set up some core practices for how to deal with certain things such as public_factory modules (mypy won't accept a module from a callable at all, so need to use simple type checking conditionals). References: #6810 Change-Id: I80be58029896a29fd9f491aa3215422a8b705e12 --- diff --git a/.github/workflows/run-test.yaml b/.github/workflows/run-test.yaml index 6fbb29bdc9..196e3c1b15 100644 --- a/.github/workflows/run-test.yaml +++ b/.github/workflows/run-test.yaml @@ -188,3 +188,36 @@ jobs: - name: Run tests run: tox -e pep8 + + run-pep484: + name: pep484-${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + # run this job using this matrix, excluding some combinations below. + matrix: + os: + - "ubuntu-latest" + python-version: + - "3.10" + + fail-fast: false + + # steps to run in each job. Some are github actions, others run shell commands + steps: + - name: Checkout repo + uses: actions/checkout@v2 + + - name: Set up python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + architecture: ${{ matrix.architecture }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install --upgrade tox setuptools + pip list + + - name: Run tests + run: tox -e pep484 diff --git a/MANIFEST.in b/MANIFEST.in index 0cb6133851..eb447a0bd0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,6 +5,11 @@ recursive-include doc *.html *.css *.txt *.js *.png *.py Makefile *.rst *.sty recursive-include examples *.py *.xml recursive-include test *.py *.dat *.testpatch +# for some reason in some environments stale Cython .c files +# are being pulled in, these should never be in a dist +exclude lib/sqlalchemy/cyextension/*.c +exclude lib/sqlalchemy/cyextension/*.so + # include the pyx and pxd extensions, which otherwise # don't come in if --with-cextensions isn't specified. recursive-include lib *.pyx *.pxd *.txt *.typed diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx index e695d4c62d..5a344da432 100644 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -22,52 +22,53 @@ cdef list cunique_list(seq, hashfunc=None): def unique_list(seq, hashfunc=None): return cunique_list(seq, hashfunc) -cdef class OrderedSet(set): +cdef class OrderedSet: cdef list _list + cdef set _set def __init__(self, d=None): - set.__init__(self) if d is not None: self._list = cunique_list(d) - set.update(self, self._list) + self._set = set(self._list) else: self._list = [] + self._set = set() cdef OrderedSet _copy(self): cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) cp._list = list(self._list) - set.update(cp, cp._list) + cp._set = set(cp._list) return cp cdef OrderedSet _from_list(self, list new_list): cdef OrderedSet new = OrderedSet.__new__(OrderedSet) new._list = new_list - set.update(new, new_list) + new._set = set(new_list) return new def add(self, element): if element not in self: self._list.append(element) - PySet_Add(self, element) + PySet_Add(self._set, element) def remove(self, element): # set.remove will raise if element is not in self - set.remove(self, element) + self._set.remove(element) self._list.remove(element) def insert(self, Py_ssize_t pos, element): if element not in self: self._list.insert(pos, element) - PySet_Add(self, element) + PySet_Add(self._set, element) def discard(self, element): if element in self: - set.remove(self, element) + self._set.remove(element) self._list.remove(element) def clear(self): - set.clear(self) + self._set.clear() self._list = [] def __getitem__(self, key): @@ -84,21 +85,34 @@ cdef class OrderedSet(set): __str__ = __repr__ - def update(self, iterable): - for e in iterable: - if e not in self: - self._list.append(e) - set.add(self, e) - return self + def update(self, *iterables): + for iterable in iterables: + for e in iterable: + if e not in self: + self._list.append(e) + self._set.add(e) def __ior__(self, iterable): - return self.update(iterable) + self.update(iterable) + return self def union(self, other): result = self._copy() result.update(other) return result + def __len__(self) -> int: + return len(self._set) + + def __eq__(self, other): + return self._set == other + + def __ne__(self, other): + return self._set != other + + def __contains__(self, element): + return element in self._set + def __or__(self, other): return self.union(other) @@ -138,27 +152,27 @@ cdef class OrderedSet(set): cdef set other_set = self._to_set(other) set.intersection_update(self, other_set) self._list = [a for a in self._list if a in other_set] - return self def __iand__(self, other): - return self.intersection_update(other) + self.intersection_update(other) + return self def symmetric_difference_update(self, other): set.symmetric_difference_update(self, other) self._list = [a for a in self._list if a in self] self._list += [a for a in other if a in self] - return self def __ixor__(self, other): - return self.symmetric_difference_update(other) + self.symmetric_difference_update(other) + return self def difference_update(self, other): set.difference_update(self, other) self._list = [a for a in self._list if a in self] - return self def __isub__(self, other): - return self.difference_update(other) + self.difference_update(other) + return self cdef object cy_id(object item): diff --git a/lib/sqlalchemy/cyextension/immutabledict.pyx b/lib/sqlalchemy/cyextension/immutabledict.pyx index 89bcf3ed6c..d07c81bd49 100644 --- a/lib/sqlalchemy/cyextension/immutabledict.pyx +++ b/lib/sqlalchemy/cyextension/immutabledict.pyx @@ -12,10 +12,25 @@ class ImmutableContainer: __delitem__ = __setitem__ = __setattr__ = _immutable +class ImmutableDictBase(dict): + def _immutable(self, *a,**kw): + _immutable_fn(self) + + @classmethod + def __class_getitem__(cls, key): + return cls + + __delitem__ = __setitem__ = __setattr__ = _immutable + clear = pop = popitem = setdefault = update = _immutable + cdef class immutabledict(dict): def __repr__(self): return f"immutabledict({dict.__repr__(self)})" + @classmethod + def __class_getitem__(cls, key): + return cls + def union(self, *args, **kw): cdef dict to_merge = None cdef immutabledict result diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 40fc5d1620..f7d02e3b0c 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -19,13 +19,17 @@ from .util import _distill_params_20 from .util import _distill_raw_params from .util import TransactionalContext from .. import exc +from .. import inspection from .. import log from .. import util from ..sql import compiler from ..sql import util as sql_util +from ..sql._typing import _ExecuteOptions +from ..sql._typing import _ExecuteParams if typing.TYPE_CHECKING: from .interfaces import Dialect + from .reflection import Inspector # noqa from .url import URL from ..pool import Pool from ..pool import PoolProxiedConnection @@ -38,7 +42,7 @@ _EMPTY_EXECUTION_OPTS = util.immutabledict() NO_OPTIONS = util.immutabledict() -class Connection(ConnectionEventsTarget): +class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """Provides high-level functionality for a wrapped DB-API connection. The :class:`_engine.Connection` object is procured by calling @@ -1079,7 +1083,12 @@ class Connection(ConnectionEventsTarget): return self.execute(statement, parameters, execution_options).scalars() - def execute(self, statement, parameters=None, execution_options=None): + def execute( + self, + statement, + parameters: Optional[_ExecuteParams] = None, + execution_options: Optional[_ExecuteOptions] = None, + ): r"""Executes a SQL statement construct and returns a :class:`_engine.Result`. @@ -2270,7 +2279,9 @@ class TwoPhaseTransaction(RootTransaction): self.connection._commit_twophase_impl(self.xid, self._is_prepared) -class Engine(ConnectionEventsTarget, log.Identified): +class Engine( + ConnectionEventsTarget, log.Identified, inspection.Inspectable["Inspector"] +): """ Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 7eebb1f019..6fb8279894 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -5,6 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from typing import Any from . import base from . import url as _url @@ -40,7 +41,7 @@ from ..sql import compiler "is deprecated and will be removed in a future release. ", ), ) -def create_engine(url, **kwargs): +def create_engine(url: "_url.URL", **kwargs: Any) -> "base.Engine": """Create a new :class:`_engine.Engine` instance. The standard calling form is to send the :ref:`URL ` as the diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 371b9c7624..df7a53ab7d 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -57,7 +57,7 @@ def cache(fn, self, con, *args, **kw): @inspection._self_inspects -class Inspector: +class Inspector(inspection.Inspectable["Inspector"]): """Performs database schema inspection. The Inspector acts as a proxy to the reflection methods of the diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index 48ce1629ae..a059662224 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -30,13 +30,13 @@ as well as support for subclass propagation (e.g. events assigned to """ import collections from itertools import chain +import threading import weakref from . import legacy from . import registry from .. import exc from .. import util -from ..util import threading from ..util.concurrency import AsyncAdaptedLock diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 8fdacbdf2e..6732edd4e8 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -12,25 +12,39 @@ raised as a result of DBAPI exceptions are all subclasses of :exc:`.DBAPIError`. """ +import typing +from typing import Any +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import Union from .util import _preloaded from .util import compat +if typing.TYPE_CHECKING: + from .engine.interfaces import Dialect + from .sql._typing import _ExecuteParams + from .sql.compiler import Compiled + from .sql.elements import ClauseElement + _version_token = None class HasDescriptionCode: """helper which adds 'code' as an attribute and '_code_str' as a method""" - code = None + code: Optional[str] = None - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any): code = kw.pop("code", None) if code is not None: self.code = code super(HasDescriptionCode, self).__init__(*arg, **kw) - def _code_str(self): + def _code_str(self) -> str: if not self.code: return "" else: @@ -43,7 +57,7 @@ class HasDescriptionCode: ) ) - def __str__(self): + def __str__(self) -> str: message = super(HasDescriptionCode, self).__str__() if self.code: message = "%s %s" % (message, self._code_str()) @@ -53,7 +67,7 @@ class HasDescriptionCode: class SQLAlchemyError(HasDescriptionCode, Exception): """Generic error class.""" - def _message(self): + def _message(self) -> str: # rules: # # 1. single arg string will usually be a unicode @@ -64,16 +78,18 @@ class SQLAlchemyError(HasDescriptionCode, Exception): # SQLAlchemy though this is happening in at least one known external # library, call str() which does a repr(). # + text: str + if len(self.args) == 1: - text = self.args[0] + arg_text = self.args[0] - if isinstance(text, bytes): - text = compat.decode_backslashreplace(text, "utf-8") + if isinstance(arg_text, bytes): + text = compat.decode_backslashreplace(arg_text, "utf-8") # This is for when the argument is not a string of any sort. # Otherwise, converting this exception to string would fail for # non-string arguments. else: - text = str(text) + text = str(arg_text) return text else: @@ -82,7 +98,7 @@ class SQLAlchemyError(HasDescriptionCode, Exception): # a repr() of the tuple return str(self.args) - def _sql_message(self): + def _sql_message(self) -> str: message = self._message() if self.code: @@ -90,7 +106,7 @@ class SQLAlchemyError(HasDescriptionCode, Exception): return message - def __str__(self): + def __str__(self) -> str: return self._sql_message() @@ -110,13 +126,13 @@ class ObjectNotExecutableError(ArgumentError): """ - def __init__(self, target): + def __init__(self, target: Any): super(ObjectNotExecutableError, self).__init__( "Not an executable object: %r" % target ) self.target = target - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return self.__class__, (self.target,) @@ -154,7 +170,14 @@ class CircularDependencyError(SQLAlchemyError): """ - def __init__(self, message, cycles, edges, msg=None, code=None): + def __init__( + self, + message: str, + cycles: Any, + edges: Any, + msg: Optional[str] = None, + code: Optional[str] = None, + ): if msg is None: message += " (%s)" % ", ".join(repr(s) for s in cycles) else: @@ -163,7 +186,7 @@ class CircularDependencyError(SQLAlchemyError): self.cycles = cycles self.edges = edges - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return ( self.__class__, (None, self.cycles, self.edges, self.args[0]), @@ -187,7 +210,12 @@ class UnsupportedCompilationError(CompileError): code = "l7de" - def __init__(self, compiler, element_type, message=None): + def __init__( + self, + compiler: "Compiled", + element_type: Type["ClauseElement"], + message: Optional[str] = None, + ): super(UnsupportedCompilationError, self).__init__( "Compiler %r can't render element of type %s%s" % (compiler, element_type, ": %s" % message if message else "") @@ -196,7 +224,7 @@ class UnsupportedCompilationError(CompileError): self.element_type = element_type self.message = message - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return self.__class__, (self.compiler, self.element_type, self.message) @@ -216,7 +244,7 @@ class DisconnectionError(SQLAlchemyError): """ - invalidate_pool = False + invalidate_pool: bool = False class InvalidatePoolError(DisconnectionError): @@ -234,7 +262,7 @@ class InvalidatePoolError(DisconnectionError): """ - invalidate_pool = True + invalidate_pool: bool = True class TimeoutError(SQLAlchemyError): # noqa @@ -332,11 +360,11 @@ class NoReferencedTableError(NoReferenceError): """ - def __init__(self, message, tname): + def __init__(self, message: str, tname: str): NoReferenceError.__init__(self, message) self.table_name = tname - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return self.__class__, (self.args[0], self.table_name) @@ -346,12 +374,12 @@ class NoReferencedColumnError(NoReferenceError): """ - def __init__(self, message, tname, cname): + def __init__(self, message: str, tname: str, cname: str): NoReferenceError.__init__(self, message) self.table_name = tname self.column_name = cname - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return ( self.__class__, (self.args[0], self.table_name, self.column_name), @@ -409,26 +437,29 @@ class StatementError(SQLAlchemyError): """ - statement = None + statement: Optional[str] = None """The string SQL statement being invoked when this exception occurred.""" - params = None + params: Optional["_ExecuteParams"] = None """The parameter list being used when this exception occurred.""" - orig = None - """The DBAPI exception object.""" + orig: Optional[BaseException] = None + """The original exception that was thrown. + + """ - ismulti = None + ismulti: Optional[bool] = None + """multi parameter passed to repr_params(). None is meaningful.""" def __init__( self, - message, - statement, - params, - orig, - hide_parameters=False, - code=None, - ismulti=None, + message: str, + statement: Optional[str], + params: Optional["_ExecuteParams"], + orig: Optional[BaseException], + hide_parameters: bool = False, + code: Optional[str] = None, + ismulti: Optional[bool] = None, ): SQLAlchemyError.__init__(self, message, code=code) self.statement = statement @@ -436,12 +467,12 @@ class StatementError(SQLAlchemyError): self.orig = orig self.ismulti = ismulti self.hide_parameters = hide_parameters - self.detail = [] + self.detail: List[str] = [] - def add_detail(self, msg): + def add_detail(self, msg: str) -> None: self.detail.append(msg) - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return ( self.__class__, ( @@ -457,8 +488,11 @@ class StatementError(SQLAlchemyError): ) @_preloaded.preload_module("sqlalchemy.sql.util") - def _sql_message(self): - util = _preloaded.preloaded.sql_util + def _sql_message(self) -> str: + if typing.TYPE_CHECKING: + from .sql import util + else: + util = _preloaded.preloaded.sql_util details = [self._message()] if self.statement: @@ -505,18 +539,67 @@ class DBAPIError(StatementError): code = "dbapi" + # I dont think I'm going to try to do overloads like this everywhere + # in the library, but as this module is early days for me typing everything + # I am sort of just practicing + + @overload @classmethod def instance( cls, - statement, - params, - orig, - dbapi_base_err, - hide_parameters=False, - connection_invalidated=False, - dialect=None, - ismulti=None, - ): + statement: str, + params: "_ExecuteParams", + orig: DontWrapMixin, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional["Dialect"] = None, + ismulti: Optional[bool] = None, + ) -> DontWrapMixin: + ... + + @overload + @classmethod + def instance( + cls, + statement: str, + params: "_ExecuteParams", + orig: Exception, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional["Dialect"] = None, + ismulti: Optional[bool] = None, + ) -> StatementError: + ... + + @overload + @classmethod + def instance( + cls, + statement: str, + params: "_ExecuteParams", + orig: BaseException, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional["Dialect"] = None, + ismulti: Optional[bool] = None, + ) -> BaseException: + ... + + @classmethod + def instance( + cls, + statement: str, + params: "_ExecuteParams", + orig: Union[BaseException, DontWrapMixin], + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional["Dialect"] = None, + ismulti: Optional[bool] = None, + ) -> Union[BaseException, DontWrapMixin]: # Don't ever wrap these, just return them directly as if # DBAPIError didn't exist. if ( @@ -578,7 +661,7 @@ class DBAPIError(StatementError): ismulti=ismulti, ) - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return ( self.__class__, ( @@ -595,13 +678,13 @@ class DBAPIError(StatementError): def __init__( self, - statement, - params, - orig, - hide_parameters=False, - connection_invalidated=False, - code=None, - ismulti=None, + statement: str, + params: "_ExecuteParams", + orig: BaseException, + hide_parameters: bool = False, + connection_invalidated: bool = False, + code: Optional[str] = None, + ismulti: Optional[bool] = None, ): try: text = str(orig) @@ -684,7 +767,7 @@ class SATestSuiteWarning(Warning): class SADeprecationWarning(HasDescriptionCode, DeprecationWarning): """Issued for usage of deprecated APIs.""" - deprecated_since = None + deprecated_since: Optional[str] = None "Indicates the version that started raising this deprecation warning" @@ -700,10 +783,10 @@ class Base20DeprecationWarning(SADeprecationWarning): """ - deprecated_since = "1.4" + deprecated_since: Optional[str] = "1.4" "Indicates the version that started raising this deprecation warning" - def __str__(self): + def __str__(self) -> str: return ( super(Base20DeprecationWarning, self).__str__() + " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)" @@ -724,7 +807,7 @@ class SAPendingDeprecationWarning(PendingDeprecationWarning): """ - deprecated_since = None + deprecated_since: Optional[str] = None "Indicates the version that started raising this deprecation warning" diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index 7f9822d02e..c6e9ca69af 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -28,15 +28,43 @@ tools which build on top of SQLAlchemy configurations to be constructed in a forwards-compatible way. """ +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import overload +from typing import Type +from typing import TypeVar +from typing import Union from . import exc -from . import util +from .util.typing import Literal +_T = TypeVar("_T", bound=Any) -_registrars = util.defaultdict(list) +_registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {} -def inspect(subject, raiseerr=True): +class Inspectable(Generic[_T]): + """define a class as inspectable. + + This allows typing to set up a linkage between an object that + can be inspected and the type of inspection it returns. + + """ + + +@overload +def inspect(subject: Inspectable[_T], raiseerr: bool = True) -> _T: + ... + + +@overload +def inspect(subject: Any, raiseerr: bool = True) -> Any: + ... + + +def inspect(subject: Any, raiseerr: bool = True) -> Any: """Produce an inspection object for the given target. The returned value in some cases may be the @@ -58,12 +86,14 @@ def inspect(subject, raiseerr=True): type_ = type(subject) for cls in type_.__mro__: if cls in _registrars: - reg = _registrars[cls] - if reg is True: + reg = _registrars.get(cls, None) + if reg is None: + continue + elif reg is True: return subject ret = reg(subject) if ret is not None: - break + return ret else: reg = ret = None @@ -75,8 +105,10 @@ def inspect(subject, raiseerr=True): return ret -def _inspects(*types): - def decorate(fn_or_cls): +def _inspects( + *types: type, +) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]: + def decorate(fn_or_cls: Callable[[Any], Any]) -> Callable[[Any], Any]: for type_ in types: if type_ in _registrars: raise AssertionError( @@ -88,6 +120,8 @@ def _inspects(*types): return decorate -def _self_inspects(cls): - _inspects(cls)(True) +def _self_inspects(cls: Type[_T]) -> Type[_T]: + if cls in _registrars: + raise AssertionError("Type %s is already " "registered" % cls) + _registrars[cls] = True return cls diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 6431053a85..e9ab8f4236 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -17,10 +17,21 @@ and :class:`_pool.Pool` objects, corresponds to a logger specific to that instance only. """ - import logging import sys +from typing import Any +from typing import Optional +from typing import overload +from typing import Set +from typing import Type +from typing import TypeVar +from typing import Union + +from .util.typing import Literal + +_IT = TypeVar("_IT", bound="Identified") +_EchoFlagType = Union[None, bool, Literal["debug"]] # set initial level to WARN. This so that # log statements don't occur in the absence of explicit @@ -30,7 +41,7 @@ if rootlogger.level == logging.NOTSET: rootlogger.setLevel(logging.WARN) -def _add_default_handler(logger): +def _add_default_handler(logger: logging.Logger) -> None: handler = logging.StreamHandler(sys.stdout) handler.setFormatter( logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s") @@ -38,32 +49,40 @@ def _add_default_handler(logger): logger.addHandler(handler) -_logged_classes = set() +_logged_classes: Set[Type["Identified"]] = set() -def _qual_logger_name_for_cls(cls): +def _qual_logger_name_for_cls(cls: Type["Identified"]) -> str: return ( getattr(cls, "_sqla_logger_namespace", None) or cls.__module__ + "." + cls.__name__ ) -def class_logger(cls): +def class_logger(cls: Type[_IT]) -> Type[_IT]: logger = logging.getLogger(_qual_logger_name_for_cls(cls)) - cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG) - cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO) + cls._should_log_debug = lambda self: logger.isEnabledFor( # type: ignore[assignment] # noqa E501 + logging.DEBUG + ) + cls._should_log_info = lambda self: logger.isEnabledFor( # type: ignore[assignment] # noqa E501 + logging.INFO + ) cls.logger = logger _logged_classes.add(cls) return cls class Identified: - logging_name = None + logging_name: Optional[str] = None - def _should_log_debug(self): + logger: Union[logging.Logger, "InstanceLogger"] + + _echo: _EchoFlagType + + def _should_log_debug(self) -> bool: return self.logger.isEnabledFor(logging.DEBUG) - def _should_log_info(self): + def _should_log_info(self) -> bool: return self.logger.isEnabledFor(logging.INFO) @@ -94,7 +113,9 @@ class InstanceLogger: "debug": logging.DEBUG, } - def __init__(self, echo, name): + _echo: _EchoFlagType + + def __init__(self, echo: _EchoFlagType, name: str): self.echo = echo self.logger = logging.getLogger(name) @@ -106,41 +127,41 @@ class InstanceLogger: # # Boilerplate convenience methods # - def debug(self, msg, *args, **kwargs): + def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: """Delegate a debug call to the underlying logger.""" self.log(logging.DEBUG, msg, *args, **kwargs) - def info(self, msg, *args, **kwargs): + def info(self, msg: str, *args: Any, **kwargs: Any) -> None: """Delegate an info call to the underlying logger.""" self.log(logging.INFO, msg, *args, **kwargs) - def warning(self, msg, *args, **kwargs): + def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: """Delegate a warning call to the underlying logger.""" self.log(logging.WARNING, msg, *args, **kwargs) warn = warning - def error(self, msg, *args, **kwargs): + def error(self, msg: str, *args: Any, **kwargs: Any) -> None: """ Delegate an error call to the underlying logger. """ self.log(logging.ERROR, msg, *args, **kwargs) - def exception(self, msg, *args, **kwargs): + def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: """Delegate an exception call to the underlying logger.""" kwargs["exc_info"] = 1 self.log(logging.ERROR, msg, *args, **kwargs) - def critical(self, msg, *args, **kwargs): + def critical(self, msg: str, *args: Any, **kwargs: Any) -> None: """Delegate a critical call to the underlying logger.""" self.log(logging.CRITICAL, msg, *args, **kwargs) - def log(self, level, msg, *args, **kwargs): + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: """Delegate a log call to the underlying logger. The level here is determined by the echo @@ -162,14 +183,14 @@ class InstanceLogger: if level >= selected_level: self.logger._log(level, msg, args, **kwargs) - def isEnabledFor(self, level): + def isEnabledFor(self, level: int) -> bool: """Is this logger enabled for level 'level'?""" if self.logger.manager.disable >= level: return False return level >= self.getEffectiveLevel() - def getEffectiveLevel(self): + def getEffectiveLevel(self) -> int: """What's the effective level for this logger?""" level = self._echo_map[self.echo] @@ -178,7 +199,9 @@ class InstanceLogger: return level -def instance_logger(instance, echoflag=None): +def instance_logger( + instance: Identified, echoflag: _EchoFlagType = None +) -> None: """create a logger for an instance that implements :class:`.Identified`.""" if instance.logging_name: @@ -191,6 +214,8 @@ def instance_logger(instance, echoflag=None): instance._echo = echoflag + logger: Union[logging.Logger, InstanceLogger] + if echoflag in (False, None): # if no echo setting or False, return a Logger directly, # avoiding overhead of filtering @@ -215,11 +240,25 @@ class echo_property: ``logging.DEBUG``. """ - def __get__(self, instance, owner): + @overload + def __get__( + self, instance: "Literal[None]", owner: "echo_property" + ) -> "echo_property": + ... + + @overload + def __get__( + self, instance: Identified, owner: "echo_property" + ) -> _EchoFlagType: + ... + + def __get__( + self, instance: Optional[Identified], owner: "echo_property" + ) -> Union["echo_property", _EchoFlagType]: if instance is None: return self else: return instance._echo - def __set__(self, instance, value): + def __set__(self, instance: Identified, value: _EchoFlagType) -> None: instance_logger(instance, echoflag=value) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 260ad1f990..75ce8216f6 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -104,6 +104,7 @@ through the adapter, allowing for some very sophisticated behavior. """ import operator +import threading import weakref from sqlalchemy.util.compat import inspect_getfullargspec @@ -122,7 +123,7 @@ __all__ = [ "attribute_mapped_collection", ] -__instrumentation_mutex = util.threading.Lock() +__instrumentation_mutex = threading.Lock() class _PlainColumnGetter: diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 29a9c9edf7..59fabb9b6b 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -42,6 +42,9 @@ from ..sql.selectable import FromClause from ..util import hybridmethod from ..util import hybridproperty +if typing.TYPE_CHECKING: + from .state import InstanceState # noqa + _T = TypeVar("_T", bound=Any) @@ -64,7 +67,9 @@ def has_inherited_table(cls): return False -class DeclarativeAttributeIntercept(type): +class DeclarativeAttributeIntercept( + type, inspection.Inspectable["Mapper[Any]"] +): """Metaclass that may be used in conjunction with the :class:`_orm.DeclarativeBase` class to support addition of class attributes dynamically. @@ -78,7 +83,7 @@ class DeclarativeAttributeIntercept(type): _del_attribute(cls, key) -class DeclarativeMeta(type): +class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]): def __init__(cls, classname, bases, dict_, **kw): # early-consume registry from the initial declarative base, # assign privately to not conflict with subclass attributes named @@ -421,7 +426,7 @@ def _setup_declarative_base(cls): cls.metadata = cls.registry.metadata -class DeclarativeBaseNoMeta: +class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]): """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass to intercept new attributes. @@ -451,7 +456,10 @@ class DeclarativeBaseNoMeta: cls._sa_registry.map_declaratively(cls) -class DeclarativeBase(metaclass=DeclarativeAttributeIntercept): +class DeclarativeBase( + inspection.Inspectable["InstanceState"], + metaclass=DeclarativeAttributeIntercept, +): """Base class used for declarative class definitions. The :class:`_orm.DeclarativeBase` allows for the creation of new diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e9a89d102b..fdf065488a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -18,6 +18,7 @@ from collections import deque from functools import reduce from itertools import chain import sys +import threading from typing import Generic from typing import Type from typing import TypeVar @@ -83,7 +84,7 @@ _already_compiling = False NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE") # lock used to synchronize the "mapper configure" step -_CONFIGURE_MUTEX = util.threading.RLock() +_CONFIGURE_MUTEX = threading.RLock() @inspection._self_inspects @@ -93,6 +94,7 @@ class Mapper( ORMEntityColumnsClauseRole, sql_base.MemoizedHasCacheKey, InspectionAttr, + log.Identified, Generic[_MC], ): """Defines an association between a Python class and a database table or @@ -2361,7 +2363,7 @@ class Mapper( yield c @HasMemoized.memoized_attribute - def attrs(self): + def attrs(self) -> util.ImmutableProperties["MapperProperty"]: """A namespace of all :class:`.MapperProperty` objects associated this mapper. diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index c7408b00b8..7a422cd2ac 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -10,6 +10,7 @@ """ +import threading import traceback import weakref @@ -21,7 +22,6 @@ from .. import exc from .. import util from ..util import chop_traceback from ..util import queue as sqla_queue -from ..util import threading class QueuePool(Pool): diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py index e9357bf7d8..594967a40b 100644 --- a/lib/sqlalchemy/sql/_py_util.py +++ b/lib/sqlalchemy/sql/_py_util.py @@ -5,8 +5,10 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from typing import Dict -class prefix_anon_map(dict): + +class prefix_anon_map(Dict[str, str]): """A map that creates new keys for missing key access. Considers keys of the form " " to produce @@ -27,7 +29,7 @@ class prefix_anon_map(dict): return value -class cache_anon_map(dict): +class cache_anon_map(Dict[int, str]): """A map that creates new keys for missing key access. Produces an incrementing sequence given a series of unique keys. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py new file mode 100644 index 0000000000..b5b0efb21a --- /dev/null +++ b/lib/sqlalchemy/sql/_typing.py @@ -0,0 +1,9 @@ +from typing import Any +from typing import Mapping +from typing import Sequence +from typing import Union + +_SingleExecuteParams = Mapping[str, Any] +_MultiExecuteParams = Sequence[_SingleExecuteParams] +_ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams] +_ExecuteOptions = Mapping[str, Any] diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index fa3bae8353..c0de1902ff 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -8,14 +8,20 @@ """High level utilities which build upon other modules here. """ - from collections import deque from itertools import chain +import typing +from typing import Any +from typing import cast +from typing import Optional from . import coercions from . import operators from . import roles from . import visitors +from ._typing import _ExecuteParams +from ._typing import _MultiExecuteParams +from ._typing import _SingleExecuteParams from .annotation import _deep_annotate # noqa from .annotation import _deep_deannotate # noqa from .annotation import _shallow_annotate # noqa @@ -45,6 +51,9 @@ from .selectable import TableClause from .. import exc from .. import util +if typing.TYPE_CHECKING: + from ..engine.row import Row + def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None): """Create a join condition between two tables or selectables. @@ -488,13 +497,13 @@ def _quote_ddl_expr(element): class _repr_base: - _LIST = 0 - _TUPLE = 1 - _DICT = 2 + _LIST: int = 0 + _TUPLE: int = 1 + _DICT: int = 2 __slots__ = ("max_chars",) - def trunc(self, value): + def trunc(self, value: Any) -> str: rep = repr(value) lenrep = len(rep) if lenrep > self.max_chars: @@ -515,11 +524,11 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row, max_chars=300): + def __init__(self, row: "Row", max_chars: int = 300): self.row = row self.max_chars = max_chars - def __repr__(self): + def __repr__(self) -> str: trunc = self.trunc return "(%s%s)" % ( ", ".join(trunc(value) for value in self.row), @@ -537,13 +546,19 @@ class _repr_params(_repr_base): __slots__ = "params", "batches", "ismulti" - def __init__(self, params, batches, max_chars=300, ismulti=None): - self.params = params + def __init__( + self, + params: _ExecuteParams, + batches: int, + max_chars: int = 300, + ismulti: Optional[bool] = None, + ): + self.params: _ExecuteParams = params self.ismulti = ismulti self.batches = batches self.max_chars = max_chars - def __repr__(self): + def __repr__(self) -> str: if self.ismulti is None: return self.trunc(self.params) @@ -557,23 +572,31 @@ class _repr_params(_repr_base): else: return self.trunc(self.params) - if self.ismulti and len(self.params) > self.batches: - msg = " ... displaying %i of %i total bound parameter sets ... " - return " ".join( - ( - self._repr_multi(self.params[: self.batches - 2], typ)[ - 0:-1 - ], - msg % (self.batches, len(self.params)), - self._repr_multi(self.params[-2:], typ)[1:], + if self.ismulti: + multi_params = cast(_MultiExecuteParams, self.params) + + if len(self.params) > self.batches: + msg = ( + " ... displaying %i of %i total bound parameter sets ... " ) - ) - elif self.ismulti: - return self._repr_multi(self.params, typ) + return " ".join( + ( + self._repr_multi( + multi_params[: self.batches - 2], + typ, + )[0:-1], + msg % (self.batches, len(self.params)), + self._repr_multi(multi_params[-2:], typ)[1:], + ) + ) + else: + return self._repr_multi(multi_params, typ) else: - return self._repr_params(self.params, typ) + return self._repr_params( + cast(_SingleExecuteParams, self.params), typ + ) - def _repr_multi(self, multi_params, typ): + def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str: if multi_params: if isinstance(multi_params[0], list): elem_type = self._LIST @@ -597,7 +620,7 @@ class _repr_params(_repr_base): else: return "(%s)" % elements - def _repr_params(self, params, typ): + def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str: trunc = self.trunc if typ is self._DICT: return "{%s}" % ( diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 203460c266..91d15aae08 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -62,7 +62,6 @@ from .compat import osx from .compat import py38 from .compat import py39 from .compat import pypy -from .compat import threading from .compat import win32 from .concurrency import asyncio from .concurrency import await_fallback diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index e53bc9c43a..3e4ef1310d 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -8,26 +8,53 @@ """Collection classes and helpers.""" import collections.abc as collections_abc import operator +import threading import types +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TypeVar +from typing import Union +from typing import ValuesView import weakref -from .compat import threading +from ._has_cy import HAS_CYEXTENSION +from .typing import Literal -try: - from sqlalchemy.cyextension.immutabledict import ImmutableContainer - from sqlalchemy.cyextension.immutabledict import immutabledict - from sqlalchemy.cyextension.collections import IdentitySet - from sqlalchemy.cyextension.collections import OrderedSet - from sqlalchemy.cyextension.collections import unique_list # noqa -except ImportError: +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_collections import immutabledict from ._py_collections import IdentitySet from ._py_collections import ImmutableContainer + from ._py_collections import ImmutableDictBase from ._py_collections import OrderedSet from ._py_collections import unique_list # noqa +else: + from sqlalchemy.cyextension.immutabledict import ImmutableContainer + from sqlalchemy.cyextension.immutabledict import ImmutableDictBase + from sqlalchemy.cyextension.immutabledict import immutabledict + from sqlalchemy.cyextension.collections import IdentitySet + from sqlalchemy.cyextension.collections import OrderedSet + from sqlalchemy.cyextension.collections import unique_list # noqa + +_T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) -EMPTY_SET = frozenset() + +EMPTY_SET: FrozenSet[Any] = frozenset() def coerce_to_immutabledict(d): @@ -39,14 +66,12 @@ def coerce_to_immutabledict(d): return immutabledict(d) -EMPTY_DICT = immutabledict() +EMPTY_DICT: immutabledict[Any, Any] = immutabledict() -class FacadeDict(ImmutableContainer, dict): +class FacadeDict(ImmutableDictBase[Any, Any]): """A dictionary that is not publicly mutable.""" - clear = pop = popitem = setdefault = update = ImmutableContainer._immutable - def __new__(cls, *args): new = dict.__new__(cls) return new @@ -68,18 +93,23 @@ class FacadeDict(ImmutableContainer, dict): return "FacadeDict(%s)" % dict.__repr__(self) -class Properties: +_DT = TypeVar("_DT", bound=Any) + + +class Properties(Generic[_T]): """Provide a __getattr__/__setattr__ interface over a dict.""" __slots__ = ("_data",) + _data: Dict[str, _T] + def __init__(self, data): object.__setattr__(self, "_data", data) - def __len__(self): + def __len__(self) -> int: return len(self._data) - def __iter__(self): + def __iter__(self) -> Iterator[_T]: return iter(list(self._data.values())) def __dir__(self): @@ -93,7 +123,7 @@ class Properties: def __setitem__(self, key, obj): self._data[key] = obj - def __getitem__(self, key): + def __getitem__(self, key: str) -> _T: return self._data[key] def __delitem__(self, key): @@ -108,16 +138,16 @@ class Properties: def __setstate__(self, state): object.__setattr__(self, "_data", state["_data"]) - def __getattr__(self, key): + def __getattr__(self, key: str) -> _T: try: return self._data[key] except KeyError: raise AttributeError(key) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self._data - def as_immutable(self): + def as_immutable(self) -> "ImmutableProperties[_T]": """Return an immutable proxy for this :class:`.Properties`.""" return ImmutableProperties(self._data) @@ -125,29 +155,39 @@ class Properties: def update(self, value): self._data.update(value) - def get(self, key, default=None): + @overload + def get(self, key: str) -> Optional[_T]: + ... + + @overload + def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: + ... + + def get( + self, key: str, default: Optional[Union[_DT, _T]] = None + ) -> Optional[Union[_T, _DT]]: if key in self: return self[key] else: return default - def keys(self): + def keys(self) -> List[str]: return list(self._data) - def values(self): + def values(self) -> List[_T]: return list(self._data.values()) - def items(self): + def items(self) -> List[Tuple[str, _T]]: return list(self._data.items()) - def has_key(self, key): + def has_key(self, key: str) -> bool: return key in self._data def clear(self): self._data.clear() -class OrderedProperties(Properties): +class OrderedProperties(Properties[_T]): """Provide a __getattr__/__setattr__ interface with an OrderedDict as backing store.""" @@ -157,7 +197,7 @@ class OrderedProperties(Properties): Properties.__init__(self, OrderedDict()) -class ImmutableProperties(ImmutableContainer, Properties): +class ImmutableProperties(ImmutableContainer, Properties[_T]): """Provide immutable dict/object attribute to an underlying dictionary.""" __slots__ = () @@ -220,7 +260,7 @@ class OrderedIdentitySet(IdentitySet): self.add(o) -class PopulateDict(dict): +class PopulateDict(Dict[_KT, _VT]): """A dict which populates missing values via a creation function. Note the creation function takes a key, unlike @@ -228,26 +268,26 @@ class PopulateDict(dict): """ - def __init__(self, creator): + def __init__(self, creator: Callable[[_KT], _VT]): self.creator = creator - def __missing__(self, key): + def __missing__(self, key: Any) -> Any: self[key] = val = self.creator(key) return val -class WeakPopulateDict(dict): +class WeakPopulateDict(Dict[_KT, _VT]): """Like PopulateDict, but assumes a self + a method and does not create a reference cycle. """ - def __init__(self, creator_method): + def __init__(self, creator_method: types.MethodType): self.creator = creator_method.__func__ weakself = creator_method.__self__ self.weakself = weakref.ref(weakself) - def __missing__(self, key): + def __missing__(self, key: Any) -> Any: self[key] = val = self.creator(self.weakself(), key) return val @@ -261,37 +301,40 @@ column_dict = dict ordered_column_set = OrderedSet -_getters = PopulateDict(operator.itemgetter) - -_property_getters = PopulateDict( - lambda idx: property(operator.itemgetter(idx)) -) - - -class UniqueAppender: +class UniqueAppender(Generic[_T]): """Appends items to a collection ensuring uniqueness. Additional appends() of the same object are ignored. Membership is determined by identity (``is a``) not equality (``==``). """ - def __init__(self, data, via=None): + __slots__ = "data", "_data_appender", "_unique" + + data: Union[Iterable[_T], Set[_T], List[_T]] + _data_appender: Callable[[_T], None] + _unique: Dict[int, Literal[True]] + + def __init__( + self, + data: Union[Iterable[_T], Set[_T], List[_T]], + via: Optional[str] = None, + ): self.data = data self._unique = {} if via: - self._data_appender = getattr(data, via) + self._data_appender = getattr(data, via) # type: ignore[assignment] # noqa E501 elif hasattr(data, "append"): - self._data_appender = data.append + self._data_appender = cast("List[_T]", data).append # type: ignore[assignment] # noqa E501 elif hasattr(data, "add"): - self._data_appender = data.add + self._data_appender = cast("Set[_T]", data).add # type: ignore[assignment] # noqa E501 - def append(self, item): + def append(self, item: _T) -> None: id_ = id(item) if id_ not in self._unique: - self._data_appender(item) + self._data_appender(item) # type: ignore[call-arg] self._unique[id_] = True - def __iter__(self): + def __iter__(self) -> Iterator[_T]: return iter(self.data) @@ -302,13 +345,27 @@ def coerce_generator_arg(arg): return arg -def to_list(x, default=None): +@overload +def to_list(x: Sequence[_T], default: Optional[List[_T]] = None) -> List[_T]: + ... + + +@overload +def to_list( + x: Optional[Sequence[_T]], default: Optional[List[_T]] = None +) -> Optional[List[_T]]: + ... + + +def to_list( + x: Optional[Sequence[_T]], default: Optional[List[_T]] = None +) -> Optional[List[_T]]: if x is None: return default if not isinstance(x, collections_abc.Iterable) or isinstance( x, (str, bytes) ): - return [x] + return [cast(_T, x)] elif isinstance(x, list): return x else: @@ -367,7 +424,7 @@ def flatten_iterator(x): yield elem -class LRUCache(dict): +class LRUCache(typing.MutableMapping[_KT, _VT]): """Dictionary with 'squishy' removal of least recently used items. @@ -377,7 +434,18 @@ class LRUCache(dict): """ - __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex" + __slots__ = ( + "capacity", + "threshold", + "size_alert", + "_data", + "_counter", + "_mutex", + ) + + capacity: int + threshold: float + size_alert: Callable[["LRUCache[_KT, _VT]"], None] def __init__(self, capacity=100, threshold=0.5, size_alert=None): self.capacity = capacity @@ -385,48 +453,56 @@ class LRUCache(dict): self.size_alert = size_alert self._counter = 0 self._mutex = threading.Lock() + self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {} def _inc_counter(self): self._counter += 1 return self._counter - def get(self, key, default=None): - item = dict.get(self, key, default) - if item is not default: - item[2] = self._inc_counter() + @overload + def get(self, key: _KT) -> Optional[_VT]: + ... + + @overload + def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: + ... + + def get( + self, key: _KT, default: Optional[Union[_VT, _T]] = None + ) -> Optional[Union[_VT, _T]]: + item = self._data.get(key, default) + if item is not default and item is not None: + item[2][0] = self._inc_counter() return item[1] else: return default - def __getitem__(self, key): - item = dict.__getitem__(self, key) - item[2] = self._inc_counter() + def __getitem__(self, key: _KT) -> _VT: + item = self._data[key] + item[2][0] = self._inc_counter() return item[1] - def values(self): - return [i[1] for i in dict.values(self)] + def __iter__(self) -> Iterator[_KT]: + return iter(self._data) - def setdefault(self, key, value): - if key in self: - return self[key] - else: - self[key] = value - return value - - def __setitem__(self, key, value): - item = dict.get(self, key) - if item is None: - item = [key, value, self._inc_counter()] - dict.__setitem__(self, key, item) - else: - item[1] = value + def __len__(self) -> int: + return len(self._data) + + def values(self) -> ValuesView[_VT]: + return typing.ValuesView({k: i[1] for k, i in self._data.items()}) + + def __setitem__(self, key: _KT, value: _VT) -> None: + self._data[key] = (key, value, [self._inc_counter()]) self._manage_size() + def __delitem__(self, __v: _KT) -> None: + del self._data[__v] + @property - def size_threshold(self): + def size_threshold(self) -> float: return self.capacity + self.capacity * self.threshold - def _manage_size(self): + def _manage_size(self) -> None: if not self._mutex.acquire(False): return try: @@ -434,13 +510,15 @@ class LRUCache(dict): while len(self) > self.capacity + self.capacity * self.threshold: if size_alert: size_alert = False - self.size_alert(self) + self.size_alert(self) # type: ignore by_counter = sorted( - dict.values(self), key=operator.itemgetter(2), reverse=True + self._data.values(), + key=operator.itemgetter(2), + reverse=True, ) for item in by_counter[self.capacity :]: try: - del self[item[0]] + del self._data[item[0]] except KeyError: # deleted elsewhere; skip continue @@ -463,6 +541,8 @@ class ScopedRegistry: a callable that will return a key to store/retrieve an object. """ + __slots__ = "createfunc", "scopefunc", "registry" + def __init__(self, createfunc, scopefunc): """Construct a new :class:`.ScopedRegistry`. @@ -529,7 +609,7 @@ class ThreadLocalRegistry(ScopedRegistry): def clear(self): try: - del self.registry.value + del self.registry.value # type: ignore except AttributeError: pass diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index ac678f8a98..b9e58e68cd 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -8,23 +8,25 @@ import asyncio from contextvars import copy_context as _copy_context import sys +import typing from typing import Any from typing import Callable from typing import Coroutine -import greenlet +import greenlet # type: ignore # noqa from .langhelpers import memoized_property from .. import exc -try: +if not typing.TYPE_CHECKING: + try: - # If greenlet.gr_context is present in current version of greenlet, - # it will be set with a copy of the current context on creation. - # Refs: https://github.com/python-greenlet/greenlet/pull/198 - getattr(greenlet.greenlet, "gr_context") -except (ImportError, AttributeError): - _copy_context = None # noqa + # If greenlet.gr_context is present in current version of greenlet, + # it will be set with a copy of the current context on creation. + # Refs: https://github.com/python-greenlet/greenlet/pull/198 + getattr(greenlet.greenlet, "gr_context") + except (ImportError, AttributeError): + _copy_context = None # noqa def is_exit_exception(e): @@ -40,7 +42,7 @@ def is_exit_exception(e): # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 -class _AsyncIoGreenlet(greenlet.greenlet): +class _AsyncIoGreenlet(greenlet.greenlet): # type: ignore def __init__(self, fn, driver): greenlet.greenlet.__init__(self, fn, driver) self.driver = driver @@ -48,7 +50,7 @@ class _AsyncIoGreenlet(greenlet.greenlet): self.gr_context = _copy_context() -def await_only(awaitable: Coroutine) -> Any: +def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. @@ -72,7 +74,7 @@ def await_only(awaitable: Coroutine) -> Any: return current.driver.switch(awaitable) -def await_fallback(awaitable: Coroutine) -> Any: +def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. @@ -97,7 +99,10 @@ def await_fallback(awaitable: Coroutine) -> Any: async def greenlet_spawn( - fn: Callable, *args, _require_await=False, **kwargs + fn: Callable[..., Any], + *args: Any, + _require_await: bool = False, + **kwargs: Any, ) -> Any: """Runs a sync function ``fn`` in a new greenlet. diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/_preloaded.py index 9448ed33de..b0f8ab444a 100644 --- a/lib/sqlalchemy/util/_preloaded.py +++ b/lib/sqlalchemy/util/_preloaded.py @@ -9,8 +9,14 @@ runtime. """ - import sys +from types import ModuleType +import typing +from typing import Any +from typing import Callable +from typing import TypeVar + +_FN = TypeVar("_FN", bound=Callable[..., Any]) class _ModuleRegistry: @@ -37,7 +43,7 @@ class _ModuleRegistry: self.module_registry = set() self.prefix = prefix - def preload_module(self, *deps): + def preload_module(self, *deps: str) -> Callable[[_FN], _FN]: """Adds the specified modules to the list to load. This method can be used both as a normal function and as a decorator. @@ -46,7 +52,7 @@ class _ModuleRegistry: self.module_registry.update(deps) return lambda fn: fn - def import_prefix(self, path): + def import_prefix(self, path: str) -> None: """Resolve all the modules in the registry that start with the specified path. """ @@ -61,6 +67,11 @@ class _ModuleRegistry: __import__(module, globals(), locals()) self.__dict__[key] = sys.modules[module] + if typing.TYPE_CHECKING: + + def __getattr__(self, key: str) -> ModuleType: + ... + preloaded = _ModuleRegistry() preload_module = preloaded.preload_module diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index ff61f6ca90..a4e4b8b5db 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -1,17 +1,52 @@ from itertools import filterfalse +from typing import Any +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import TypeVar + +_T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) class ImmutableContainer: - def _immutable(self, *arg, **kw): + def _immutable(self, *arg: Any, **kw: Any) -> NoReturn: raise TypeError("%s object is immutable" % self.__class__.__name__) - __delitem__ = __setitem__ = __setattr__ = _immutable + def __delitem__(self, key: Any) -> NoReturn: + self._immutable() + def __setitem__(self, key: Any, value: Any) -> NoReturn: + self._immutable() -class immutabledict(ImmutableContainer, dict): + def __setattr__(self, key: str, value: Any) -> NoReturn: + self._immutable() - clear = pop = popitem = setdefault = update = ImmutableContainer._immutable +class ImmutableDictBase(ImmutableContainer, Dict[_KT, _VT]): + def clear(self) -> NoReturn: + self._immutable() + + def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn: + self._immutable() + + def popitem(self) -> NoReturn: + self._immutable() + + def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn: + self._immutable() + + def update(self, *arg: Any, **kw: Any) -> NoReturn: + self._immutable() + + +class immutabledict(ImmutableDictBase[_KT, _VT]): def __new__(cls, *args): new = dict.__new__(cls) dict.__init__(new, *args) @@ -41,7 +76,7 @@ class immutabledict(ImmutableContainer, dict): dict.__init__(new, self) if __d: dict.update(new, __d) - dict.update(new, kw) + dict.update(new, kw) # type: ignore return new def merge_with(self, *dicts): @@ -61,110 +96,145 @@ class immutabledict(ImmutableContainer, dict): return "immutabledict(%s)" % dict.__repr__(self) -class OrderedSet(set): +class OrderedSet(Generic[_T]): + __slots__ = ("_list", "_set", "__weakref__") + + _list: List[_T] + _set: Set[_T] + def __init__(self, d=None): - set.__init__(self) if d is not None: self._list = unique_list(d) - set.update(self, self._list) + self._set = set(self._list) else: self._list = [] + self._set = set() + + def __reduce__(self): + return (OrderedSet, (self._list,)) - def add(self, element): + def add(self, element: _T) -> None: if element not in self: self._list.append(element) - set.add(self, element) + self._set.add(element) - def remove(self, element): - set.remove(self, element) + def remove(self, element: _T) -> None: + self._set.remove(element) self._list.remove(element) - def insert(self, pos, element): + def insert(self, pos: int, element: _T) -> None: if element not in self: self._list.insert(pos, element) - set.add(self, element) + self._set.add(element) - def discard(self, element): + def discard(self, element: _T) -> None: if element in self: self._list.remove(element) - set.remove(self, element) + self._set.remove(element) - def clear(self): - set.clear(self) + def clear(self) -> None: + self._set.clear() self._list = [] - def __getitem__(self, key): + def __len__(self) -> int: + return len(self._set) + + def __eq__(self, other): + if not isinstance(other, OrderedSet): + return self._set == other + else: + return self._set == other._set + + def __ne__(self, other): + if not isinstance(other, OrderedSet): + return self._set != other + else: + return self._set != other._set + + def __contains__(self, element: Any) -> bool: + return element in self._set + + def __getitem__(self, key: int) -> _T: return self._list[key] - def __iter__(self): + def __iter__(self) -> Iterator[_T]: return iter(self._list) - def __add__(self, other): + def __add__(self, other: Iterator[_T]) -> "OrderedSet[_T]": return self.union(other) - def __repr__(self): + def __repr__(self) -> str: return "%s(%r)" % (self.__class__.__name__, self._list) __str__ = __repr__ - def update(self, iterable): - for e in iterable: - if e not in self: - self._list.append(e) - set.add(self, e) - return self + def update(self, *iterables: Iterable[_T]) -> None: + for iterable in iterables: + for e in iterable: + if e not in self: + self._list.append(e) + self._set.add(e) - __ior__ = update + def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + self.update(other) + return self - def union(self, other): + def union(self, other: Iterable[_T]) -> "OrderedSet[_T]": result = self.__class__(self) result.update(other) return result - __or__ = union + def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + return self.union(other) - def intersection(self, other): + def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]": other = other if isinstance(other, set) else set(other) return self.__class__(a for a in self if a in other) - __and__ = intersection + def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + return self.intersection(other) - def symmetric_difference(self, other): + def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]": other_set = other if isinstance(other, set) else set(other) result = self.__class__(a for a in self if a not in other_set) result.update(a for a in other if a not in self) return result - __xor__ = symmetric_difference + def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + return self.symmetric_difference(other) - def difference(self, other): + def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]": other = other if isinstance(other, set) else set(other) return self.__class__(a for a in self if a not in other) - __sub__ = difference + def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + return self.difference(other) - def intersection_update(self, other): + def intersection_update(self, other: Iterable[_T]) -> None: other = other if isinstance(other, set) else set(other) - set.intersection_update(self, other) + self._set.intersection_update(other) self._list = [a for a in self._list if a in other] - return self - __iand__ = intersection_update + def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + self.intersection_update(other) + return self - def symmetric_difference_update(self, other): - set.symmetric_difference_update(self, other) + def symmetric_difference_update(self, other: Iterable[_T]) -> None: + self._set.symmetric_difference_update(other) self._list = [a for a in self._list if a in self] self._list += [a for a in other if a in self] - return self - __ixor__ = symmetric_difference_update + def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + self.symmetric_difference_update(other) + return self - def difference_update(self, other): - set.difference_update(self, other) + def difference_update(self, other: Iterable[_T]) -> None: + self._set.difference_update(other) self._list = [a for a in self._list if a in self] - return self - __isub__ = difference_update + def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + self.difference_update(other) + return self class IdentitySet: diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 679df73c70..0f4befbb1f 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -6,14 +6,23 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php """Handle Python version/platform incompatibilities.""" +from __future__ import annotations + import base64 -import collections import dataclasses import inspect import operator import platform import sys import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple py311 = sys.version_info >= (3, 11) @@ -32,27 +41,18 @@ has_refcount_gc = bool(cpython) dottedgetter = operator.attrgetter next = next # noqa -FullArgSpec = collections.namedtuple( - "FullArgSpec", - [ - "args", - "varargs", - "varkw", - "defaults", - "kwonlyargs", - "kwonlydefaults", - "annotations", - ], -) - -try: - import threading -except ImportError: - import dummy_threading as threading # noqa +class FullArgSpec(typing.NamedTuple): + args: List[str] + varargs: Optional[str] + varkw: Optional[str] + defaults: Optional[Tuple[Any, ...]] + kwonlyargs: List[str] + kwonlydefaults: Dict[str, Any] + annotations: Dict[str, Any] -def inspect_getfullargspec(func): +def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec: """Fully vendored version of getfullargspec from Python 3.3.""" if inspect.ismethod(func): @@ -90,13 +90,13 @@ def inspect_getfullargspec(func): ) -if py38: +if typing.TYPE_CHECKING or py38: from importlib import metadata as importlib_metadata else: import importlib_metadata # noqa -if py39: +if typing.TYPE_CHECKING or py39: # pep 584 dict union dict_union = operator.or_ # noqa else: @@ -109,7 +109,7 @@ else: def importlib_metadata_get(group): ep = importlib_metadata.entry_points() - if hasattr(ep, "select"): + if not typing.TYPE_CHECKING and hasattr(ep, "select"): return ep.select(group=group) else: return ep.get(group, ()) @@ -119,15 +119,15 @@ def b(s): return s.encode("latin-1") -def b64decode(x): +def b64decode(x: str) -> bytes: return base64.b64decode(x.encode("ascii")) -def b64encode(x): +def b64encode(x: bytes) -> str: return base64.b64encode(x).decode("ascii") -def decode_backslashreplace(text, encoding): +def decode_backslashreplace(text: bytes, encoding: str) -> str: return text.decode(encoding, errors="backslashreplace") @@ -150,20 +150,20 @@ def _formatannotation(annotation, base_module=None): def inspect_formatargspec( - args, - varargs=None, - varkw=None, - defaults=None, - kwonlyargs=(), - kwonlydefaults={}, - annotations={}, - formatarg=str, - formatvarargs=lambda name: "*" + name, - formatvarkw=lambda name: "**" + name, - formatvalue=lambda value: "=" + repr(value), - formatreturns=lambda text: " -> " + text, - formatannotation=_formatannotation, -): + args: List[str], + varargs: Optional[str] = None, + varkw: Optional[str] = None, + defaults: Optional[Sequence[Any]] = None, + kwonlyargs: Optional[Sequence[str]] = (), + kwonlydefaults: Optional[Mapping[str, Any]] = {}, + annotations: Mapping[str, Any] = {}, + formatarg: Callable[[str], str] = str, + formatvarargs: Callable[[str], str] = lambda name: "*" + name, + formatvarkw: Callable[[str], str] = lambda name: "**" + name, + formatvalue: Callable[[Any], str] = lambda value: "=" + repr(value), + formatreturns: Callable[[Any], str] = lambda text: " -> " + str(text), + formatannotation: Callable[[Any], str] = _formatannotation, +) -> str: """Copy formatargspec from python 3.7 standard library. Python 3 has deprecated formatargspec and requested that Signature @@ -190,6 +190,9 @@ def inspect_formatargspec( specs = [] if defaults: firstdefault = len(args) - len(defaults) + else: + firstdefault = -1 + for i, arg in enumerate(args): spec = formatargandannotation(arg) if defaults and i >= firstdefault: diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index e5183a542a..57ef230062 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -5,11 +5,12 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +import asyncio # noqa have_greenlet = False greenlet_error = None try: - import greenlet # noqa F401 + import greenlet # type: ignore # noqa F401 except ImportError as e: greenlet_error = str(e) pass @@ -24,12 +25,9 @@ else: from ._concurrency_py3k import ( _util_async_run_coroutine_function, ) # noqa F401, E501 - from ._concurrency_py3k import asyncio # noqa F401 if not have_greenlet: - asyncio = None # noqa F811 - def _not_implemented(): # this conditional is to prevent pylance from considering # greenlet_spawn() etc as "no return" and dimming out code below it @@ -46,20 +44,20 @@ if not have_greenlet: def is_exit_exception(e): # noqa F811 return not isinstance(e, Exception) - def await_only(thing): # noqa F811 + def await_only(thing): # type: ignore # noqa F811 _not_implemented() - def await_fallback(thing): # noqa F81 + def await_fallback(thing): # type: ignore # noqa F81 return thing - def greenlet_spawn(fn, *args, **kw): # noqa F81 + def greenlet_spawn(fn, *args, **kw): # type: ignore # noqa F81 _not_implemented() - def AsyncAdaptedLock(*args, **kw): # noqa F81 + def AsyncAdaptedLock(*args, **kw): # type: ignore # noqa F81 _not_implemented() - def _util_async_run(fn, *arg, **kw): # noqa F81 + def _util_async_run(fn, *arg, **kw): # type: ignore # noqa F81 return fn(*arg, **kw) - def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa F81 + def _util_async_run_coroutine_function(fn, *arg, **kw): # type: ignore # noqa F81 _not_implemented() diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index e5d5d54619..565cbafe26 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -11,6 +11,8 @@ functionality.""" import re from typing import Any from typing import Callable +from typing import cast +from typing import Optional from typing import TypeVar from . import compat @@ -67,11 +69,11 @@ def deprecated_cls(version, message, constructor="__init__"): def deprecated_property( - version, - message=None, - add_deprecation_to_docstring=True, - warning=None, - enable_warnings=True, + version: str, + message: Optional[str] = None, + add_deprecation_to_docstring: bool = True, + warning: Optional[str] = None, + enable_warnings: bool = True, ) -> Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]]: """the @deprecated decorator with a @property. @@ -99,14 +101,17 @@ def deprecated_property( great! now it is. """ - return lambda fn: property( - deprecated( - version, - message=message, - add_deprecation_to_docstring=add_deprecation_to_docstring, - warning=warning, - enable_warnings=enable_warnings, - )(fn) + return cast( + Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]], + lambda fn: property( + deprecated( + version, + message=message, + add_deprecation_to_docstring=add_deprecation_to_docstring, + warning=warning, + enable_warnings=enable_warnings, + )(fn) + ), ) @@ -325,11 +330,12 @@ def _decorate_cls_with_warning( ) doc = inject_docstring_text(doc, docstring_header, 1) + constructor_fn = None if type(cls) is type: clsdict = dict(cls.__dict__) clsdict["__doc__"] = doc clsdict.pop("__dict__", None) - cls = type(cls.__name__, cls.__bases__, clsdict) + cls = type(cls.__name__, cls.__bases__, clsdict) # type: ignore if constructor is not None: constructor_fn = clsdict[constructor] @@ -339,6 +345,7 @@ def _decorate_cls_with_warning( constructor_fn = getattr(cls, constructor) if constructor is not None: + assert constructor_fn is not None setattr( cls, constructor, diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 8b65fb4cf6..9401c249fe 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -19,13 +19,23 @@ import operator import re import sys import textwrap +import threading import types import typing from typing import Any from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet from typing import Generic +from typing import Iterator +from typing import List from typing import Optional from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type from typing import TypeVar from typing import Union import warnings @@ -33,6 +43,7 @@ import warnings from . import _collections from . import compat from . import typing as compat_typing +from ._has_cy import HAS_CYEXTENSION from .. import exc _T = TypeVar("_T") @@ -43,7 +54,7 @@ _HP = TypeVar("_HP", bound="hybridproperty") _HM = TypeVar("_HM", bound="hybridmethod") -def md5_hex(x): +def md5_hex(x: Any) -> str: x = x.encode("utf-8") m = hashlib.md5() m.update(x) @@ -70,26 +81,44 @@ class safe_reraise: __slots__ = ("warn_only", "_exc_info") - def __init__(self, warn_only=False): + _exc_info: Union[ + None, + Tuple[ + Type[BaseException], + BaseException, + types.TracebackType, + ], + Tuple[None, None, None], + ] + + def __init__(self, warn_only: bool = False): self.warn_only = warn_only - def __enter__(self): + def __enter__(self) -> None: self._exc_info = sys.exc_info() - def __exit__(self, type_, value, traceback): + def __exit__( + self, + type_: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: + assert self._exc_info is not None # see #2703 for notes if type_ is None: exc_type, exc_value, exc_tb = self._exc_info + assert exc_value is not None self._exc_info = None # remove potential circular references if not self.warn_only: raise exc_value.with_traceback(exc_tb) else: self._exc_info = None # remove potential circular references + assert value is not None raise value.with_traceback(traceback) -def walk_subclasses(cls): - seen = set() +def walk_subclasses(cls: type) -> Iterator[type]: + seen: Set[Any] = set() stack = [cls] while stack: @@ -102,7 +131,7 @@ def walk_subclasses(cls): yield cls -def string_or_unprintable(element): +def string_or_unprintable(element: Any) -> str: if isinstance(element, str): return element else: @@ -112,13 +141,15 @@ def string_or_unprintable(element): return "unprintable element %r" % element -def clsname_as_plain_name(cls): +def clsname_as_plain_name(cls: Type[Any]) -> str: return " ".join( n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__) ) -def method_is_overridden(instance_or_cls, against_method): +def method_is_overridden( + instance_or_cls: Union[Type[Any], object], against_method: types.MethodType +) -> bool: """Return True if the two class methods don't match.""" if not isinstance(instance_or_cls, type): @@ -128,18 +159,18 @@ def method_is_overridden(instance_or_cls, against_method): method_name = against_method.__name__ - current_method = getattr(current_cls, method_name) + current_method: types.MethodType = getattr(current_cls, method_name) return current_method != against_method -def decode_slice(slc): +def decode_slice(slc: slice) -> Tuple[Any, ...]: """decode a slice object as sent to __getitem__. takes into account the 2.5 __index__() method, basically. """ - ret = [] + ret: List[Any] = [] for x in slc.start, slc.stop, slc.step: if hasattr(x, "__index__"): x = x.__index__() @@ -147,23 +178,23 @@ def decode_slice(slc): return tuple(ret) -def _unique_symbols(used, *bases): - used = set(used) +def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]: + used_set = set(used) for base in bases: pool = itertools.chain( (base,), map(lambda i: base + str(i), range(1000)), ) for sym in pool: - if sym not in used: - used.add(sym) + if sym not in used_set: + used_set.add(sym) yield sym break else: raise NameError("exhausted namespace for symbol base %s" % base) -def map_bits(fn, n): +def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]: """Call the given function given each nonzero bit from n.""" while n: @@ -172,28 +203,34 @@ def map_bits(fn, n): n ^= b -_Fn = typing.TypeVar("_Fn", bound=typing.Callable) +_Fn = typing.TypeVar("_Fn", bound=typing.Callable[..., Any]) _Args = compat_typing.ParamSpec("_Args") def decorator( - target: typing.Callable[compat_typing.Concatenate[_Fn, _Args], typing.Any] + target: typing.Callable[ # type: ignore + compat_typing.Concatenate[_Fn, _Args], typing.Any + ] ) -> _Fn: """A signature-matching decorator factory.""" - def decorate(fn): + def decorate(fn: typing.Callable[..., Any]) -> typing.Callable[..., Any]: if not inspect.isfunction(fn) and not inspect.ismethod(fn): raise Exception("not a decoratable function") spec = compat.inspect_getfullargspec(fn) - env = {} + env: Dict[str, Any] = {} spec = _update_argspec_defaults_into_env(spec, env) - names = tuple(spec[0]) + spec[1:3] + (fn.__name__,) + names = ( + tuple(cast("Tuple[str, ...]", spec[0])) + + cast("Tuple[str, ...]", spec[1:3]) + + (fn.__name__,) + ) targ_name, fn_name = _unique_symbols(names, "target", "fn") - metadata = dict(target=targ_name, fn=fn_name) + metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name) metadata.update(format_argspec_plus(spec, grouped=False)) metadata["name"] = fn.__name__ code = ( @@ -205,9 +242,15 @@ def %(name)s%(grouped_args)s: ) env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__}) - decorated = _exec_code_in_env(code, env, fn.__name__) + decorated = cast( + types.FunctionType, + _exec_code_in_env(code, env, fn.__name__), + ) decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ - decorated.__wrapped__ = fn + + # claims to be fixed? + # https://github.com/python/mypy/issues/11896 + decorated.__wrapped__ = fn # type: ignore return update_wrapper(decorated, fn) return typing.cast(_Fn, update_wrapper(decorate, target)) @@ -303,7 +346,26 @@ def _inspect_func_args(fn): ) -def get_cls_kwargs(cls, _set=None): +@overload +def get_cls_kwargs( + cls: type, + *, + _set: Optional[Set[str]] = None, + raiseerr: compat_typing.Literal[True] = ..., +) -> Set[str]: + ... + + +@overload +def get_cls_kwargs( + cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False +) -> Optional[Set[str]]: + ... + + +def get_cls_kwargs( + cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False +) -> Optional[Set[str]]: r"""Return the full set of inherited kwargs for the given `cls`. Probes a class's __init__ method, collecting all named arguments. If the @@ -321,6 +383,7 @@ def get_cls_kwargs(cls, _set=None): toplevel = _set is None if toplevel: _set = set() + assert _set is not None ctr = cls.__dict__.get("__init__", False) @@ -335,11 +398,18 @@ def get_cls_kwargs(cls, _set=None): _set.update(names) if not has_kw and not toplevel: - return None + if raiseerr: + raise TypeError( + f"given cls {cls} doesn't have an __init__ method" + ) + else: + return None + else: + has_kw = False if not has_init or has_kw: for c in cls.__bases__: - if get_cls_kwargs(c, _set) is None: + if get_cls_kwargs(c, _set=_set) is None: break _set.discard("self") @@ -411,7 +481,9 @@ def get_callable_argspec(fn, no_self=False, _is_init=False): raise TypeError("Can't inspect callable: %s" % fn) -def format_argspec_plus(fn, grouped=True): +def format_argspec_plus( + fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True +) -> Dict[str, Optional[str]]: """Returns a dictionary of formatted, introspected function arguments. A enhanced variant of inspect.formatargspec to support code generation. @@ -474,11 +546,14 @@ def format_argspec_plus(fn, grouped=True): num_defaults = 0 if spec[3]: - num_defaults += len(spec[3]) + num_defaults += len(cast(Tuple[Any], spec[3])) if spec[4]: num_defaults += len(spec[4]) + name_args = spec[0] + spec[4] + defaulted_vals: Union[List[str], Tuple[()]] + if num_defaults: defaulted_vals = name_args[0 - num_defaults :] else: @@ -489,7 +564,7 @@ def format_argspec_plus(fn, grouped=True): spec[1], spec[2], defaulted_vals, - formatvalue=lambda x: "=" + x, + formatvalue=lambda x: "=" + str(x), ) if spec[0]: @@ -498,7 +573,7 @@ def format_argspec_plus(fn, grouped=True): spec[1], spec[2], defaulted_vals, - formatvalue=lambda x: "=" + x, + formatvalue=lambda x: "=" + str(x), ) else: apply_kw_proxied = apply_kw @@ -570,7 +645,7 @@ def create_proxy_methods( def decorate(cls): def instrument(name, clslevel=False): - fn = getattr(target_cls, name) + fn = cast(Callable[..., Any], getattr(target_cls, name)) spec = compat.inspect_getfullargspec(fn) env = {"__name__": fn.__module__} @@ -599,7 +674,9 @@ def create_proxy_methods( % metadata ) - proxy_fn = _exec_code_in_env(code, env, fn.__name__) + proxy_fn = cast( + Callable[..., Any], _exec_code_in_env(code, env, fn.__name__) + ) proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__ proxy_fn.__doc__ = inject_docstring_text( fn.__doc__, @@ -721,7 +798,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): except TypeError: continue else: - default_len = spec.defaults and len(spec.defaults) or 0 + default_len = len(spec.defaults) if spec.defaults else 0 if i == 0: if spec.varargs: vargs = spec.varargs @@ -735,6 +812,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): ) if default_len: + assert spec.defaults kw_args.update( [ (arg, default) @@ -811,9 +889,6 @@ def class_hierarchy(cls): class_hierarchy(class A(object)) returns (A, object), not A plus every class systemwide that derives from object. - Old-style classes are discarded and hierarchies rooted on them - will not be descended. - """ hier = {cls} @@ -829,7 +904,15 @@ def class_hierarchy(cls): if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"): continue - for s in [_ for _ in c.__subclasses__() if _ not in hier]: + for s in [ + _ + for _ in ( + c.__subclasses__() + if not issubclass(c, type) + else c.__subclasses__(c) + ) + if _ not in hier + ]: process.append(s) hier.add(s) return list(hier) @@ -886,10 +969,12 @@ def monkeypatch_proxied_specials( for method in dunders: try: - fn = getattr(from_cls, method) - if not hasattr(fn, "__call__"): + maybe_fn = getattr(from_cls, method) + if not hasattr(maybe_fn, "__call__"): continue - fn = getattr(fn, "__func__", fn) + maybe_fn = getattr(maybe_fn, "__func__", maybe_fn) + fn = cast(Callable[..., Any], maybe_fn) + except AttributeError: continue try: @@ -1021,7 +1106,7 @@ class memoized_property(Generic[_T]): __name__: str def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None): - self.fget = fget + self.fget = fget # type: ignore[assignment] self.__doc__ = doc or fget.__doc__ self.__name__ = fget.__name__ @@ -1041,7 +1126,7 @@ class memoized_property(Generic[_T]): if obj is None: return self obj.__dict__[self.__name__] = result = self.fget(obj) - return result + return result # type: ignore def _reset(self, obj): memoized_property.reset(obj, self.__name__) @@ -1082,7 +1167,7 @@ class HasMemoized: __slots__ = () - _memoized_keys = frozenset() + _memoized_keys: FrozenSet[str] = frozenset() def _reset_memoizations(self): for elem in self._memoized_keys: @@ -1104,7 +1189,8 @@ class HasMemoized: __name__: str def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None): - self.fget = fget + # https://github.com/python/mypy/issues/708 + self.fget = fget # type: ignore self.__doc__ = doc or fget.__doc__ self.__name__ = fget.__name__ @@ -1268,7 +1354,7 @@ def constructor_copy(obj, cls, *args, **kw): def counter(): """Return a threadsafe counter function.""" - lock = compat.threading.Lock() + lock = threading.Lock() counter = itertools.count(1) # avoid the 2to3 "next" transformation... @@ -1362,12 +1448,14 @@ class classproperty(property): """ - def __init__(self, fget, *arg, **kw): + fget: Callable[[Any], Any] + + def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any): super(classproperty, self).__init__(fget, *arg, **kw) self.__doc__ = fget.__doc__ - def __get__(desc, self, cls): - return desc.fget(cls) + def __get__(self, obj: Any, cls: Optional[type] = None) -> Any: + return self.fget(cls) # type: ignore class hybridproperty: @@ -1406,7 +1494,9 @@ class hybridmethod: class _symbol(int): - def __new__(self, name, doc=None, canonical=None): + name: str + + def __new__(cls, name, doc=None, canonical=None): """Construct a new named symbol.""" assert isinstance(name, str) if canonical is None: @@ -1452,8 +1542,8 @@ class symbol: """ - symbols = {} - _lock = compat.threading.Lock() + symbols: Dict[str, "_symbol"] = {} + _lock = threading.Lock() def __new__(cls, name, doc=None, canonical=None): with cls._lock: @@ -1546,6 +1636,8 @@ class _hash_limit_string(str): """ + _hash: int + def __new__(cls, value, num, args): interpolated = (value % args) + ( " (this warning may be suppressed after %d occurrences)" % num @@ -1731,8 +1823,8 @@ class EnsureKWArg: super().__init_subclass__() @classmethod - def _wrap_w_kw(cls, fn): - def wrap(*arg, **kw): + def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]: + def wrap(*arg: Any, **kw: Any) -> Any: return fn(*arg) return update_wrapper(wrap, fn) @@ -1910,15 +2002,12 @@ def repr_tuple_names(names): def has_compiled_ext(raise_=False): - try: - from sqlalchemy.cyextension import collections # noqa F401 - from sqlalchemy.cyextension import immutabledict # noqa F401 - from sqlalchemy.cyextension import processors # noqa F401 - from sqlalchemy.cyextension import resultproxy # noqa F401 - from sqlalchemy.cyextension import util # noqa F401 - + if HAS_CYEXTENSION: return True - except ImportError: - if raise_: - raise + elif raise_: + raise ImportError( + "cython extensions were expected to be installed, " + "but are not present" + ) + else: return False diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index d2cd0a1a71..3062d9d8ab 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -17,12 +17,11 @@ producing a ``put()`` inside the ``get()`` and therefore a reentrant condition. """ - +import asyncio from collections import deque +import threading from time import time as _time -from .compat import threading -from .concurrency import asyncio from .concurrency import await_fallback from .concurrency import await_only from .langhelpers import memoized_property diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 5767d258b0..62a9f6c8a8 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -1,3 +1,4 @@ +import typing from typing import Any from typing import Callable # noqa from typing import Generic @@ -12,27 +13,29 @@ from . import compat _T = TypeVar("_T", bound=Any) -if compat.py38: - from typing import Literal - from typing import Protocol - from typing import TypedDict +if typing.TYPE_CHECKING or not compat.py38: + from typing_extensions import Literal # noqa F401 + from typing_extensions import Protocol # noqa F401 + from typing_extensions import TypedDict # noqa F401 else: - from typing_extensions import Literal # noqa - from typing_extensions import Protocol # noqa - from typing_extensions import TypedDict # noqa + from typing import Literal # noqa F401 + from typing import Protocol # noqa F401 + from typing import TypedDict # noqa F401 -if compat.py310: - from typing import Concatenate - from typing import ParamSpec +if typing.TYPE_CHECKING or not compat.py310: + from typing_extensions import Concatenate # noqa F401 + from typing_extensions import ParamSpec # noqa F401 else: - from typing_extensions import Concatenate # noqa - from typing_extensions import ParamSpec # noqa + from typing import Concatenate # noqa F401 + from typing import ParamSpec # noqa F401 -_T = TypeVar("_T") +class _TypeToInstance(Generic[_T]): + """describe a variable that moves between a class and an instance of + that class. + """ -class _TypeToInstance(Generic[_T]): @overload def __get__(self, instance: None, owner: Any) -> Type[_T]: ... @@ -41,6 +44,9 @@ class _TypeToInstance(Generic[_T]): def __get__(self, instance: object, owner: Any) -> _T: ... + def __get__(self, instance: object, owner: Any) -> Union[Type[_T], _T]: + ... + @overload def __set__(self, instance: None, value: Type[_T]) -> None: ... @@ -49,6 +55,9 @@ class _TypeToInstance(Generic[_T]): def __set__(self, instance: object, value: _T) -> None: ... + def __set__(self, instance: object, value: Union[Type[_T], _T]) -> None: + ... + class ReadOnlyInstanceDescriptor(Protocol[_T]): """protocol representing an instance-only descriptor""" diff --git a/pyproject.toml b/pyproject.toml index 2707bea97c..036892d45b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,3 +23,67 @@ filterwarnings = [ "error::DeprecationWarning:test", "error::DeprecationWarning:sqlalchemy" ] + + +[tool.pyright] +include = [ + "lib/sqlalchemy/events.py", + "lib/sqlalchemy/exc.py", + "lib/sqlalchemy/log.py", + "lib/sqlalchemy/inspection.py", + "lib/sqlalchemy/schema.py", + "lib/sqlalchemy/types.py", + "lib/sqlalchemy/util/", +] + + + +[tool.mypy] +mypy_path = "./lib/" +show_error_codes = true +strict = false +incremental = true + +# disabled checking +[[tool.mypy.overrides]] +module="sqlalchemy.*" +ignore_errors = true +warn_unused_ignores = false + +strict = true + +# https://github.com/python/mypy/issues/8754 +# we are a pep-561 package, so implicit-rexport should be +# enabled +implicit_reexport = true + +# individual packages or even modules should be listed here +# with strictness-specificity set up. there's no way we are going to get +# the whole library 100% strictly typed, so we have to tune this based on +# the type of module or package we are dealing with + +# strict checking +[[tool.mypy.overrides]] +module = [ + "sqlalchemy.events", + "sqlalchemy.events", + "sqlalchemy.exc", + "sqlalchemy.inspection", + "sqlalchemy.schema", + "sqlalchemy.types", +] +ignore_errors = false +strict = true + +# partial checking, internals can be untyped +[[tool.mypy.overrides]] +module="sqlalchemy.util.*" +ignore_errors = false + +# util is for internal use so we can get by without everything +# being typed +allow_untyped_defs = true +check_untyped_defs = false +allow_untyped_calls = true + + diff --git a/setup.cfg b/setup.cfg index 3f903eb62f..2eceb0b816 100644 --- a/setup.cfg +++ b/setup.cfg @@ -114,18 +114,6 @@ per-file-ignores = lib/sqlalchemy/types.py:F401 lib/sqlalchemy/sql/expression.py:F401 -[mypy] -mypy_path = ./lib/ -strict = True -incremental = True -#plugins = sqlalchemy.ext.mypy.plugin - -[mypy-sqlalchemy.*] -ignore_errors = True - -[mypy-sqlalchemy.ext.mypy.*] -ignore_errors = False - [sqla_testing] requirement_cls = test.requirements:DefaultRequirements profile_file = test/profiles.txt diff --git a/test/base/test_utils.py b/test/base/test_utils.py index a88b7c56c5..dc02c37cb0 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -162,7 +162,20 @@ class OrderedSetTest(fixtures.TestBase): eq_(o.difference(iter([3, 4])), util.OrderedSet([2, 5])) eq_(o.intersection(iter([3, 4, 6])), util.OrderedSet([3, 4])) - eq_(o.union(iter([3, 4, 6])), util.OrderedSet([2, 3, 4, 5, 6])) + eq_(o.union(iter([3, 4, 6])), util.OrderedSet([3, 2, 4, 5, 6])) + + def test_len(self): + eq_(len(util.OrderedSet([1, 2, 3])), 3) + + def test_eq_no_insert_order(self): + eq_(util.OrderedSet([3, 2, 4, 5]), util.OrderedSet([2, 3, 4, 5])) + + ne_(util.OrderedSet([3, 2, 4, 5]), util.OrderedSet([3, 2, 4, 5, 6])) + + def test_eq_non_ordered_set(self): + eq_(util.OrderedSet([3, 2, 4, 5]), {2, 3, 4, 5}) + + ne_(util.OrderedSet([3, 2, 4, 5]), {3, 2, 4, 5, 6}) def test_repr(self): o = util.OrderedSet([]) @@ -295,7 +308,7 @@ class ImmutableTest(fixtures.TestBase): lambda: d.update({2: 4}), ) if hasattr(d, "pop"): - calls += (d.pop, d.popitem) + calls += (lambda: d.pop(2), d.popitem) for m in calls: with expect_raises_message(TypeError, "object is immutable"): m() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 59ebc87e2b..59bc4863fb 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -3333,6 +3333,8 @@ class OnConnectTest(fixtures.TestBase): cls_ = testing.db.dialect.__class__ class SomeDialect(cls_): + supports_statement_cache = True + def initialize(self, connection): super(SomeDialect, self).initialize(connection) m1.append("initialize") diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index dc04d4da65..a83ca41947 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -31,20 +31,9 @@ from sqlalchemy.testing import not_in from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -from sqlalchemy.util import has_compiled_ext -from sqlalchemy.util import OrderedSet from test.orm import _fixtures -if has_compiled_ext(): - # cython ordered set is immutable, subclass it with a python - # class so that its method can be replaced - _OrderedSet = OrderedSet - - class OrderedSet(_OrderedSet): - pass - - class MergeTest(_fixtures.FixtureTest): """Session.merge() functionality""" @@ -167,7 +156,7 @@ class MergeTest(_fixtures.FixtureTest): users, properties={ "addresses": relationship( - Address, backref="user", collection_class=OrderedSet + Address, backref="user", collection_class=set ) }, ) @@ -178,12 +167,10 @@ class MergeTest(_fixtures.FixtureTest): u = User( id=7, name="fred", - addresses=OrderedSet( - [ - Address(id=1, email_address="fred1"), - Address(id=2, email_address="fred2"), - ] - ), + addresses={ + Address(id=1, email_address="fred1"), + Address(id=2, email_address="fred2"), + }, ) eq_(load.called, 0) @@ -203,12 +190,10 @@ class MergeTest(_fixtures.FixtureTest): User( id=7, name="fred", - addresses=OrderedSet( - [ - Address(id=1, email_address="fred1"), - Address(id=2, email_address="fred2"), - ] - ), + addresses={ + Address(id=1, email_address="fred1"), + Address(id=2, email_address="fred2"), + }, ), ) @@ -258,7 +243,7 @@ class MergeTest(_fixtures.FixtureTest): users, properties={ "addresses": relationship( - Address, backref="user", collection_class=OrderedSet + Address, backref="user", collection_class=set ) }, ) @@ -269,12 +254,10 @@ class MergeTest(_fixtures.FixtureTest): u = User( id=None, name="fred", - addresses=OrderedSet( - [ - Address(id=None, email_address="fred1"), - Address(id=None, email_address="fred2"), - ] - ), + addresses={ + Address(id=None, email_address="fred1"), + Address(id=None, email_address="fred2"), + }, ) eq_(load.called, 0) @@ -293,12 +276,10 @@ class MergeTest(_fixtures.FixtureTest): sess.query(User).one(), User( name="fred", - addresses=OrderedSet( - [ - Address(email_address="fred1"), - Address(email_address="fred2"), - ] - ), + addresses={ + Address(email_address="fred1"), + Address(email_address="fred2"), + }, ), ) @@ -341,8 +322,7 @@ class MergeTest(_fixtures.FixtureTest): "addresses": relationship( Address, backref="user", - collection_class=OrderedSet, - order_by=addresses.c.id, + collection_class=set, cascade="all, delete-orphan", ) }, @@ -355,12 +335,10 @@ class MergeTest(_fixtures.FixtureTest): u = User( id=7, name="fred", - addresses=OrderedSet( - [ - Address(id=1, email_address="fred1"), - Address(id=2, email_address="fred2"), - ] - ), + addresses={ + Address(id=1, email_address="fred1"), + Address(id=2, email_address="fred2"), + }, ) sess = fixture_session() sess.add(u) @@ -372,12 +350,10 @@ class MergeTest(_fixtures.FixtureTest): u = User( id=7, name="fred", - addresses=OrderedSet( - [ - Address(id=3, email_address="fred3"), - Address(id=4, email_address="fred4"), - ] - ), + addresses={ + Address(id=3, email_address="fred3"), + Address(id=4, email_address="fred4"), + }, ) u = sess.merge(u) @@ -393,12 +369,10 @@ class MergeTest(_fixtures.FixtureTest): User( id=7, name="fred", - addresses=OrderedSet( - [ - Address(id=3, email_address="fred3"), - Address(id=4, email_address="fred4"), - ] - ), + addresses={ + Address(id=3, email_address="fred3"), + Address(id=4, email_address="fred4"), + }, ), ) sess.flush() @@ -408,12 +382,10 @@ class MergeTest(_fixtures.FixtureTest): User( id=7, name="fred", - addresses=OrderedSet( - [ - Address(id=3, email_address="fred3"), - Address(id=4, email_address="fred4"), - ] - ), + addresses={ + Address(id=3, email_address="fred3"), + Address(id=4, email_address="fred4"), + }, ), ) @@ -433,7 +405,7 @@ class MergeTest(_fixtures.FixtureTest): Address, backref="user", order_by=addresses.c.id, - collection_class=OrderedSet, + collection_class=set, ) }, ) @@ -445,7 +417,7 @@ class MergeTest(_fixtures.FixtureTest): u = User( id=7, name="fred", - addresses=OrderedSet([a, Address(id=2, email_address="fred2")]), + addresses={a, Address(id=2, email_address="fred2")}, ) sess = fixture_session() sess.add(u) @@ -467,12 +439,10 @@ class MergeTest(_fixtures.FixtureTest): User( id=7, name="fred jones", - addresses=OrderedSet( - [ - Address(id=2, email_address="fred2"), - Address(id=3, email_address="fred3"), - ] - ), + addresses={ + Address(id=2, email_address="fred2"), + Address(id=3, email_address="fred3"), + }, ), ) diff --git a/tox.ini b/tox.ini index e55a43cbbe..2100aa507e 100644 --- a/tox.ini +++ b/tox.ini @@ -128,6 +128,16 @@ commands= oracle,mssql,sqlite_file: python reap_dbs.py db_idents.txt +[testenv:pep484] +deps= + greenlet != 0.4.17 + importlib_metadata; python_version < '3.8' + mypy + pyright +commands = + mypy ./lib/sqlalchemy + pyright + [testenv:mypy] deps= pytest>=7.0.0rc1,<8 @@ -158,6 +168,7 @@ commands = flake8 ./lib/ ./test/ ./examples/ setup.py doc/build/conf.py {posargs} black --check ./lib/ ./test/ ./examples/ setup.py doc/build/conf.py + # command run in the github action when cext are active. [testenv:github-cext] deps = {[testenv]deps}