]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
mypy: sqlalchemy.util
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Jan 2022 16:49:02 +0000 (11:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Jan 2022 20:14:01 +0000 (15:14 -0500)
Starting to set up practices and conventions to
get the library typed.

Key goals for typing are:

1. whole library can pass mypy without any strict
   turned on.
2. we can incrementally turn on some strict flags on a per-package/
   module basis, as here we turn on more strictness for sqlalchemy.util, exc,
   and log
3. mypy ORM plugin tests work fully without sqlalchemy2-stubs
   installed
4. public facing methods all have return types, major parameter
   signatures filled in also
5. Foundational elements like util etc. are typed enough so that
   we can use them in fully typed internals higher up the stack.

Conventions set up here:

1. we can use lots of config in setup.cfg to limit where mypy
   is throwing errors and how detailed it should be in different
   packages / modules.  We can use this to push up gerrits
   that will pass tests fully without everything being typed.
2. a new tox target pep484 is added.  this links to a new jenkins
   pep484 job that works across all projects (alembic, dogpile, etc.)

We've worked around some mypy bugs that will likely
be around for awhile, and also set up some core practices
for how to deal with certain things such as public_factory
modules (mypy won't accept a module from a callable at all,
so need to use simple type checking conditionals).

References: #6810
Change-Id: I80be58029896a29fd9f491aa3215422a8b705e12

35 files changed:
.github/workflows/run-test.yaml
MANIFEST.in
lib/sqlalchemy/cyextension/collections.pyx
lib/sqlalchemy/cyextension/immutabledict.pyx
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/event/attr.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/inspection.py
lib/sqlalchemy/log.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/pool/impl.py
lib/sqlalchemy/sql/_py_util.py
lib/sqlalchemy/sql/_typing.py [new file with mode: 0644]
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/_concurrency_py3k.py
lib/sqlalchemy/util/_preloaded.py
lib/sqlalchemy/util/_py_collections.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/concurrency.py
lib/sqlalchemy/util/deprecations.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/queue.py
lib/sqlalchemy/util/typing.py
pyproject.toml
setup.cfg
test/base/test_utils.py
test/engine/test_execute.py
test/orm/test_merge.py
tox.ini

index 6fbb29bdc94f26459a11e528afa7b4cbd19e2ee7..196e3c1b15ae9ac18ad74ef47751b53b8f4a3fb4 100644 (file)
@@ -188,3 +188,36 @@ jobs:
 
       - name: Run tests
         run: tox -e pep8
+
+  run-pep484:
+    name: pep484-${{ matrix.python-version }}
+    runs-on: ${{ matrix.os }}
+    strategy:
+      # run this job using this matrix, excluding some combinations below.
+      matrix:
+        os:
+          - "ubuntu-latest"
+        python-version:
+          - "3.10"
+
+      fail-fast: false
+
+    # steps to run in each job. Some are github actions, others run shell commands
+    steps:
+      - name: Checkout repo
+        uses: actions/checkout@v2
+
+      - name: Set up python
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          architecture: ${{ matrix.architecture }}
+
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install --upgrade tox setuptools
+          pip list
+
+      - name: Run tests
+        run: tox -e pep484
index 0cb613385140a78154404fe49bcc57002cea0a24..eb447a0bd0cafb102b0a50a25a640784ab1a62d3 100644 (file)
@@ -5,6 +5,11 @@ recursive-include doc *.html *.css *.txt *.js *.png *.py Makefile *.rst *.sty
 recursive-include examples *.py *.xml
 recursive-include test *.py *.dat *.testpatch
 
+# for some reason in some environments stale Cython .c files
+# are being pulled in, these should never be in a dist
+exclude lib/sqlalchemy/cyextension/*.c
+exclude lib/sqlalchemy/cyextension/*.so
+
 # include the pyx and pxd extensions, which otherwise
 # don't come in if --with-cextensions isn't specified.
 recursive-include lib *.pyx *.pxd *.txt *.typed
index e695d4c62d4609a656962037ef2de2aa762026c1..5a344da43224bd1348bdadcdad13c2d7d9bc4784 100644 (file)
@@ -22,52 +22,53 @@ cdef list cunique_list(seq, hashfunc=None):
 def unique_list(seq, hashfunc=None):
     return cunique_list(seq, hashfunc)
 
-cdef class OrderedSet(set):
+cdef class OrderedSet:
 
     cdef list _list
+    cdef set _set
 
     def __init__(self, d=None):
-        set.__init__(self)
         if d is not None:
             self._list = cunique_list(d)
-            set.update(self, self._list)
+            self._set = set(self._list)
         else:
             self._list = []
+            self._set = set()
 
     cdef OrderedSet _copy(self):
         cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
         cp._list = list(self._list)
-        set.update(cp, cp._list)
+        cp._set = set(cp._list)
         return cp
 
     cdef OrderedSet _from_list(self, list new_list):
         cdef OrderedSet new = OrderedSet.__new__(OrderedSet)
         new._list = new_list
-        set.update(new, new_list)
+        new._set = set(new_list)
         return new
 
     def add(self, element):
         if element not in self:
             self._list.append(element)
-            PySet_Add(self, element)
+            PySet_Add(self._set, element)
 
     def remove(self, element):
         # set.remove will raise if element is not in self
-        set.remove(self, element)
+        self._set.remove(element)
         self._list.remove(element)
 
     def insert(self, Py_ssize_t pos, element):
         if element not in self:
             self._list.insert(pos, element)
-            PySet_Add(self, element)
+            PySet_Add(self._set, element)
 
     def discard(self, element):
         if element in self:
-            set.remove(self, element)
+            self._set.remove(element)
             self._list.remove(element)
 
     def clear(self):
-        set.clear(self)
+        self._set.clear()
         self._list = []
 
     def __getitem__(self, key):
@@ -84,21 +85,34 @@ cdef class OrderedSet(set):
 
     __str__ = __repr__
 
-    def update(self, iterable):
-        for e in iterable:
-            if e not in self:
-                self._list.append(e)
-                set.add(self, e)
-        return self
+    def update(self, *iterables):
+        for iterable in iterables:
+            for e in iterable:
+                if e not in self:
+                    self._list.append(e)
+                    self._set.add(e)
 
     def __ior__(self, iterable):
-        return self.update(iterable)
+        self.update(iterable)
+        return self
 
     def union(self, other):
         result = self._copy()
         result.update(other)
         return result
 
+    def __len__(self) -> int:
+        return len(self._set)
+
+    def __eq__(self, other):
+        return self._set == other
+
+    def __ne__(self, other):
+        return self._set != other
+
+    def __contains__(self, element):
+        return element in self._set
+
     def __or__(self, other):
         return self.union(other)
 
@@ -138,27 +152,27 @@ cdef class OrderedSet(set):
         cdef set other_set = self._to_set(other)
         set.intersection_update(self, other_set)
         self._list = [a for a in self._list if a in other_set]
-        return self
 
     def __iand__(self, other):
-        return self.intersection_update(other)
+        self.intersection_update(other)
+        return self
 
     def symmetric_difference_update(self, other):
         set.symmetric_difference_update(self, other)
         self._list = [a for a in self._list if a in self]
         self._list += [a for a in other if a in self]
-        return self
 
     def __ixor__(self, other):
-        return self.symmetric_difference_update(other)
+        self.symmetric_difference_update(other)
+        return self
 
     def difference_update(self, other):
         set.difference_update(self, other)
         self._list = [a for a in self._list if a in self]
-        return self
 
     def __isub__(self, other):
-        return self.difference_update(other)
+        self.difference_update(other)
+        return self
 
 
 cdef object cy_id(object item):
index 89bcf3ed6cfb5f151f44401373cb79c8bec2479b..d07c81bd49031abb0efc5a0e12cf983dc3b35011 100644 (file)
@@ -12,10 +12,25 @@ class ImmutableContainer:
     __delitem__ = __setitem__ = __setattr__ = _immutable
 
 
+class ImmutableDictBase(dict):
+    def _immutable(self, *a,**kw):
+        _immutable_fn(self)
+
+    @classmethod
+    def __class_getitem__(cls, key):
+        return cls
+
+    __delitem__ = __setitem__ = __setattr__ = _immutable
+    clear = pop = popitem = setdefault = update = _immutable
+
 cdef class immutabledict(dict):
     def __repr__(self):
         return f"immutabledict({dict.__repr__(self)})"
 
+    @classmethod
+    def __class_getitem__(cls, key):
+        return cls
+
     def union(self, *args, **kw):
         cdef dict to_merge = None
         cdef immutabledict result
index 40fc5d1620806548e225114c216b1aee9b2cbadd..f7d02e3b0c946d171c324ceda19a97f6de2adb20 100644 (file)
@@ -19,13 +19,17 @@ from .util import _distill_params_20
 from .util import _distill_raw_params
 from .util import TransactionalContext
 from .. import exc
+from .. import inspection
 from .. import log
 from .. import util
 from ..sql import compiler
 from ..sql import util as sql_util
+from ..sql._typing import _ExecuteOptions
+from ..sql._typing import _ExecuteParams
 
 if typing.TYPE_CHECKING:
     from .interfaces import Dialect
+    from .reflection import Inspector  # noqa
     from .url import URL
     from ..pool import Pool
     from ..pool import PoolProxiedConnection
@@ -38,7 +42,7 @@ _EMPTY_EXECUTION_OPTS = util.immutabledict()
 NO_OPTIONS = util.immutabledict()
 
 
-class Connection(ConnectionEventsTarget):
+class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
     """Provides high-level functionality for a wrapped DB-API connection.
 
     The :class:`_engine.Connection` object is procured by calling
@@ -1079,7 +1083,12 @@ class Connection(ConnectionEventsTarget):
 
         return self.execute(statement, parameters, execution_options).scalars()
 
-    def execute(self, statement, parameters=None, execution_options=None):
+    def execute(
+        self,
+        statement,
+        parameters: Optional[_ExecuteParams] = None,
+        execution_options: Optional[_ExecuteOptions] = None,
+    ):
         r"""Executes a SQL statement construct and returns a
         :class:`_engine.Result`.
 
@@ -2270,7 +2279,9 @@ class TwoPhaseTransaction(RootTransaction):
         self.connection._commit_twophase_impl(self.xid, self._is_prepared)
 
 
-class Engine(ConnectionEventsTarget, log.Identified):
+class Engine(
+    ConnectionEventsTarget, log.Identified, inspection.Inspectable["Inspector"]
+):
     """
     Connects a :class:`~sqlalchemy.pool.Pool` and
     :class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a
index 7eebb1f0197e01ef6c0ef2a09fc96ed249953f99..6fb8279894822f4c7ffa059dc8791ee353a7e716 100644 (file)
@@ -5,6 +5,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
+from typing import Any
 
 from . import base
 from . import url as _url
@@ -40,7 +41,7 @@ from ..sql import compiler
         "is deprecated and will be removed in a future release. ",
     ),
 )
-def create_engine(url, **kwargs):
+def create_engine(url: "_url.URL", **kwargs: Any) -> "base.Engine":
     """Create a new :class:`_engine.Engine` instance.
 
     The standard calling form is to send the :ref:`URL <database_urls>` as the
index 371b9c7624fb3c20d81031093ac4d0bcb7196172..df7a53ab7deea48ee001f96e894b296fe333d5dc 100644 (file)
@@ -57,7 +57,7 @@ def cache(fn, self, con, *args, **kw):
 
 
 @inspection._self_inspects
-class Inspector:
+class Inspector(inspection.Inspectable["Inspector"]):
     """Performs database schema inspection.
 
     The Inspector acts as a proxy to the reflection methods of the
index 48ce1629ae4f1cedc62eabbfbb2a224041bc676d..a059662224c452ad5557edeb7941c12ca800cf0c 100644 (file)
@@ -30,13 +30,13 @@ as well as support for subclass propagation (e.g. events assigned to
 """
 import collections
 from itertools import chain
+import threading
 import weakref
 
 from . import legacy
 from . import registry
 from .. import exc
 from .. import util
-from ..util import threading
 from ..util.concurrency import AsyncAdaptedLock
 
 
index 8fdacbdf2e54de50884391a288b82c8be0ddc1dc..6732edd4e84030e3142b0b5323a165cbb190d78c 100644 (file)
@@ -12,25 +12,39 @@ raised as a result of DBAPI exceptions are all subclasses of
 :exc:`.DBAPIError`.
 
 """
+import typing
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type
+from typing import Union
 
 from .util import _preloaded
 from .util import compat
 
+if typing.TYPE_CHECKING:
+    from .engine.interfaces import Dialect
+    from .sql._typing import _ExecuteParams
+    from .sql.compiler import Compiled
+    from .sql.elements import ClauseElement
+
 _version_token = None
 
 
 class HasDescriptionCode:
     """helper which adds 'code' as an attribute and '_code_str' as a method"""
 
-    code = None
+    code: Optional[str] = None
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any):
         code = kw.pop("code", None)
         if code is not None:
             self.code = code
         super(HasDescriptionCode, self).__init__(*arg, **kw)
 
-    def _code_str(self):
+    def _code_str(self) -> str:
         if not self.code:
             return ""
         else:
@@ -43,7 +57,7 @@ class HasDescriptionCode:
                 )
             )
 
-    def __str__(self):
+    def __str__(self) -> str:
         message = super(HasDescriptionCode, self).__str__()
         if self.code:
             message = "%s %s" % (message, self._code_str())
@@ -53,7 +67,7 @@ class HasDescriptionCode:
 class SQLAlchemyError(HasDescriptionCode, Exception):
     """Generic error class."""
 
-    def _message(self):
+    def _message(self) -> str:
         # rules:
         #
         # 1. single arg string will usually be a unicode
@@ -64,16 +78,18 @@ class SQLAlchemyError(HasDescriptionCode, Exception):
         # SQLAlchemy though this is happening in at least one known external
         # library, call str() which does a repr().
         #
+        text: str
+
         if len(self.args) == 1:
-            text = self.args[0]
+            arg_text = self.args[0]
 
-            if isinstance(text, bytes):
-                text = compat.decode_backslashreplace(text, "utf-8")
+            if isinstance(arg_text, bytes):
+                text = compat.decode_backslashreplace(arg_text, "utf-8")
             # This is for when the argument is not a string of any sort.
             # Otherwise, converting this exception to string would fail for
             # non-string arguments.
             else:
-                text = str(text)
+                text = str(arg_text)
 
             return text
         else:
@@ -82,7 +98,7 @@ class SQLAlchemyError(HasDescriptionCode, Exception):
             # a repr() of the tuple
             return str(self.args)
 
-    def _sql_message(self):
+    def _sql_message(self) -> str:
         message = self._message()
 
         if self.code:
@@ -90,7 +106,7 @@ class SQLAlchemyError(HasDescriptionCode, Exception):
 
         return message
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self._sql_message()
 
 
@@ -110,13 +126,13 @@ class ObjectNotExecutableError(ArgumentError):
 
     """
 
-    def __init__(self, target):
+    def __init__(self, target: Any):
         super(ObjectNotExecutableError, self).__init__(
             "Not an executable object: %r" % target
         )
         self.target = target
 
-    def __reduce__(self):
+    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         return self.__class__, (self.target,)
 
 
@@ -154,7 +170,14 @@ class CircularDependencyError(SQLAlchemyError):
 
     """
 
-    def __init__(self, message, cycles, edges, msg=None, code=None):
+    def __init__(
+        self,
+        message: str,
+        cycles: Any,
+        edges: Any,
+        msg: Optional[str] = None,
+        code: Optional[str] = None,
+    ):
         if msg is None:
             message += " (%s)" % ", ".join(repr(s) for s in cycles)
         else:
@@ -163,7 +186,7 @@ class CircularDependencyError(SQLAlchemyError):
         self.cycles = cycles
         self.edges = edges
 
-    def __reduce__(self):
+    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         return (
             self.__class__,
             (None, self.cycles, self.edges, self.args[0]),
@@ -187,7 +210,12 @@ class UnsupportedCompilationError(CompileError):
 
     code = "l7de"
 
-    def __init__(self, compiler, element_type, message=None):
+    def __init__(
+        self,
+        compiler: "Compiled",
+        element_type: Type["ClauseElement"],
+        message: Optional[str] = None,
+    ):
         super(UnsupportedCompilationError, self).__init__(
             "Compiler %r can't render element of type %s%s"
             % (compiler, element_type, ": %s" % message if message else "")
@@ -196,7 +224,7 @@ class UnsupportedCompilationError(CompileError):
         self.element_type = element_type
         self.message = message
 
-    def __reduce__(self):
+    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         return self.__class__, (self.compiler, self.element_type, self.message)
 
 
@@ -216,7 +244,7 @@ class DisconnectionError(SQLAlchemyError):
 
     """
 
-    invalidate_pool = False
+    invalidate_pool: bool = False
 
 
 class InvalidatePoolError(DisconnectionError):
@@ -234,7 +262,7 @@ class InvalidatePoolError(DisconnectionError):
 
     """
 
-    invalidate_pool = True
+    invalidate_pool: bool = True
 
 
 class TimeoutError(SQLAlchemyError):  # noqa
@@ -332,11 +360,11 @@ class NoReferencedTableError(NoReferenceError):
 
     """
 
-    def __init__(self, message, tname):
+    def __init__(self, message: str, tname: str):
         NoReferenceError.__init__(self, message)
         self.table_name = tname
 
-    def __reduce__(self):
+    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         return self.__class__, (self.args[0], self.table_name)
 
 
@@ -346,12 +374,12 @@ class NoReferencedColumnError(NoReferenceError):
 
     """
 
-    def __init__(self, message, tname, cname):
+    def __init__(self, message: str, tname: str, cname: str):
         NoReferenceError.__init__(self, message)
         self.table_name = tname
         self.column_name = cname
 
-    def __reduce__(self):
+    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         return (
             self.__class__,
             (self.args[0], self.table_name, self.column_name),
@@ -409,26 +437,29 @@ class StatementError(SQLAlchemyError):
 
     """
 
-    statement = None
+    statement: Optional[str] = None
     """The string SQL statement being invoked when this exception occurred."""
 
-    params = None
+    params: Optional["_ExecuteParams"] = None
     """The parameter list being used when this exception occurred."""
 
-    orig = None
-    """The DBAPI exception object."""
+    orig: Optional[BaseException] = None
+    """The original exception that was thrown.
+
+    """
 
-    ismulti = None
+    ismulti: Optional[bool] = None
+    """multi parameter passed to repr_params().  None is meaningful."""
 
     def __init__(
         self,
-        message,
-        statement,
-        params,
-        orig,
-        hide_parameters=False,
-        code=None,
-        ismulti=None,
+        message: str,
+        statement: Optional[str],
+        params: Optional["_ExecuteParams"],
+        orig: Optional[BaseException],
+        hide_parameters: bool = False,
+        code: Optional[str] = None,
+        ismulti: Optional[bool] = None,
     ):
         SQLAlchemyError.__init__(self, message, code=code)
         self.statement = statement
@@ -436,12 +467,12 @@ class StatementError(SQLAlchemyError):
         self.orig = orig
         self.ismulti = ismulti
         self.hide_parameters = hide_parameters
-        self.detail = []
+        self.detail: List[str] = []
 
-    def add_detail(self, msg):
+    def add_detail(self, msg: str) -> None:
         self.detail.append(msg)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         return (
             self.__class__,
             (
@@ -457,8 +488,11 @@ class StatementError(SQLAlchemyError):
         )
 
     @_preloaded.preload_module("sqlalchemy.sql.util")
-    def _sql_message(self):
-        util = _preloaded.preloaded.sql_util
+    def _sql_message(self) -> str:
+        if typing.TYPE_CHECKING:
+            from .sql import util
+        else:
+            util = _preloaded.preloaded.sql_util
 
         details = [self._message()]
         if self.statement:
@@ -505,18 +539,67 @@ class DBAPIError(StatementError):
 
     code = "dbapi"
 
+    # I dont think I'm going to try to do overloads like this everywhere
+    # in the library, but as this module is early days for me typing everything
+    # I am sort of just practicing
+
+    @overload
     @classmethod
     def instance(
         cls,
-        statement,
-        params,
-        orig,
-        dbapi_base_err,
-        hide_parameters=False,
-        connection_invalidated=False,
-        dialect=None,
-        ismulti=None,
-    ):
+        statement: str,
+        params: "_ExecuteParams",
+        orig: DontWrapMixin,
+        dbapi_base_err: Type[Exception],
+        hide_parameters: bool = False,
+        connection_invalidated: bool = False,
+        dialect: Optional["Dialect"] = None,
+        ismulti: Optional[bool] = None,
+    ) -> DontWrapMixin:
+        ...
+
+    @overload
+    @classmethod
+    def instance(
+        cls,
+        statement: str,
+        params: "_ExecuteParams",
+        orig: Exception,
+        dbapi_base_err: Type[Exception],
+        hide_parameters: bool = False,
+        connection_invalidated: bool = False,
+        dialect: Optional["Dialect"] = None,
+        ismulti: Optional[bool] = None,
+    ) -> StatementError:
+        ...
+
+    @overload
+    @classmethod
+    def instance(
+        cls,
+        statement: str,
+        params: "_ExecuteParams",
+        orig: BaseException,
+        dbapi_base_err: Type[Exception],
+        hide_parameters: bool = False,
+        connection_invalidated: bool = False,
+        dialect: Optional["Dialect"] = None,
+        ismulti: Optional[bool] = None,
+    ) -> BaseException:
+        ...
+
+    @classmethod
+    def instance(
+        cls,
+        statement: str,
+        params: "_ExecuteParams",
+        orig: Union[BaseException, DontWrapMixin],
+        dbapi_base_err: Type[Exception],
+        hide_parameters: bool = False,
+        connection_invalidated: bool = False,
+        dialect: Optional["Dialect"] = None,
+        ismulti: Optional[bool] = None,
+    ) -> Union[BaseException, DontWrapMixin]:
         # Don't ever wrap these, just return them directly as if
         # DBAPIError didn't exist.
         if (
@@ -578,7 +661,7 @@ class DBAPIError(StatementError):
             ismulti=ismulti,
         )
 
-    def __reduce__(self):
+    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         return (
             self.__class__,
             (
@@ -595,13 +678,13 @@ class DBAPIError(StatementError):
 
     def __init__(
         self,
-        statement,
-        params,
-        orig,
-        hide_parameters=False,
-        connection_invalidated=False,
-        code=None,
-        ismulti=None,
+        statement: str,
+        params: "_ExecuteParams",
+        orig: BaseException,
+        hide_parameters: bool = False,
+        connection_invalidated: bool = False,
+        code: Optional[str] = None,
+        ismulti: Optional[bool] = None,
     ):
         try:
             text = str(orig)
@@ -684,7 +767,7 @@ class SATestSuiteWarning(Warning):
 class SADeprecationWarning(HasDescriptionCode, DeprecationWarning):
     """Issued for usage of deprecated APIs."""
 
-    deprecated_since = None
+    deprecated_since: Optional[str] = None
     "Indicates the version that started raising this deprecation warning"
 
 
@@ -700,10 +783,10 @@ class Base20DeprecationWarning(SADeprecationWarning):
 
     """
 
-    deprecated_since = "1.4"
+    deprecated_since: Optional[str] = "1.4"
     "Indicates the version that started raising this deprecation warning"
 
-    def __str__(self):
+    def __str__(self) -> str:
         return (
             super(Base20DeprecationWarning, self).__str__()
             + " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)"
@@ -724,7 +807,7 @@ class SAPendingDeprecationWarning(PendingDeprecationWarning):
 
     """
 
-    deprecated_since = None
+    deprecated_since: Optional[str] = None
     "Indicates the version that started raising this deprecation warning"
 
 
index 7f9822d02e9e272a454a504bb49c3a6edba700d3..c6e9ca69af94d36a0120b680baa5bc41f9e81574 100644 (file)
@@ -28,15 +28,43 @@ tools which build on top of SQLAlchemy configurations to be constructed
 in a forwards-compatible way.
 
 """
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Generic
+from typing import overload
+from typing import Type
+from typing import TypeVar
+from typing import Union
 
 from . import exc
-from . import util
+from .util.typing import Literal
 
+_T = TypeVar("_T", bound=Any)
 
-_registrars = util.defaultdict(list)
+_registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {}
 
 
-def inspect(subject, raiseerr=True):
+class Inspectable(Generic[_T]):
+    """define a class as inspectable.
+
+    This allows typing to set up a linkage between an object that
+    can be inspected and the type of inspection it returns.
+
+    """
+
+
+@overload
+def inspect(subject: Inspectable[_T], raiseerr: bool = True) -> _T:
+    ...
+
+
+@overload
+def inspect(subject: Any, raiseerr: bool = True) -> Any:
+    ...
+
+
+def inspect(subject: Any, raiseerr: bool = True) -> Any:
     """Produce an inspection object for the given target.
 
     The returned value in some cases may be the
@@ -58,12 +86,14 @@ def inspect(subject, raiseerr=True):
     type_ = type(subject)
     for cls in type_.__mro__:
         if cls in _registrars:
-            reg = _registrars[cls]
-            if reg is True:
+            reg = _registrars.get(cls, None)
+            if reg is None:
+                continue
+            elif reg is True:
                 return subject
             ret = reg(subject)
             if ret is not None:
-                break
+                return ret
     else:
         reg = ret = None
 
@@ -75,8 +105,10 @@ def inspect(subject, raiseerr=True):
     return ret
 
 
-def _inspects(*types):
-    def decorate(fn_or_cls):
+def _inspects(
+    *types: type,
+) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]:
+    def decorate(fn_or_cls: Callable[[Any], Any]) -> Callable[[Any], Any]:
         for type_ in types:
             if type_ in _registrars:
                 raise AssertionError(
@@ -88,6 +120,8 @@ def _inspects(*types):
     return decorate
 
 
-def _self_inspects(cls):
-    _inspects(cls)(True)
+def _self_inspects(cls: Type[_T]) -> Type[_T]:
+    if cls in _registrars:
+        raise AssertionError("Type %s is already " "registered" % cls)
+    _registrars[cls] = True
     return cls
index 6431053a85d9967601a0fc80ad77ac50d89a68c3..e9ab8f423698a6f2132787f34785044266590f9e 100644 (file)
@@ -17,10 +17,21 @@ and :class:`_pool.Pool` objects, corresponds to a logger specific to that
 instance only.
 
 """
-
 import logging
 import sys
+from typing import Any
+from typing import Optional
+from typing import overload
+from typing import Set
+from typing import Type
+from typing import TypeVar
+from typing import Union
+
+from .util.typing import Literal
+
+_IT = TypeVar("_IT", bound="Identified")
 
+_EchoFlagType = Union[None, bool, Literal["debug"]]
 
 # set initial level to WARN.  This so that
 # log statements don't occur in the absence of explicit
@@ -30,7 +41,7 @@ if rootlogger.level == logging.NOTSET:
     rootlogger.setLevel(logging.WARN)
 
 
-def _add_default_handler(logger):
+def _add_default_handler(logger: logging.Logger) -> None:
     handler = logging.StreamHandler(sys.stdout)
     handler.setFormatter(
         logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
@@ -38,32 +49,40 @@ def _add_default_handler(logger):
     logger.addHandler(handler)
 
 
-_logged_classes = set()
+_logged_classes: Set[Type["Identified"]] = set()
 
 
-def _qual_logger_name_for_cls(cls):
+def _qual_logger_name_for_cls(cls: Type["Identified"]) -> str:
     return (
         getattr(cls, "_sqla_logger_namespace", None)
         or cls.__module__ + "." + cls.__name__
     )
 
 
-def class_logger(cls):
+def class_logger(cls: Type[_IT]) -> Type[_IT]:
     logger = logging.getLogger(_qual_logger_name_for_cls(cls))
-    cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG)
-    cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO)
+    cls._should_log_debug = lambda self: logger.isEnabledFor(  # type: ignore[assignment]  # noqa E501
+        logging.DEBUG
+    )
+    cls._should_log_info = lambda self: logger.isEnabledFor(  # type: ignore[assignment]  # noqa E501
+        logging.INFO
+    )
     cls.logger = logger
     _logged_classes.add(cls)
     return cls
 
 
 class Identified:
-    logging_name = None
+    logging_name: Optional[str] = None
 
-    def _should_log_debug(self):
+    logger: Union[logging.Logger, "InstanceLogger"]
+
+    _echo: _EchoFlagType
+
+    def _should_log_debug(self) -> bool:
         return self.logger.isEnabledFor(logging.DEBUG)
 
-    def _should_log_info(self):
+    def _should_log_info(self) -> bool:
         return self.logger.isEnabledFor(logging.INFO)
 
 
@@ -94,7 +113,9 @@ class InstanceLogger:
         "debug": logging.DEBUG,
     }
 
-    def __init__(self, echo, name):
+    _echo: _EchoFlagType
+
+    def __init__(self, echo: _EchoFlagType, name: str):
         self.echo = echo
         self.logger = logging.getLogger(name)
 
@@ -106,41 +127,41 @@ class InstanceLogger:
     #
     # Boilerplate convenience methods
     #
-    def debug(self, msg, *args, **kwargs):
+    def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
         """Delegate a debug call to the underlying logger."""
 
         self.log(logging.DEBUG, msg, *args, **kwargs)
 
-    def info(self, msg, *args, **kwargs):
+    def info(self, msg: str, *args: Any, **kwargs: Any) -> None:
         """Delegate an info call to the underlying logger."""
 
         self.log(logging.INFO, msg, *args, **kwargs)
 
-    def warning(self, msg, *args, **kwargs):
+    def warning(self, msg: str, *args: Any, **kwargs: Any) -> None:
         """Delegate a warning call to the underlying logger."""
 
         self.log(logging.WARNING, msg, *args, **kwargs)
 
     warn = warning
 
-    def error(self, msg, *args, **kwargs):
+    def error(self, msg: str, *args: Any, **kwargs: Any) -> None:
         """
         Delegate an error call to the underlying logger.
         """
         self.log(logging.ERROR, msg, *args, **kwargs)
 
-    def exception(self, msg, *args, **kwargs):
+    def exception(self, msg: str, *args: Any, **kwargs: Any) -> None:
         """Delegate an exception call to the underlying logger."""
 
         kwargs["exc_info"] = 1
         self.log(logging.ERROR, msg, *args, **kwargs)
 
-    def critical(self, msg, *args, **kwargs):
+    def critical(self, msg: str, *args: Any, **kwargs: Any) -> None:
         """Delegate a critical call to the underlying logger."""
 
         self.log(logging.CRITICAL, msg, *args, **kwargs)
 
-    def log(self, level, msg, *args, **kwargs):
+    def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None:
         """Delegate a log call to the underlying logger.
 
         The level here is determined by the echo
@@ -162,14 +183,14 @@ class InstanceLogger:
         if level >= selected_level:
             self.logger._log(level, msg, args, **kwargs)
 
-    def isEnabledFor(self, level):
+    def isEnabledFor(self, level: int) -> bool:
         """Is this logger enabled for level 'level'?"""
 
         if self.logger.manager.disable >= level:
             return False
         return level >= self.getEffectiveLevel()
 
-    def getEffectiveLevel(self):
+    def getEffectiveLevel(self) -> int:
         """What's the effective level for this logger?"""
 
         level = self._echo_map[self.echo]
@@ -178,7 +199,9 @@ class InstanceLogger:
         return level
 
 
-def instance_logger(instance, echoflag=None):
+def instance_logger(
+    instance: Identified, echoflag: _EchoFlagType = None
+) -> None:
     """create a logger for an instance that implements :class:`.Identified`."""
 
     if instance.logging_name:
@@ -191,6 +214,8 @@ def instance_logger(instance, echoflag=None):
 
     instance._echo = echoflag
 
+    logger: Union[logging.Logger, InstanceLogger]
+
     if echoflag in (False, None):
         # if no echo setting or False, return a Logger directly,
         # avoiding overhead of filtering
@@ -215,11 +240,25 @@ class echo_property:
     ``logging.DEBUG``.
     """
 
-    def __get__(self, instance, owner):
+    @overload
+    def __get__(
+        self, instance: "Literal[None]", owner: "echo_property"
+    ) -> "echo_property":
+        ...
+
+    @overload
+    def __get__(
+        self, instance: Identified, owner: "echo_property"
+    ) -> _EchoFlagType:
+        ...
+
+    def __get__(
+        self, instance: Optional[Identified], owner: "echo_property"
+    ) -> Union["echo_property", _EchoFlagType]:
         if instance is None:
             return self
         else:
             return instance._echo
 
-    def __set__(self, instance, value):
+    def __set__(self, instance: Identified, value: _EchoFlagType) -> None:
         instance_logger(instance, echoflag=value)
index 260ad1f9908a081aa2a4ed7dcdbdaf4ab41f35ed..75ce8216f60b0353564c104d8fa5defa30da03d3 100644 (file)
@@ -104,6 +104,7 @@ through the adapter, allowing for some very sophisticated behavior.
 """
 
 import operator
+import threading
 import weakref
 
 from sqlalchemy.util.compat import inspect_getfullargspec
@@ -122,7 +123,7 @@ __all__ = [
     "attribute_mapped_collection",
 ]
 
-__instrumentation_mutex = util.threading.Lock()
+__instrumentation_mutex = threading.Lock()
 
 
 class _PlainColumnGetter:
index 29a9c9edf737b4f11ffb3f4083e07a00c36b0f4e..59fabb9b6b7c2faae3fe34c43745bbbaa60fb98a 100644 (file)
@@ -42,6 +42,9 @@ from ..sql.selectable import FromClause
 from ..util import hybridmethod
 from ..util import hybridproperty
 
+if typing.TYPE_CHECKING:
+    from .state import InstanceState  # noqa
+
 _T = TypeVar("_T", bound=Any)
 
 
@@ -64,7 +67,9 @@ def has_inherited_table(cls):
     return False
 
 
-class DeclarativeAttributeIntercept(type):
+class DeclarativeAttributeIntercept(
+    type, inspection.Inspectable["Mapper[Any]"]
+):
     """Metaclass that may be used in conjunction with the
     :class:`_orm.DeclarativeBase` class to support addition of class
     attributes dynamically.
@@ -78,7 +83,7 @@ class DeclarativeAttributeIntercept(type):
         _del_attribute(cls, key)
 
 
-class DeclarativeMeta(type):
+class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]):
     def __init__(cls, classname, bases, dict_, **kw):
         # early-consume registry from the initial declarative base,
         # assign privately to not conflict with subclass attributes named
@@ -421,7 +426,7 @@ def _setup_declarative_base(cls):
         cls.metadata = cls.registry.metadata
 
 
-class DeclarativeBaseNoMeta:
+class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
     """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass
     to intercept new attributes.
 
@@ -451,7 +456,10 @@ class DeclarativeBaseNoMeta:
             cls._sa_registry.map_declaratively(cls)
 
 
-class DeclarativeBase(metaclass=DeclarativeAttributeIntercept):
+class DeclarativeBase(
+    inspection.Inspectable["InstanceState"],
+    metaclass=DeclarativeAttributeIntercept,
+):
     """Base class used for declarative class definitions.
 
     The :class:`_orm.DeclarativeBase` allows for the creation of new
index e9a89d102ba77704af2e452f461b775079415e9a..fdf065488a9e62cd06030fb8578a56157fb281e9 100644 (file)
@@ -18,6 +18,7 @@ from collections import deque
 from functools import reduce
 from itertools import chain
 import sys
+import threading
 from typing import Generic
 from typing import Type
 from typing import TypeVar
@@ -83,7 +84,7 @@ _already_compiling = False
 NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE")
 
 # lock used to synchronize the "mapper configure" step
-_CONFIGURE_MUTEX = util.threading.RLock()
+_CONFIGURE_MUTEX = threading.RLock()
 
 
 @inspection._self_inspects
@@ -93,6 +94,7 @@ class Mapper(
     ORMEntityColumnsClauseRole,
     sql_base.MemoizedHasCacheKey,
     InspectionAttr,
+    log.Identified,
     Generic[_MC],
 ):
     """Defines an association between a Python class and a database table or
@@ -2361,7 +2363,7 @@ class Mapper(
                 yield c
 
     @HasMemoized.memoized_attribute
-    def attrs(self):
+    def attrs(self) -> util.ImmutableProperties["MapperProperty"]:
         """A namespace of all :class:`.MapperProperty` objects
         associated this mapper.
 
index c7408b00b84675ea00fc51acf1c631492a1f94ac..7a422cd2ac97a0dedf80e028f55a0487b5cc28b2 100644 (file)
@@ -10,6 +10,7 @@
 
 """
 
+import threading
 import traceback
 import weakref
 
@@ -21,7 +22,6 @@ from .. import exc
 from .. import util
 from ..util import chop_traceback
 from ..util import queue as sqla_queue
-from ..util import threading
 
 
 class QueuePool(Pool):
index e9357bf7d835ed5afd9fc98b44806844f8e19556..594967a40b65fe67d99f7a8a3834b8b4015adb23 100644 (file)
@@ -5,8 +5,10 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
+from typing import Dict
 
-class prefix_anon_map(dict):
+
+class prefix_anon_map(Dict[str, str]):
     """A map that creates new keys for missing key access.
 
     Considers keys of the form "<ident> <name>" to produce
@@ -27,7 +29,7 @@ class prefix_anon_map(dict):
         return value
 
 
-class cache_anon_map(dict):
+class cache_anon_map(Dict[int, str]):
     """A map that creates new keys for missing key access.
 
     Produces an incrementing sequence given a series of unique keys.
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
new file mode 100644 (file)
index 0000000..b5b0efb
--- /dev/null
@@ -0,0 +1,9 @@
+from typing import Any
+from typing import Mapping
+from typing import Sequence
+from typing import Union
+
+_SingleExecuteParams = Mapping[str, Any]
+_MultiExecuteParams = Sequence[_SingleExecuteParams]
+_ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams]
+_ExecuteOptions = Mapping[str, Any]
index fa3bae83530d8cf9d2fcf2dd8c3d614ea6b1c349..c0de1902ff791723d3e9d7cc4cbfca1a749d3b04 100644 (file)
@@ -8,14 +8,20 @@
 """High level utilities which build upon other modules here.
 
 """
-
 from collections import deque
 from itertools import chain
+import typing
+from typing import Any
+from typing import cast
+from typing import Optional
 
 from . import coercions
 from . import operators
 from . import roles
 from . import visitors
+from ._typing import _ExecuteParams
+from ._typing import _MultiExecuteParams
+from ._typing import _SingleExecuteParams
 from .annotation import _deep_annotate  # noqa
 from .annotation import _deep_deannotate  # noqa
 from .annotation import _shallow_annotate  # noqa
@@ -45,6 +51,9 @@ from .selectable import TableClause
 from .. import exc
 from .. import util
 
+if typing.TYPE_CHECKING:
+    from ..engine.row import Row
+
 
 def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
     """Create a join condition between two tables or selectables.
@@ -488,13 +497,13 @@ def _quote_ddl_expr(element):
 
 
 class _repr_base:
-    _LIST = 0
-    _TUPLE = 1
-    _DICT = 2
+    _LIST: int = 0
+    _TUPLE: int = 1
+    _DICT: int = 2
 
     __slots__ = ("max_chars",)
 
-    def trunc(self, value):
+    def trunc(self, value: Any) -> str:
         rep = repr(value)
         lenrep = len(rep)
         if lenrep > self.max_chars:
@@ -515,11 +524,11 @@ class _repr_row(_repr_base):
 
     __slots__ = ("row",)
 
-    def __init__(self, row, max_chars=300):
+    def __init__(self, row: "Row", max_chars: int = 300):
         self.row = row
         self.max_chars = max_chars
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         trunc = self.trunc
         return "(%s%s)" % (
             ", ".join(trunc(value) for value in self.row),
@@ -537,13 +546,19 @@ class _repr_params(_repr_base):
 
     __slots__ = "params", "batches", "ismulti"
 
-    def __init__(self, params, batches, max_chars=300, ismulti=None):
-        self.params = params
+    def __init__(
+        self,
+        params: _ExecuteParams,
+        batches: int,
+        max_chars: int = 300,
+        ismulti: Optional[bool] = None,
+    ):
+        self.params: _ExecuteParams = params
         self.ismulti = ismulti
         self.batches = batches
         self.max_chars = max_chars
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if self.ismulti is None:
             return self.trunc(self.params)
 
@@ -557,23 +572,31 @@ class _repr_params(_repr_base):
         else:
             return self.trunc(self.params)
 
-        if self.ismulti and len(self.params) > self.batches:
-            msg = " ... displaying %i of %i total bound parameter sets ... "
-            return " ".join(
-                (
-                    self._repr_multi(self.params[: self.batches - 2], typ)[
-                        0:-1
-                    ],
-                    msg % (self.batches, len(self.params)),
-                    self._repr_multi(self.params[-2:], typ)[1:],
+        if self.ismulti:
+            multi_params = cast(_MultiExecuteParams, self.params)
+
+            if len(self.params) > self.batches:
+                msg = (
+                    " ... displaying %i of %i total bound parameter sets ... "
                 )
-            )
-        elif self.ismulti:
-            return self._repr_multi(self.params, typ)
+                return " ".join(
+                    (
+                        self._repr_multi(
+                            multi_params[: self.batches - 2],
+                            typ,
+                        )[0:-1],
+                        msg % (self.batches, len(self.params)),
+                        self._repr_multi(multi_params[-2:], typ)[1:],
+                    )
+                )
+            else:
+                return self._repr_multi(multi_params, typ)
         else:
-            return self._repr_params(self.params, typ)
+            return self._repr_params(
+                cast(_SingleExecuteParams, self.params), typ
+            )
 
-    def _repr_multi(self, multi_params, typ):
+    def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str:
         if multi_params:
             if isinstance(multi_params[0], list):
                 elem_type = self._LIST
@@ -597,7 +620,7 @@ class _repr_params(_repr_base):
         else:
             return "(%s)" % elements
 
-    def _repr_params(self, params, typ):
+    def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str:
         trunc = self.trunc
         if typ is self._DICT:
             return "{%s}" % (
index 203460c266ed788eb7774a24779222ae0a035f4f..91d15aae086ab31675d8432ec6f353848bdc4b29 100644 (file)
@@ -62,7 +62,6 @@ from .compat import osx
 from .compat import py38
 from .compat import py39
 from .compat import pypy
-from .compat import threading
 from .compat import win32
 from .concurrency import asyncio
 from .concurrency import await_fallback
index e53bc9c43a3e361ec0f6e37eb8fb5327ef09536a..3e4ef1310d6dd67dba19fb2dd0e6f69859098086 100644 (file)
@@ -8,26 +8,53 @@
 """Collection classes and helpers."""
 import collections.abc as collections_abc
 import operator
+import threading
 import types
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
+from typing import ValuesView
 import weakref
 
-from .compat import threading
+from ._has_cy import HAS_CYEXTENSION
+from .typing import Literal
 
-try:
-    from sqlalchemy.cyextension.immutabledict import ImmutableContainer
-    from sqlalchemy.cyextension.immutabledict import immutabledict
-    from sqlalchemy.cyextension.collections import IdentitySet
-    from sqlalchemy.cyextension.collections import OrderedSet
-    from sqlalchemy.cyextension.collections import unique_list  # noqa
-except ImportError:
+if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
     from ._py_collections import immutabledict
     from ._py_collections import IdentitySet
     from ._py_collections import ImmutableContainer
+    from ._py_collections import ImmutableDictBase
     from ._py_collections import OrderedSet
     from ._py_collections import unique_list  # noqa
+else:
+    from sqlalchemy.cyextension.immutabledict import ImmutableContainer
+    from sqlalchemy.cyextension.immutabledict import ImmutableDictBase
+    from sqlalchemy.cyextension.immutabledict import immutabledict
+    from sqlalchemy.cyextension.collections import IdentitySet
+    from sqlalchemy.cyextension.collections import OrderedSet
+    from sqlalchemy.cyextension.collections import unique_list  # noqa
+
 
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
 
-EMPTY_SET = frozenset()
+
+EMPTY_SET: FrozenSet[Any] = frozenset()
 
 
 def coerce_to_immutabledict(d):
@@ -39,14 +66,12 @@ def coerce_to_immutabledict(d):
         return immutabledict(d)
 
 
-EMPTY_DICT = immutabledict()
+EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
 
 
-class FacadeDict(ImmutableContainer, dict):
+class FacadeDict(ImmutableDictBase[Any, Any]):
     """A dictionary that is not publicly mutable."""
 
-    clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
-
     def __new__(cls, *args):
         new = dict.__new__(cls)
         return new
@@ -68,18 +93,23 @@ class FacadeDict(ImmutableContainer, dict):
         return "FacadeDict(%s)" % dict.__repr__(self)
 
 
-class Properties:
+_DT = TypeVar("_DT", bound=Any)
+
+
+class Properties(Generic[_T]):
     """Provide a __getattr__/__setattr__ interface over a dict."""
 
     __slots__ = ("_data",)
 
+    _data: Dict[str, _T]
+
     def __init__(self, data):
         object.__setattr__(self, "_data", data)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._data)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[_T]:
         return iter(list(self._data.values()))
 
     def __dir__(self):
@@ -93,7 +123,7 @@ class Properties:
     def __setitem__(self, key, obj):
         self._data[key] = obj
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: str) -> _T:
         return self._data[key]
 
     def __delitem__(self, key):
@@ -108,16 +138,16 @@ class Properties:
     def __setstate__(self, state):
         object.__setattr__(self, "_data", state["_data"])
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> _T:
         try:
             return self._data[key]
         except KeyError:
             raise AttributeError(key)
 
-    def __contains__(self, key):
+    def __contains__(self, key: str) -> bool:
         return key in self._data
 
-    def as_immutable(self):
+    def as_immutable(self) -> "ImmutableProperties[_T]":
         """Return an immutable proxy for this :class:`.Properties`."""
 
         return ImmutableProperties(self._data)
@@ -125,29 +155,39 @@ class Properties:
     def update(self, value):
         self._data.update(value)
 
-    def get(self, key, default=None):
+    @overload
+    def get(self, key: str) -> Optional[_T]:
+        ...
+
+    @overload
+    def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]:
+        ...
+
+    def get(
+        self, key: str, default: Optional[Union[_DT, _T]] = None
+    ) -> Optional[Union[_T, _DT]]:
         if key in self:
             return self[key]
         else:
             return default
 
-    def keys(self):
+    def keys(self) -> List[str]:
         return list(self._data)
 
-    def values(self):
+    def values(self) -> List[_T]:
         return list(self._data.values())
 
-    def items(self):
+    def items(self) -> List[Tuple[str, _T]]:
         return list(self._data.items())
 
-    def has_key(self, key):
+    def has_key(self, key: str) -> bool:
         return key in self._data
 
     def clear(self):
         self._data.clear()
 
 
-class OrderedProperties(Properties):
+class OrderedProperties(Properties[_T]):
     """Provide a __getattr__/__setattr__ interface with an OrderedDict
     as backing store."""
 
@@ -157,7 +197,7 @@ class OrderedProperties(Properties):
         Properties.__init__(self, OrderedDict())
 
 
-class ImmutableProperties(ImmutableContainer, Properties):
+class ImmutableProperties(ImmutableContainer, Properties[_T]):
     """Provide immutable dict/object attribute to an underlying dictionary."""
 
     __slots__ = ()
@@ -220,7 +260,7 @@ class OrderedIdentitySet(IdentitySet):
                 self.add(o)
 
 
-class PopulateDict(dict):
+class PopulateDict(Dict[_KT, _VT]):
     """A dict which populates missing values via a creation function.
 
     Note the creation function takes a key, unlike
@@ -228,26 +268,26 @@ class PopulateDict(dict):
 
     """
 
-    def __init__(self, creator):
+    def __init__(self, creator: Callable[[_KT], _VT]):
         self.creator = creator
 
-    def __missing__(self, key):
+    def __missing__(self, key: Any) -> Any:
         self[key] = val = self.creator(key)
         return val
 
 
-class WeakPopulateDict(dict):
+class WeakPopulateDict(Dict[_KT, _VT]):
     """Like PopulateDict, but assumes a self + a method and does not create
     a reference cycle.
 
     """
 
-    def __init__(self, creator_method):
+    def __init__(self, creator_method: types.MethodType):
         self.creator = creator_method.__func__
         weakself = creator_method.__self__
         self.weakself = weakref.ref(weakself)
 
-    def __missing__(self, key):
+    def __missing__(self, key: Any) -> Any:
         self[key] = val = self.creator(self.weakself(), key)
         return val
 
@@ -261,37 +301,40 @@ column_dict = dict
 ordered_column_set = OrderedSet
 
 
-_getters = PopulateDict(operator.itemgetter)
-
-_property_getters = PopulateDict(
-    lambda idx: property(operator.itemgetter(idx))
-)
-
-
-class UniqueAppender:
+class UniqueAppender(Generic[_T]):
     """Appends items to a collection ensuring uniqueness.
 
     Additional appends() of the same object are ignored.  Membership is
     determined by identity (``is a``) not equality (``==``).
     """
 
-    def __init__(self, data, via=None):
+    __slots__ = "data", "_data_appender", "_unique"
+
+    data: Union[Iterable[_T], Set[_T], List[_T]]
+    _data_appender: Callable[[_T], None]
+    _unique: Dict[int, Literal[True]]
+
+    def __init__(
+        self,
+        data: Union[Iterable[_T], Set[_T], List[_T]],
+        via: Optional[str] = None,
+    ):
         self.data = data
         self._unique = {}
         if via:
-            self._data_appender = getattr(data, via)
+            self._data_appender = getattr(data, via)  # type: ignore[assignment]  # noqa E501
         elif hasattr(data, "append"):
-            self._data_appender = data.append
+            self._data_appender = cast("List[_T]", data).append  # type: ignore[assignment]  # noqa E501
         elif hasattr(data, "add"):
-            self._data_appender = data.add
+            self._data_appender = cast("Set[_T]", data).add  # type: ignore[assignment]  # noqa E501
 
-    def append(self, item):
+    def append(self, item: _T) -> None:
         id_ = id(item)
         if id_ not in self._unique:
-            self._data_appender(item)
+            self._data_appender(item)  # type: ignore[call-arg]
             self._unique[id_] = True
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[_T]:
         return iter(self.data)
 
 
@@ -302,13 +345,27 @@ def coerce_generator_arg(arg):
         return arg
 
 
-def to_list(x, default=None):
+@overload
+def to_list(x: Sequence[_T], default: Optional[List[_T]] = None) -> List[_T]:
+    ...
+
+
+@overload
+def to_list(
+    x: Optional[Sequence[_T]], default: Optional[List[_T]] = None
+) -> Optional[List[_T]]:
+    ...
+
+
+def to_list(
+    x: Optional[Sequence[_T]], default: Optional[List[_T]] = None
+) -> Optional[List[_T]]:
     if x is None:
         return default
     if not isinstance(x, collections_abc.Iterable) or isinstance(
         x, (str, bytes)
     ):
-        return [x]
+        return [cast(_T, x)]
     elif isinstance(x, list):
         return x
     else:
@@ -367,7 +424,7 @@ def flatten_iterator(x):
             yield elem
 
 
-class LRUCache(dict):
+class LRUCache(typing.MutableMapping[_KT, _VT]):
     """Dictionary with 'squishy' removal of least
     recently used items.
 
@@ -377,7 +434,18 @@ class LRUCache(dict):
 
     """
 
-    __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
+    __slots__ = (
+        "capacity",
+        "threshold",
+        "size_alert",
+        "_data",
+        "_counter",
+        "_mutex",
+    )
+
+    capacity: int
+    threshold: float
+    size_alert: Callable[["LRUCache[_KT, _VT]"], None]
 
     def __init__(self, capacity=100, threshold=0.5, size_alert=None):
         self.capacity = capacity
@@ -385,48 +453,56 @@ class LRUCache(dict):
         self.size_alert = size_alert
         self._counter = 0
         self._mutex = threading.Lock()
+        self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
 
     def _inc_counter(self):
         self._counter += 1
         return self._counter
 
-    def get(self, key, default=None):
-        item = dict.get(self, key, default)
-        if item is not default:
-            item[2] = self._inc_counter()
+    @overload
+    def get(self, key: _KT) -> Optional[_VT]:
+        ...
+
+    @overload
+    def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]:
+        ...
+
+    def get(
+        self, key: _KT, default: Optional[Union[_VT, _T]] = None
+    ) -> Optional[Union[_VT, _T]]:
+        item = self._data.get(key, default)
+        if item is not default and item is not None:
+            item[2][0] = self._inc_counter()
             return item[1]
         else:
             return default
 
-    def __getitem__(self, key):
-        item = dict.__getitem__(self, key)
-        item[2] = self._inc_counter()
+    def __getitem__(self, key: _KT) -> _VT:
+        item = self._data[key]
+        item[2][0] = self._inc_counter()
         return item[1]
 
-    def values(self):
-        return [i[1] for i in dict.values(self)]
+    def __iter__(self) -> Iterator[_KT]:
+        return iter(self._data)
 
-    def setdefault(self, key, value):
-        if key in self:
-            return self[key]
-        else:
-            self[key] = value
-            return value
-
-    def __setitem__(self, key, value):
-        item = dict.get(self, key)
-        if item is None:
-            item = [key, value, self._inc_counter()]
-            dict.__setitem__(self, key, item)
-        else:
-            item[1] = value
+    def __len__(self) -> int:
+        return len(self._data)
+
+    def values(self) -> ValuesView[_VT]:
+        return typing.ValuesView({k: i[1] for k, i in self._data.items()})
+
+    def __setitem__(self, key: _KT, value: _VT) -> None:
+        self._data[key] = (key, value, [self._inc_counter()])
         self._manage_size()
 
+    def __delitem__(self, __v: _KT) -> None:
+        del self._data[__v]
+
     @property
-    def size_threshold(self):
+    def size_threshold(self) -> float:
         return self.capacity + self.capacity * self.threshold
 
-    def _manage_size(self):
+    def _manage_size(self) -> None:
         if not self._mutex.acquire(False):
             return
         try:
@@ -434,13 +510,15 @@ class LRUCache(dict):
             while len(self) > self.capacity + self.capacity * self.threshold:
                 if size_alert:
                     size_alert = False
-                    self.size_alert(self)
+                    self.size_alert(self)  # type: ignore
                 by_counter = sorted(
-                    dict.values(self), key=operator.itemgetter(2), reverse=True
+                    self._data.values(),
+                    key=operator.itemgetter(2),
+                    reverse=True,
                 )
                 for item in by_counter[self.capacity :]:
                     try:
-                        del self[item[0]]
+                        del self._data[item[0]]
                     except KeyError:
                         # deleted elsewhere; skip
                         continue
@@ -463,6 +541,8 @@ class ScopedRegistry:
       a callable that will return a key to store/retrieve an object.
     """
 
+    __slots__ = "createfunc", "scopefunc", "registry"
+
     def __init__(self, createfunc, scopefunc):
         """Construct a new :class:`.ScopedRegistry`.
 
@@ -529,7 +609,7 @@ class ThreadLocalRegistry(ScopedRegistry):
 
     def clear(self):
         try:
-            del self.registry.value
+            del self.registry.value  # type: ignore
         except AttributeError:
             pass
 
index ac678f8a983e0aa0302133fa6e5e4e314a462cbe..b9e58e68cd807ba67f7e9c3e81e895398565128e 100644 (file)
@@ -8,23 +8,25 @@
 import asyncio
 from contextvars import copy_context as _copy_context
 import sys
+import typing
 from typing import Any
 from typing import Callable
 from typing import Coroutine
 
-import greenlet
+import greenlet  # type: ignore # noqa
 
 from .langhelpers import memoized_property
 from .. import exc
 
-try:
+if not typing.TYPE_CHECKING:
+    try:
 
-    # If greenlet.gr_context is present in current version of greenlet,
-    # it will be set with a copy of the current context on creation.
-    # Refs: https://github.com/python-greenlet/greenlet/pull/198
-    getattr(greenlet.greenlet, "gr_context")
-except (ImportError, AttributeError):
-    _copy_context = None  # noqa
+        # If greenlet.gr_context is present in current version of greenlet,
+        # it will be set with a copy of the current context on creation.
+        # Refs: https://github.com/python-greenlet/greenlet/pull/198
+        getattr(greenlet.greenlet, "gr_context")
+    except (ImportError, AttributeError):
+        _copy_context = None  # noqa
 
 
 def is_exit_exception(e):
@@ -40,7 +42,7 @@ def is_exit_exception(e):
 # Issue for context: https://github.com/python-greenlet/greenlet/issues/173
 
 
-class _AsyncIoGreenlet(greenlet.greenlet):
+class _AsyncIoGreenlet(greenlet.greenlet):  # type: ignore
     def __init__(self, fn, driver):
         greenlet.greenlet.__init__(self, fn, driver)
         self.driver = driver
@@ -48,7 +50,7 @@ class _AsyncIoGreenlet(greenlet.greenlet):
             self.gr_context = _copy_context()
 
 
-def await_only(awaitable: Coroutine) -> Any:
+def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
@@ -72,7 +74,7 @@ def await_only(awaitable: Coroutine) -> Any:
     return current.driver.switch(awaitable)
 
 
-def await_fallback(awaitable: Coroutine) -> Any:
+def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
@@ -97,7 +99,10 @@ def await_fallback(awaitable: Coroutine) -> Any:
 
 
 async def greenlet_spawn(
-    fn: Callable, *args, _require_await=False, **kwargs
+    fn: Callable[..., Any],
+    *args: Any,
+    _require_await: bool = False,
+    **kwargs: Any,
 ) -> Any:
     """Runs a sync function ``fn`` in a new greenlet.
 
index 9448ed33de71c01196e5e0aefdaed08ea3331fa4..b0f8ab444ae26dc845d9f27d69a18e556e1459d1 100644 (file)
@@ -9,8 +9,14 @@
 runtime.
 
 """
-
 import sys
+from types import ModuleType
+import typing
+from typing import Any
+from typing import Callable
+from typing import TypeVar
+
+_FN = TypeVar("_FN", bound=Callable[..., Any])
 
 
 class _ModuleRegistry:
@@ -37,7 +43,7 @@ class _ModuleRegistry:
         self.module_registry = set()
         self.prefix = prefix
 
-    def preload_module(self, *deps):
+    def preload_module(self, *deps: str) -> Callable[[_FN], _FN]:
         """Adds the specified modules to the list to load.
 
         This method can be used both as a normal function and as a decorator.
@@ -46,7 +52,7 @@ class _ModuleRegistry:
         self.module_registry.update(deps)
         return lambda fn: fn
 
-    def import_prefix(self, path):
+    def import_prefix(self, path: str) -> None:
         """Resolve all the modules in the registry that start with the
         specified path.
         """
@@ -61,6 +67,11 @@ class _ModuleRegistry:
                 __import__(module, globals(), locals())
                 self.__dict__[key] = sys.modules[module]
 
+    if typing.TYPE_CHECKING:
+
+        def __getattr__(self, key: str) -> ModuleType:
+            ...
+
 
 preloaded = _ModuleRegistry()
 preload_module = preloaded.preload_module
index ff61f6ca9004d7f5ac8f7864921147c9fc1e4c71..a4e4b8b5db29f9dea2c007e03bd01d3cc9deaff2 100644 (file)
@@ -1,17 +1,52 @@
 from itertools import filterfalse
+from typing import Any
+from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import TypeVar
+
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
 
 
 class ImmutableContainer:
-    def _immutable(self, *arg, **kw):
+    def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
         raise TypeError("%s object is immutable" % self.__class__.__name__)
 
-    __delitem__ = __setitem__ = __setattr__ = _immutable
+    def __delitem__(self, key: Any) -> NoReturn:
+        self._immutable()
 
+    def __setitem__(self, key: Any, value: Any) -> NoReturn:
+        self._immutable()
 
-class immutabledict(ImmutableContainer, dict):
+    def __setattr__(self, key: str, value: Any) -> NoReturn:
+        self._immutable()
 
-    clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
 
+class ImmutableDictBase(ImmutableContainer, Dict[_KT, _VT]):
+    def clear(self) -> NoReturn:
+        self._immutable()
+
+    def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
+        self._immutable()
+
+    def popitem(self) -> NoReturn:
+        self._immutable()
+
+    def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
+        self._immutable()
+
+    def update(self, *arg: Any, **kw: Any) -> NoReturn:
+        self._immutable()
+
+
+class immutabledict(ImmutableDictBase[_KT, _VT]):
     def __new__(cls, *args):
         new = dict.__new__(cls)
         dict.__init__(new, *args)
@@ -41,7 +76,7 @@ class immutabledict(ImmutableContainer, dict):
         dict.__init__(new, self)
         if __d:
             dict.update(new, __d)
-        dict.update(new, kw)
+        dict.update(new, kw)  # type: ignore
         return new
 
     def merge_with(self, *dicts):
@@ -61,110 +96,145 @@ class immutabledict(ImmutableContainer, dict):
         return "immutabledict(%s)" % dict.__repr__(self)
 
 
-class OrderedSet(set):
+class OrderedSet(Generic[_T]):
+    __slots__ = ("_list", "_set", "__weakref__")
+
+    _list: List[_T]
+    _set: Set[_T]
+
     def __init__(self, d=None):
-        set.__init__(self)
         if d is not None:
             self._list = unique_list(d)
-            set.update(self, self._list)
+            self._set = set(self._list)
         else:
             self._list = []
+            self._set = set()
+
+    def __reduce__(self):
+        return (OrderedSet, (self._list,))
 
-    def add(self, element):
+    def add(self, element: _T) -> None:
         if element not in self:
             self._list.append(element)
-        set.add(self, element)
+        self._set.add(element)
 
-    def remove(self, element):
-        set.remove(self, element)
+    def remove(self, element: _T) -> None:
+        self._set.remove(element)
         self._list.remove(element)
 
-    def insert(self, pos, element):
+    def insert(self, pos: int, element: _T) -> None:
         if element not in self:
             self._list.insert(pos, element)
-        set.add(self, element)
+        self._set.add(element)
 
-    def discard(self, element):
+    def discard(self, element: _T) -> None:
         if element in self:
             self._list.remove(element)
-            set.remove(self, element)
+            self._set.remove(element)
 
-    def clear(self):
-        set.clear(self)
+    def clear(self) -> None:
+        self._set.clear()
         self._list = []
 
-    def __getitem__(self, key):
+    def __len__(self) -> int:
+        return len(self._set)
+
+    def __eq__(self, other):
+        if not isinstance(other, OrderedSet):
+            return self._set == other
+        else:
+            return self._set == other._set
+
+    def __ne__(self, other):
+        if not isinstance(other, OrderedSet):
+            return self._set != other
+        else:
+            return self._set != other._set
+
+    def __contains__(self, element: Any) -> bool:
+        return element in self._set
+
+    def __getitem__(self, key: int) -> _T:
         return self._list[key]
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[_T]:
         return iter(self._list)
 
-    def __add__(self, other):
+    def __add__(self, other: Iterator[_T]) -> "OrderedSet[_T]":
         return self.union(other)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s(%r)" % (self.__class__.__name__, self._list)
 
     __str__ = __repr__
 
-    def update(self, iterable):
-        for e in iterable:
-            if e not in self:
-                self._list.append(e)
-                set.add(self, e)
-        return self
+    def update(self, *iterables: Iterable[_T]) -> None:
+        for iterable in iterables:
+            for e in iterable:
+                if e not in self:
+                    self._list.append(e)
+                    self._set.add(e)
 
-    __ior__ = update
+    def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+        self.update(other)
+        return self
 
-    def union(self, other):
+    def union(self, other: Iterable[_T]) -> "OrderedSet[_T]":
         result = self.__class__(self)
         result.update(other)
         return result
 
-    __or__ = union
+    def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+        return self.union(other)
 
-    def intersection(self, other):
+    def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]":
         other = other if isinstance(other, set) else set(other)
         return self.__class__(a for a in self if a in other)
 
-    __and__ = intersection
+    def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+        return self.intersection(other)
 
-    def symmetric_difference(self, other):
+    def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
         other_set = other if isinstance(other, set) else set(other)
         result = self.__class__(a for a in self if a not in other_set)
         result.update(a for a in other if a not in self)
         return result
 
-    __xor__ = symmetric_difference
+    def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+        return self.symmetric_difference(other)
 
-    def difference(self, other):
+    def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
         other = other if isinstance(other, set) else set(other)
         return self.__class__(a for a in self if a not in other)
 
-    __sub__ = difference
+    def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+        return self.difference(other)
 
-    def intersection_update(self, other):
+    def intersection_update(self, other: Iterable[_T]) -> None:
         other = other if isinstance(other, set) else set(other)
-        set.intersection_update(self, other)
+        self._set.intersection_update(other)
         self._list = [a for a in self._list if a in other]
-        return self
 
-    __iand__ = intersection_update
+    def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+        self.intersection_update(other)
+        return self
 
-    def symmetric_difference_update(self, other):
-        set.symmetric_difference_update(self, other)
+    def symmetric_difference_update(self, other: Iterable[_T]) -> None:
+        self._set.symmetric_difference_update(other)
         self._list = [a for a in self._list if a in self]
         self._list += [a for a in other if a in self]
-        return self
 
-    __ixor__ = symmetric_difference_update
+    def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+        self.symmetric_difference_update(other)
+        return self
 
-    def difference_update(self, other):
-        set.difference_update(self, other)
+    def difference_update(self, other: Iterable[_T]) -> None:
+        self._set.difference_update(other)
         self._list = [a for a in self._list if a in self]
-        return self
 
-    __isub__ = difference_update
+    def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+        self.difference_update(other)
+        return self
 
 
 class IdentitySet:
index 679df73c706aa163743bbbde449ef6acc0b6cbb9..0f4befbb1f92fea1997aecf15a11a5ae166ada49 100644 (file)
@@ -6,14 +6,23 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
 """Handle Python version/platform incompatibilities."""
+from __future__ import annotations
+
 import base64
-import collections
 import dataclasses
 import inspect
 import operator
 import platform
 import sys
 import typing
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
 
 
 py311 = sys.version_info >= (3, 11)
@@ -32,27 +41,18 @@ has_refcount_gc = bool(cpython)
 dottedgetter = operator.attrgetter
 next = next  # noqa
 
-FullArgSpec = collections.namedtuple(
-    "FullArgSpec",
-    [
-        "args",
-        "varargs",
-        "varkw",
-        "defaults",
-        "kwonlyargs",
-        "kwonlydefaults",
-        "annotations",
-    ],
-)
-
 
-try:
-    import threading
-except ImportError:
-    import dummy_threading as threading  # noqa
+class FullArgSpec(typing.NamedTuple):
+    args: List[str]
+    varargs: Optional[str]
+    varkw: Optional[str]
+    defaults: Optional[Tuple[Any, ...]]
+    kwonlyargs: List[str]
+    kwonlydefaults: Dict[str, Any]
+    annotations: Dict[str, Any]
 
 
-def inspect_getfullargspec(func):
+def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec:
     """Fully vendored version of getfullargspec from Python 3.3."""
 
     if inspect.ismethod(func):
@@ -90,13 +90,13 @@ def inspect_getfullargspec(func):
     )
 
 
-if py38:
+if typing.TYPE_CHECKING or py38:
     from importlib import metadata as importlib_metadata
 else:
     import importlib_metadata  # noqa
 
 
-if py39:
+if typing.TYPE_CHECKING or py39:
     # pep 584 dict union
     dict_union = operator.or_  # noqa
 else:
@@ -109,7 +109,7 @@ else:
 
 def importlib_metadata_get(group):
     ep = importlib_metadata.entry_points()
-    if hasattr(ep, "select"):
+    if not typing.TYPE_CHECKING and hasattr(ep, "select"):
         return ep.select(group=group)
     else:
         return ep.get(group, ())
@@ -119,15 +119,15 @@ def b(s):
     return s.encode("latin-1")
 
 
-def b64decode(x):
+def b64decode(x: str) -> bytes:
     return base64.b64decode(x.encode("ascii"))
 
 
-def b64encode(x):
+def b64encode(x: bytes) -> str:
     return base64.b64encode(x).decode("ascii")
 
 
-def decode_backslashreplace(text, encoding):
+def decode_backslashreplace(text: bytes, encoding: str) -> str:
     return text.decode(encoding, errors="backslashreplace")
 
 
@@ -150,20 +150,20 @@ def _formatannotation(annotation, base_module=None):
 
 
 def inspect_formatargspec(
-    args,
-    varargs=None,
-    varkw=None,
-    defaults=None,
-    kwonlyargs=(),
-    kwonlydefaults={},
-    annotations={},
-    formatarg=str,
-    formatvarargs=lambda name: "*" + name,
-    formatvarkw=lambda name: "**" + name,
-    formatvalue=lambda value: "=" + repr(value),
-    formatreturns=lambda text: " -> " + text,
-    formatannotation=_formatannotation,
-):
+    args: List[str],
+    varargs: Optional[str] = None,
+    varkw: Optional[str] = None,
+    defaults: Optional[Sequence[Any]] = None,
+    kwonlyargs: Optional[Sequence[str]] = (),
+    kwonlydefaults: Optional[Mapping[str, Any]] = {},
+    annotations: Mapping[str, Any] = {},
+    formatarg: Callable[[str], str] = str,
+    formatvarargs: Callable[[str], str] = lambda name: "*" + name,
+    formatvarkw: Callable[[str], str] = lambda name: "**" + name,
+    formatvalue: Callable[[Any], str] = lambda value: "=" + repr(value),
+    formatreturns: Callable[[Any], str] = lambda text: " -> " + str(text),
+    formatannotation: Callable[[Any], str] = _formatannotation,
+) -> str:
     """Copy formatargspec from python 3.7 standard library.
 
     Python 3 has deprecated formatargspec and requested that Signature
@@ -190,6 +190,9 @@ def inspect_formatargspec(
     specs = []
     if defaults:
         firstdefault = len(args) - len(defaults)
+    else:
+        firstdefault = -1
+
     for i, arg in enumerate(args):
         spec = formatargandannotation(arg)
         if defaults and i >= firstdefault:
index e5183a542a310b0f604d708a51c925a4a719cffe..57ef2300622a9a358244c9cb67d10c1db673bcd6 100644 (file)
@@ -5,11 +5,12 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
+import asyncio  # noqa
 
 have_greenlet = False
 greenlet_error = None
 try:
-    import greenlet  # noqa F401
+    import greenlet  # type: ignore # noqa F401
 except ImportError as e:
     greenlet_error = str(e)
     pass
@@ -24,12 +25,9 @@ else:
     from ._concurrency_py3k import (
         _util_async_run_coroutine_function,
     )  # noqa F401, E501
-    from ._concurrency_py3k import asyncio  # noqa F401
 
 if not have_greenlet:
 
-    asyncio = None  # noqa F811
-
     def _not_implemented():
         # this conditional is to prevent pylance from considering
         # greenlet_spawn() etc as "no return" and dimming out code below it
@@ -46,20 +44,20 @@ if not have_greenlet:
     def is_exit_exception(e):  # noqa F811
         return not isinstance(e, Exception)
 
-    def await_only(thing):  # noqa F811
+    def await_only(thing):  # type: ignore # noqa F811
         _not_implemented()
 
-    def await_fallback(thing):  # noqa F81
+    def await_fallback(thing):  # type: ignore # noqa F81
         return thing
 
-    def greenlet_spawn(fn, *args, **kw):  # noqa F81
+    def greenlet_spawn(fn, *args, **kw):  # type: ignore # noqa F81
         _not_implemented()
 
-    def AsyncAdaptedLock(*args, **kw):  # noqa F81
+    def AsyncAdaptedLock(*args, **kw):  # type: ignore # noqa F81
         _not_implemented()
 
-    def _util_async_run(fn, *arg, **kw):  # noqa F81
+    def _util_async_run(fn, *arg, **kw):  # type: ignore # noqa F81
         return fn(*arg, **kw)
 
-    def _util_async_run_coroutine_function(fn, *arg, **kw):  # noqa F81
+    def _util_async_run_coroutine_function(fn, *arg, **kw):  # type: ignore # noqa F81
         _not_implemented()
index e5d5d5461966ef0a3575fba87aad508d5ef6e029..565cbafe26840d42ca0dc874fc42681d49696e02 100644 (file)
@@ -11,6 +11,8 @@ functionality."""
 import re
 from typing import Any
 from typing import Callable
+from typing import cast
+from typing import Optional
 from typing import TypeVar
 
 from . import compat
@@ -67,11 +69,11 @@ def deprecated_cls(version, message, constructor="__init__"):
 
 
 def deprecated_property(
-    version,
-    message=None,
-    add_deprecation_to_docstring=True,
-    warning=None,
-    enable_warnings=True,
+    version: str,
+    message: Optional[str] = None,
+    add_deprecation_to_docstring: bool = True,
+    warning: Optional[str] = None,
+    enable_warnings: bool = True,
 ) -> Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]]:
     """the @deprecated decorator with a @property.
 
@@ -99,14 +101,17 @@ def deprecated_property(
     great!   now it is.
 
     """
-    return lambda fn: property(
-        deprecated(
-            version,
-            message=message,
-            add_deprecation_to_docstring=add_deprecation_to_docstring,
-            warning=warning,
-            enable_warnings=enable_warnings,
-        )(fn)
+    return cast(
+        Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]],
+        lambda fn: property(
+            deprecated(
+                version,
+                message=message,
+                add_deprecation_to_docstring=add_deprecation_to_docstring,
+                warning=warning,
+                enable_warnings=enable_warnings,
+            )(fn)
+        ),
     )
 
 
@@ -325,11 +330,12 @@ def _decorate_cls_with_warning(
             )
         doc = inject_docstring_text(doc, docstring_header, 1)
 
+        constructor_fn = None
         if type(cls) is type:
             clsdict = dict(cls.__dict__)
             clsdict["__doc__"] = doc
             clsdict.pop("__dict__", None)
-            cls = type(cls.__name__, cls.__bases__, clsdict)
+            cls = type(cls.__name__, cls.__bases__, clsdict)  # type: ignore
             if constructor is not None:
                 constructor_fn = clsdict[constructor]
 
@@ -339,6 +345,7 @@ def _decorate_cls_with_warning(
                 constructor_fn = getattr(cls, constructor)
 
         if constructor is not None:
+            assert constructor_fn is not None
             setattr(
                 cls,
                 constructor,
index 8b65fb4cf66d34b517768666123c251a9316de75..9401c249fe39016e1c50e0b6ea2860748a94bc94 100644 (file)
@@ -19,13 +19,23 @@ import operator
 import re
 import sys
 import textwrap
+import threading
 import types
 import typing
 from typing import Any
 from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
 from typing import Generic
+from typing import Iterator
+from typing import List
 from typing import Optional
 from typing import overload
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
 from typing import TypeVar
 from typing import Union
 import warnings
@@ -33,6 +43,7 @@ import warnings
 from . import _collections
 from . import compat
 from . import typing as compat_typing
+from ._has_cy import HAS_CYEXTENSION
 from .. import exc
 
 _T = TypeVar("_T")
@@ -43,7 +54,7 @@ _HP = TypeVar("_HP", bound="hybridproperty")
 _HM = TypeVar("_HM", bound="hybridmethod")
 
 
-def md5_hex(x):
+def md5_hex(x: Any) -> str:
     x = x.encode("utf-8")
     m = hashlib.md5()
     m.update(x)
@@ -70,26 +81,44 @@ class safe_reraise:
 
     __slots__ = ("warn_only", "_exc_info")
 
-    def __init__(self, warn_only=False):
+    _exc_info: Union[
+        None,
+        Tuple[
+            Type[BaseException],
+            BaseException,
+            types.TracebackType,
+        ],
+        Tuple[None, None, None],
+    ]
+
+    def __init__(self, warn_only: bool = False):
         self.warn_only = warn_only
 
-    def __enter__(self):
+    def __enter__(self) -> None:
         self._exc_info = sys.exc_info()
 
-    def __exit__(self, type_, value, traceback):
+    def __exit__(
+        self,
+        type_: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> None:
+        assert self._exc_info is not None
         # see #2703 for notes
         if type_ is None:
             exc_type, exc_value, exc_tb = self._exc_info
+            assert exc_value is not None
             self._exc_info = None  # remove potential circular references
             if not self.warn_only:
                 raise exc_value.with_traceback(exc_tb)
         else:
             self._exc_info = None  # remove potential circular references
+            assert value is not None
             raise value.with_traceback(traceback)
 
 
-def walk_subclasses(cls):
-    seen = set()
+def walk_subclasses(cls: type) -> Iterator[type]:
+    seen: Set[Any] = set()
 
     stack = [cls]
     while stack:
@@ -102,7 +131,7 @@ def walk_subclasses(cls):
         yield cls
 
 
-def string_or_unprintable(element):
+def string_or_unprintable(element: Any) -> str:
     if isinstance(element, str):
         return element
     else:
@@ -112,13 +141,15 @@ def string_or_unprintable(element):
             return "unprintable element %r" % element
 
 
-def clsname_as_plain_name(cls):
+def clsname_as_plain_name(cls: Type[Any]) -> str:
     return " ".join(
         n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__)
     )
 
 
-def method_is_overridden(instance_or_cls, against_method):
+def method_is_overridden(
+    instance_or_cls: Union[Type[Any], object], against_method: types.MethodType
+) -> bool:
     """Return True if the two class methods don't match."""
 
     if not isinstance(instance_or_cls, type):
@@ -128,18 +159,18 @@ def method_is_overridden(instance_or_cls, against_method):
 
     method_name = against_method.__name__
 
-    current_method = getattr(current_cls, method_name)
+    current_method: types.MethodType = getattr(current_cls, method_name)
 
     return current_method != against_method
 
 
-def decode_slice(slc):
+def decode_slice(slc: slice) -> Tuple[Any, ...]:
     """decode a slice object as sent to __getitem__.
 
     takes into account the 2.5 __index__() method, basically.
 
     """
-    ret = []
+    ret: List[Any] = []
     for x in slc.start, slc.stop, slc.step:
         if hasattr(x, "__index__"):
             x = x.__index__()
@@ -147,23 +178,23 @@ def decode_slice(slc):
     return tuple(ret)
 
 
-def _unique_symbols(used, *bases):
-    used = set(used)
+def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
+    used_set = set(used)
     for base in bases:
         pool = itertools.chain(
             (base,),
             map(lambda i: base + str(i), range(1000)),
         )
         for sym in pool:
-            if sym not in used:
-                used.add(sym)
+            if sym not in used_set:
+                used_set.add(sym)
                 yield sym
                 break
         else:
             raise NameError("exhausted namespace for symbol base %s" % base)
 
 
-def map_bits(fn, n):
+def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
     """Call the given function given each nonzero bit from n."""
 
     while n:
@@ -172,28 +203,34 @@ def map_bits(fn, n):
         n ^= b
 
 
-_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Fn = typing.TypeVar("_Fn", bound=typing.Callable[..., Any])
 _Args = compat_typing.ParamSpec("_Args")
 
 
 def decorator(
-    target: typing.Callable[compat_typing.Concatenate[_Fn, _Args], typing.Any]
+    target: typing.Callable[  # type: ignore
+        compat_typing.Concatenate[_Fn, _Args], typing.Any
+    ]
 ) -> _Fn:
     """A signature-matching decorator factory."""
 
-    def decorate(fn):
+    def decorate(fn: typing.Callable[..., Any]) -> typing.Callable[..., Any]:
         if not inspect.isfunction(fn) and not inspect.ismethod(fn):
             raise Exception("not a decoratable function")
 
         spec = compat.inspect_getfullargspec(fn)
-        env = {}
+        env: Dict[str, Any] = {}
 
         spec = _update_argspec_defaults_into_env(spec, env)
 
-        names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
+        names = (
+            tuple(cast("Tuple[str, ...]", spec[0]))
+            + cast("Tuple[str, ...]", spec[1:3])
+            + (fn.__name__,)
+        )
         targ_name, fn_name = _unique_symbols(names, "target", "fn")
 
-        metadata = dict(target=targ_name, fn=fn_name)
+        metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
         metadata.update(format_argspec_plus(spec, grouped=False))
         metadata["name"] = fn.__name__
         code = (
@@ -205,9 +242,15 @@ def %(name)s%(grouped_args)s:
         )
         env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
 
-        decorated = _exec_code_in_env(code, env, fn.__name__)
+        decorated = cast(
+            types.FunctionType,
+            _exec_code_in_env(code, env, fn.__name__),
+        )
         decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
-        decorated.__wrapped__ = fn
+
+        # claims to be fixed?
+        # https://github.com/python/mypy/issues/11896
+        decorated.__wrapped__ = fn  # type: ignore
         return update_wrapper(decorated, fn)
 
     return typing.cast(_Fn, update_wrapper(decorate, target))
@@ -303,7 +346,26 @@ def _inspect_func_args(fn):
         )
 
 
-def get_cls_kwargs(cls, _set=None):
+@overload
+def get_cls_kwargs(
+    cls: type,
+    *,
+    _set: Optional[Set[str]] = None,
+    raiseerr: compat_typing.Literal[True] = ...,
+) -> Set[str]:
+    ...
+
+
+@overload
+def get_cls_kwargs(
+    cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
+) -> Optional[Set[str]]:
+    ...
+
+
+def get_cls_kwargs(
+    cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
+) -> Optional[Set[str]]:
     r"""Return the full set of inherited kwargs for the given `cls`.
 
     Probes a class's __init__ method, collecting all named arguments.  If the
@@ -321,6 +383,7 @@ def get_cls_kwargs(cls, _set=None):
     toplevel = _set is None
     if toplevel:
         _set = set()
+    assert _set is not None
 
     ctr = cls.__dict__.get("__init__", False)
 
@@ -335,11 +398,18 @@ def get_cls_kwargs(cls, _set=None):
         _set.update(names)
 
         if not has_kw and not toplevel:
-            return None
+            if raiseerr:
+                raise TypeError(
+                    f"given cls {cls} doesn't have an __init__ method"
+                )
+            else:
+                return None
+    else:
+        has_kw = False
 
     if not has_init or has_kw:
         for c in cls.__bases__:
-            if get_cls_kwargs(c, _set) is None:
+            if get_cls_kwargs(c, _set=_set) is None:
                 break
 
     _set.discard("self")
@@ -411,7 +481,9 @@ def get_callable_argspec(fn, no_self=False, _is_init=False):
         raise TypeError("Can't inspect callable: %s" % fn)
 
 
-def format_argspec_plus(fn, grouped=True):
+def format_argspec_plus(
+    fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
+) -> Dict[str, Optional[str]]:
     """Returns a dictionary of formatted, introspected function arguments.
 
     A enhanced variant of inspect.formatargspec to support code generation.
@@ -474,11 +546,14 @@ def format_argspec_plus(fn, grouped=True):
 
     num_defaults = 0
     if spec[3]:
-        num_defaults += len(spec[3])
+        num_defaults += len(cast(Tuple[Any], spec[3]))
     if spec[4]:
         num_defaults += len(spec[4])
+
     name_args = spec[0] + spec[4]
 
+    defaulted_vals: Union[List[str], Tuple[()]]
+
     if num_defaults:
         defaulted_vals = name_args[0 - num_defaults :]
     else:
@@ -489,7 +564,7 @@ def format_argspec_plus(fn, grouped=True):
         spec[1],
         spec[2],
         defaulted_vals,
-        formatvalue=lambda x: "=" + x,
+        formatvalue=lambda x: "=" + str(x),
     )
 
     if spec[0]:
@@ -498,7 +573,7 @@ def format_argspec_plus(fn, grouped=True):
             spec[1],
             spec[2],
             defaulted_vals,
-            formatvalue=lambda x: "=" + x,
+            formatvalue=lambda x: "=" + str(x),
         )
     else:
         apply_kw_proxied = apply_kw
@@ -570,7 +645,7 @@ def create_proxy_methods(
 
     def decorate(cls):
         def instrument(name, clslevel=False):
-            fn = getattr(target_cls, name)
+            fn = cast(Callable[..., Any], getattr(target_cls, name))
             spec = compat.inspect_getfullargspec(fn)
             env = {"__name__": fn.__module__}
 
@@ -599,7 +674,9 @@ def create_proxy_methods(
                     % metadata
                 )
 
-            proxy_fn = _exec_code_in_env(code, env, fn.__name__)
+            proxy_fn = cast(
+                Callable[..., Any], _exec_code_in_env(code, env, fn.__name__)
+            )
             proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
             proxy_fn.__doc__ = inject_docstring_text(
                 fn.__doc__,
@@ -721,7 +798,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
         except TypeError:
             continue
         else:
-            default_len = spec.defaults and len(spec.defaults) or 0
+            default_len = len(spec.defaults) if spec.defaults else 0
             if i == 0:
                 if spec.varargs:
                     vargs = spec.varargs
@@ -735,6 +812,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
                 )
 
             if default_len:
+                assert spec.defaults
                 kw_args.update(
                     [
                         (arg, default)
@@ -811,9 +889,6 @@ def class_hierarchy(cls):
     class_hierarchy(class A(object)) returns (A, object), not A plus every
     class systemwide that derives from object.
 
-    Old-style classes are discarded and hierarchies rooted on them
-    will not be descended.
-
     """
 
     hier = {cls}
@@ -829,7 +904,15 @@ def class_hierarchy(cls):
         if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
             continue
 
-        for s in [_ for _ in c.__subclasses__() if _ not in hier]:
+        for s in [
+            _
+            for _ in (
+                c.__subclasses__()
+                if not issubclass(c, type)
+                else c.__subclasses__(c)
+            )
+            if _ not in hier
+        ]:
             process.append(s)
             hier.add(s)
     return list(hier)
@@ -886,10 +969,12 @@ def monkeypatch_proxied_specials(
 
     for method in dunders:
         try:
-            fn = getattr(from_cls, method)
-            if not hasattr(fn, "__call__"):
+            maybe_fn = getattr(from_cls, method)
+            if not hasattr(maybe_fn, "__call__"):
                 continue
-            fn = getattr(fn, "__func__", fn)
+            maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
+            fn = cast(Callable[..., Any], maybe_fn)
+
         except AttributeError:
             continue
         try:
@@ -1021,7 +1106,7 @@ class memoized_property(Generic[_T]):
     __name__: str
 
     def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
-        self.fget = fget
+        self.fget = fget  # type: ignore[assignment]
         self.__doc__ = doc or fget.__doc__
         self.__name__ = fget.__name__
 
@@ -1041,7 +1126,7 @@ class memoized_property(Generic[_T]):
         if obj is None:
             return self
         obj.__dict__[self.__name__] = result = self.fget(obj)
-        return result
+        return result  # type: ignore
 
     def _reset(self, obj):
         memoized_property.reset(obj, self.__name__)
@@ -1082,7 +1167,7 @@ class HasMemoized:
 
     __slots__ = ()
 
-    _memoized_keys = frozenset()
+    _memoized_keys: FrozenSet[str] = frozenset()
 
     def _reset_memoizations(self):
         for elem in self._memoized_keys:
@@ -1104,7 +1189,8 @@ class HasMemoized:
         __name__: str
 
         def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
-            self.fget = fget
+            # https://github.com/python/mypy/issues/708
+            self.fget = fget  # type: ignore
             self.__doc__ = doc or fget.__doc__
             self.__name__ = fget.__name__
 
@@ -1268,7 +1354,7 @@ def constructor_copy(obj, cls, *args, **kw):
 def counter():
     """Return a threadsafe counter function."""
 
-    lock = compat.threading.Lock()
+    lock = threading.Lock()
     counter = itertools.count(1)
 
     # avoid the 2to3 "next" transformation...
@@ -1362,12 +1448,14 @@ class classproperty(property):
 
     """
 
-    def __init__(self, fget, *arg, **kw):
+    fget: Callable[[Any], Any]
+
+    def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
         super(classproperty, self).__init__(fget, *arg, **kw)
         self.__doc__ = fget.__doc__
 
-    def __get__(desc, self, cls):
-        return desc.fget(cls)
+    def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
+        return self.fget(cls)  # type: ignore
 
 
 class hybridproperty:
@@ -1406,7 +1494,9 @@ class hybridmethod:
 
 
 class _symbol(int):
-    def __new__(self, name, doc=None, canonical=None):
+    name: str
+
+    def __new__(cls, name, doc=None, canonical=None):
         """Construct a new named symbol."""
         assert isinstance(name, str)
         if canonical is None:
@@ -1452,8 +1542,8 @@ class symbol:
 
     """
 
-    symbols = {}
-    _lock = compat.threading.Lock()
+    symbols: Dict[str, "_symbol"] = {}
+    _lock = threading.Lock()
 
     def __new__(cls, name, doc=None, canonical=None):
         with cls._lock:
@@ -1546,6 +1636,8 @@ class _hash_limit_string(str):
 
     """
 
+    _hash: int
+
     def __new__(cls, value, num, args):
         interpolated = (value % args) + (
             " (this warning may be suppressed after %d occurrences)" % num
@@ -1731,8 +1823,8 @@ class EnsureKWArg:
         super().__init_subclass__()
 
     @classmethod
-    def _wrap_w_kw(cls, fn):
-        def wrap(*arg, **kw):
+    def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
+        def wrap(*arg: Any, **kw: Any) -> Any:
             return fn(*arg)
 
         return update_wrapper(wrap, fn)
@@ -1910,15 +2002,12 @@ def repr_tuple_names(names):
 
 
 def has_compiled_ext(raise_=False):
-    try:
-        from sqlalchemy.cyextension import collections  # noqa F401
-        from sqlalchemy.cyextension import immutabledict  # noqa F401
-        from sqlalchemy.cyextension import processors  # noqa F401
-        from sqlalchemy.cyextension import resultproxy  # noqa F401
-        from sqlalchemy.cyextension import util  # noqa F401
-
+    if HAS_CYEXTENSION:
         return True
-    except ImportError:
-        if raise_:
-            raise
+    elif raise_:
+        raise ImportError(
+            "cython extensions were expected to be installed, "
+            "but are not present"
+        )
+    else:
         return False
index d2cd0a1a7114c62d3b75c16f59db8f230bca2dba..3062d9d8ab00081b279a5a2ebce9c5724e229488 100644 (file)
@@ -17,12 +17,11 @@ producing a ``put()`` inside the ``get()`` and therefore a reentrant
 condition.
 
 """
-
+import asyncio
 from collections import deque
+import threading
 from time import time as _time
 
-from .compat import threading
-from .concurrency import asyncio
 from .concurrency import await_fallback
 from .concurrency import await_only
 from .langhelpers import memoized_property
index 5767d258b0a27ad55de57dc3be1bb2384223bfa3..62a9f6c8a8dc8f642959faa623a56590ba3d1812 100644 (file)
@@ -1,3 +1,4 @@
+import typing
 from typing import Any
 from typing import Callable  # noqa
 from typing import Generic
@@ -12,27 +13,29 @@ from . import compat
 
 _T = TypeVar("_T", bound=Any)
 
-if compat.py38:
-    from typing import Literal
-    from typing import Protocol
-    from typing import TypedDict
+if typing.TYPE_CHECKING or not compat.py38:
+    from typing_extensions import Literal  # noqa F401
+    from typing_extensions import Protocol  # noqa F401
+    from typing_extensions import TypedDict  # noqa F401
 else:
-    from typing_extensions import Literal  # noqa
-    from typing_extensions import Protocol  # noqa
-    from typing_extensions import TypedDict  # noqa
+    from typing import Literal  # noqa F401
+    from typing import Protocol  # noqa F401
+    from typing import TypedDict  # noqa F401
 
-if compat.py310:
-    from typing import Concatenate
-    from typing import ParamSpec
+if typing.TYPE_CHECKING or not compat.py310:
+    from typing_extensions import Concatenate  # noqa F401
+    from typing_extensions import ParamSpec  # noqa F401
 else:
-    from typing_extensions import Concatenate  # noqa
-    from typing_extensions import ParamSpec  # noqa
+    from typing import Concatenate  # noqa F401
+    from typing import ParamSpec  # noqa F401
 
 
-_T = TypeVar("_T")
+class _TypeToInstance(Generic[_T]):
+    """describe a variable that moves between a class and an instance of
+    that class.
 
+    """
 
-class _TypeToInstance(Generic[_T]):
     @overload
     def __get__(self, instance: None, owner: Any) -> Type[_T]:
         ...
@@ -41,6 +44,9 @@ class _TypeToInstance(Generic[_T]):
     def __get__(self, instance: object, owner: Any) -> _T:
         ...
 
+    def __get__(self, instance: object, owner: Any) -> Union[Type[_T], _T]:
+        ...
+
     @overload
     def __set__(self, instance: None, value: Type[_T]) -> None:
         ...
@@ -49,6 +55,9 @@ class _TypeToInstance(Generic[_T]):
     def __set__(self, instance: object, value: _T) -> None:
         ...
 
+    def __set__(self, instance: object, value: Union[Type[_T], _T]) -> None:
+        ...
+
 
 class ReadOnlyInstanceDescriptor(Protocol[_T]):
     """protocol representing an instance-only descriptor"""
index 2707bea97c9438de6be3ca7bdcfa8ef8eed65a75..036892d45bf255ce230f5ad83c40ae07b820da07 100644 (file)
@@ -23,3 +23,67 @@ filterwarnings = [
     "error::DeprecationWarning:test",
     "error::DeprecationWarning:sqlalchemy"
 ]
+
+
+[tool.pyright]
+include = [
+    "lib/sqlalchemy/events.py",
+    "lib/sqlalchemy/exc.py",
+    "lib/sqlalchemy/log.py",
+    "lib/sqlalchemy/inspection.py",
+    "lib/sqlalchemy/schema.py",
+    "lib/sqlalchemy/types.py",
+    "lib/sqlalchemy/util/",
+]
+
+
+
+[tool.mypy]
+mypy_path = "./lib/"
+show_error_codes = true
+strict = false
+incremental = true
+
+# disabled checking
+[[tool.mypy.overrides]]
+module="sqlalchemy.*"
+ignore_errors = true
+warn_unused_ignores = false
+
+strict = true
+
+# https://github.com/python/mypy/issues/8754
+# we are a pep-561 package, so implicit-rexport should be
+# enabled
+implicit_reexport = true
+
+# individual packages or even modules should be listed here
+# with strictness-specificity set up.  there's no way we are going to get
+# the whole library 100% strictly typed, so we have to tune this based on
+# the type of module or package we are dealing with
+
+# strict checking
+[[tool.mypy.overrides]]
+module = [
+    "sqlalchemy.events",
+    "sqlalchemy.events",
+    "sqlalchemy.exc",
+    "sqlalchemy.inspection",
+    "sqlalchemy.schema",
+    "sqlalchemy.types",
+]
+ignore_errors = false
+strict = true
+
+# partial checking, internals can be untyped
+[[tool.mypy.overrides]]
+module="sqlalchemy.util.*"
+ignore_errors = false
+
+# util is for internal use so we can get by without everything
+# being typed
+allow_untyped_defs = true
+check_untyped_defs = false
+allow_untyped_calls = true
+
+
index 3f903eb62f6619fab6df13ef15e89077420064c4..2eceb0b816b2e41dd27a7ef6e049d7e65ab87614 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -114,18 +114,6 @@ per-file-ignores =
     lib/sqlalchemy/types.py:F401
     lib/sqlalchemy/sql/expression.py:F401
 
-[mypy]
-mypy_path = ./lib/
-strict = True
-incremental = True
-#plugins = sqlalchemy.ext.mypy.plugin
-
-[mypy-sqlalchemy.*]
-ignore_errors = True
-
-[mypy-sqlalchemy.ext.mypy.*]
-ignore_errors = False
-
 [sqla_testing]
 requirement_cls = test.requirements:DefaultRequirements
 profile_file = test/profiles.txt
index a88b7c56c5ce8d20020371817b29e125f02ecef2..dc02c37cb0006f2be57c689aaf88cc6ad5c8aef0 100644 (file)
@@ -162,7 +162,20 @@ class OrderedSetTest(fixtures.TestBase):
 
         eq_(o.difference(iter([3, 4])), util.OrderedSet([2, 5]))
         eq_(o.intersection(iter([3, 4, 6])), util.OrderedSet([3, 4]))
-        eq_(o.union(iter([3, 4, 6])), util.OrderedSet([2, 3, 4, 5, 6]))
+        eq_(o.union(iter([3, 4, 6])), util.OrderedSet([3, 2, 4, 5, 6]))
+
+    def test_len(self):
+        eq_(len(util.OrderedSet([1, 2, 3])), 3)
+
+    def test_eq_no_insert_order(self):
+        eq_(util.OrderedSet([3, 2, 4, 5]), util.OrderedSet([2, 3, 4, 5]))
+
+        ne_(util.OrderedSet([3, 2, 4, 5]), util.OrderedSet([3, 2, 4, 5, 6]))
+
+    def test_eq_non_ordered_set(self):
+        eq_(util.OrderedSet([3, 2, 4, 5]), {2, 3, 4, 5})
+
+        ne_(util.OrderedSet([3, 2, 4, 5]), {3, 2, 4, 5, 6})
 
     def test_repr(self):
         o = util.OrderedSet([])
@@ -295,7 +308,7 @@ class ImmutableTest(fixtures.TestBase):
             lambda: d.update({2: 4}),
         )
         if hasattr(d, "pop"):
-            calls += (d.pop, d.popitem)
+            calls += (lambda: d.pop(2), d.popitem)
         for m in calls:
             with expect_raises_message(TypeError, "object is immutable"):
                 m()
index 59ebc87e2bbf131b18c0ed61306d038ddf7746c0..59bc4863fb0440ff09bf153b618a401c88b1c989 100644 (file)
@@ -3333,6 +3333,8 @@ class OnConnectTest(fixtures.TestBase):
         cls_ = testing.db.dialect.__class__
 
         class SomeDialect(cls_):
+            supports_statement_cache = True
+
             def initialize(self, connection):
                 super(SomeDialect, self).initialize(connection)
                 m1.append("initialize")
index dc04d4da65f6fd1abdedf63fac04eced696118aa..a83ca4194751d4588f873062ff9925010d142de4 100644 (file)
@@ -31,20 +31,9 @@ from sqlalchemy.testing import not_in
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
-from sqlalchemy.util import has_compiled_ext
-from sqlalchemy.util import OrderedSet
 from test.orm import _fixtures
 
 
-if has_compiled_ext():
-    # cython ordered set is immutable, subclass it with a python
-    # class so that its method can be replaced
-    _OrderedSet = OrderedSet
-
-    class OrderedSet(_OrderedSet):
-        pass
-
-
 class MergeTest(_fixtures.FixtureTest):
     """Session.merge() functionality"""
 
@@ -167,7 +156,7 @@ class MergeTest(_fixtures.FixtureTest):
             users,
             properties={
                 "addresses": relationship(
-                    Address, backref="user", collection_class=OrderedSet
+                    Address, backref="user", collection_class=set
                 )
             },
         )
@@ -178,12 +167,10 @@ class MergeTest(_fixtures.FixtureTest):
         u = User(
             id=7,
             name="fred",
-            addresses=OrderedSet(
-                [
-                    Address(id=1, email_address="fred1"),
-                    Address(id=2, email_address="fred2"),
-                ]
-            ),
+            addresses={
+                Address(id=1, email_address="fred1"),
+                Address(id=2, email_address="fred2"),
+            },
         )
         eq_(load.called, 0)
 
@@ -203,12 +190,10 @@ class MergeTest(_fixtures.FixtureTest):
             User(
                 id=7,
                 name="fred",
-                addresses=OrderedSet(
-                    [
-                        Address(id=1, email_address="fred1"),
-                        Address(id=2, email_address="fred2"),
-                    ]
-                ),
+                addresses={
+                    Address(id=1, email_address="fred1"),
+                    Address(id=2, email_address="fred2"),
+                },
             ),
         )
 
@@ -258,7 +243,7 @@ class MergeTest(_fixtures.FixtureTest):
             users,
             properties={
                 "addresses": relationship(
-                    Address, backref="user", collection_class=OrderedSet
+                    Address, backref="user", collection_class=set
                 )
             },
         )
@@ -269,12 +254,10 @@ class MergeTest(_fixtures.FixtureTest):
         u = User(
             id=None,
             name="fred",
-            addresses=OrderedSet(
-                [
-                    Address(id=None, email_address="fred1"),
-                    Address(id=None, email_address="fred2"),
-                ]
-            ),
+            addresses={
+                Address(id=None, email_address="fred1"),
+                Address(id=None, email_address="fred2"),
+            },
         )
         eq_(load.called, 0)
 
@@ -293,12 +276,10 @@ class MergeTest(_fixtures.FixtureTest):
             sess.query(User).one(),
             User(
                 name="fred",
-                addresses=OrderedSet(
-                    [
-                        Address(email_address="fred1"),
-                        Address(email_address="fred2"),
-                    ]
-                ),
+                addresses={
+                    Address(email_address="fred1"),
+                    Address(email_address="fred2"),
+                },
             ),
         )
 
@@ -341,8 +322,7 @@ class MergeTest(_fixtures.FixtureTest):
                 "addresses": relationship(
                     Address,
                     backref="user",
-                    collection_class=OrderedSet,
-                    order_by=addresses.c.id,
+                    collection_class=set,
                     cascade="all, delete-orphan",
                 )
             },
@@ -355,12 +335,10 @@ class MergeTest(_fixtures.FixtureTest):
         u = User(
             id=7,
             name="fred",
-            addresses=OrderedSet(
-                [
-                    Address(id=1, email_address="fred1"),
-                    Address(id=2, email_address="fred2"),
-                ]
-            ),
+            addresses={
+                Address(id=1, email_address="fred1"),
+                Address(id=2, email_address="fred2"),
+            },
         )
         sess = fixture_session()
         sess.add(u)
@@ -372,12 +350,10 @@ class MergeTest(_fixtures.FixtureTest):
         u = User(
             id=7,
             name="fred",
-            addresses=OrderedSet(
-                [
-                    Address(id=3, email_address="fred3"),
-                    Address(id=4, email_address="fred4"),
-                ]
-            ),
+            addresses={
+                Address(id=3, email_address="fred3"),
+                Address(id=4, email_address="fred4"),
+            },
         )
 
         u = sess.merge(u)
@@ -393,12 +369,10 @@ class MergeTest(_fixtures.FixtureTest):
             User(
                 id=7,
                 name="fred",
-                addresses=OrderedSet(
-                    [
-                        Address(id=3, email_address="fred3"),
-                        Address(id=4, email_address="fred4"),
-                    ]
-                ),
+                addresses={
+                    Address(id=3, email_address="fred3"),
+                    Address(id=4, email_address="fred4"),
+                },
             ),
         )
         sess.flush()
@@ -408,12 +382,10 @@ class MergeTest(_fixtures.FixtureTest):
             User(
                 id=7,
                 name="fred",
-                addresses=OrderedSet(
-                    [
-                        Address(id=3, email_address="fred3"),
-                        Address(id=4, email_address="fred4"),
-                    ]
-                ),
+                addresses={
+                    Address(id=3, email_address="fred3"),
+                    Address(id=4, email_address="fred4"),
+                },
             ),
         )
 
@@ -433,7 +405,7 @@ class MergeTest(_fixtures.FixtureTest):
                     Address,
                     backref="user",
                     order_by=addresses.c.id,
-                    collection_class=OrderedSet,
+                    collection_class=set,
                 )
             },
         )
@@ -445,7 +417,7 @@ class MergeTest(_fixtures.FixtureTest):
         u = User(
             id=7,
             name="fred",
-            addresses=OrderedSet([a, Address(id=2, email_address="fred2")]),
+            addresses={a, Address(id=2, email_address="fred2")},
         )
         sess = fixture_session()
         sess.add(u)
@@ -467,12 +439,10 @@ class MergeTest(_fixtures.FixtureTest):
             User(
                 id=7,
                 name="fred jones",
-                addresses=OrderedSet(
-                    [
-                        Address(id=2, email_address="fred2"),
-                        Address(id=3, email_address="fred3"),
-                    ]
-                ),
+                addresses={
+                    Address(id=2, email_address="fred2"),
+                    Address(id=3, email_address="fred3"),
+                },
             ),
         )
 
diff --git a/tox.ini b/tox.ini
index e55a43cbbeef2c7fba97f870553ea076315c2612..2100aa507e367ce99ddcbbdb156c65f6b0e34f66 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -128,6 +128,16 @@ commands=
   oracle,mssql,sqlite_file: python reap_dbs.py db_idents.txt
 
 
+[testenv:pep484]
+deps=
+     greenlet != 0.4.17
+     importlib_metadata; python_version < '3.8'
+     mypy
+     pyright
+commands =
+    mypy  ./lib/sqlalchemy
+    pyright
+
 [testenv:mypy]
 deps=
      pytest>=7.0.0rc1,<8
@@ -158,6 +168,7 @@ commands =
      flake8 ./lib/ ./test/ ./examples/ setup.py doc/build/conf.py {posargs}
      black --check ./lib/ ./test/ ./examples/ setup.py doc/build/conf.py
 
+
 # command run in the github action when cext are active.
 [testenv:github-cext]
 deps = {[testenv]deps}