]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep-484: sqlalchemy.sql pass one
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Mar 2022 22:14:41 +0000 (17:14 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 13 Mar 2022 19:29:20 +0000 (15:29 -0400)
sqlalchemy.sql will require many passes to get all
modules even gradually typed.  Will have to pick and
choose what modules can be strictly typed vs. which
can be gradual.

in this patch, emphasis is on visitors.py, cache_key.py,
annotations.py for strict typing, compiler.py is on gradual
typing but has much more structure, in particular where it
connects with the outside world.

The work within compiler.py also reached back out to
engine/cursor.py , default.py quite a bit.

References: #6810
Change-Id: I6e8a29f6013fd216e43d45091bc193f8be0368fd

25 files changed:
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/sql/_py_util.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/cache_key.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/typing.py
pyproject.toml

index 821c0cb8e3334b2bd23d410a58b8891956e10588..f776e597538bb01355a60a77056546c8450a182f 100644 (file)
@@ -604,18 +604,20 @@ class CursorResultMetaData(ResultMetaData):
         cls,
         result_columns: List[ResultColumnsEntry],
         loose_column_name_matching: bool = False,
-    ) -> Dict[Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]]:
+    ) -> Dict[
+        Union[str, object], Tuple[str, Tuple[Any, ...], TypeEngine[Any], int]
+    ]:
         """when matching cursor.description to a set of names that are present
         in a Compiled object, as is the case with TextualSelect, get all the
         names we expect might match those in cursor.description.
         """
 
         d: Dict[
-            Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]
+            Union[str, object],
+            Tuple[str, Tuple[Any, ...], TypeEngine[Any], int],
         ] = {}
         for ridx, elem in enumerate(result_columns):
             key = elem[RM_RENDERED_NAME]
-
             if key in d:
                 # conflicting keyname - just add the column-linked objects
                 # to the existing record.  if there is a duplicate column
index 2579f573c5f2bb90b3e22a38d462b29d926afb53..c9fb1ebf2c7bc080fb8962c042f866899d5b76f2 100644 (file)
@@ -46,6 +46,7 @@ from .interfaces import ExecutionContext
 from .. import event
 from .. import exc
 from .. import pool
+from .. import TupleType
 from .. import types as sqltypes
 from .. import util
 from ..sql import compiler
@@ -76,6 +77,8 @@ if typing.TYPE_CHECKING:
     from ..sql.compiler import Compiled
     from ..sql.compiler import ResultColumnsEntry
     from ..sql.compiler import TypeCompiler
+    from ..sql.dml import DMLState
+    from ..sql.elements import BindParameter
     from ..sql.schema import Column
     from ..sql.type_api import TypeEngine
 
@@ -820,7 +823,7 @@ class DefaultExecutionContext(ExecutionContext):
     cursor: DBAPICursor
     compiled_parameters: List[_MutableCoreSingleExecuteParams]
     parameters: _DBAPIMultiExecuteParams
-    extracted_parameters: _CoreSingleExecuteParams
+    extracted_parameters: Optional[Sequence[BindParameter[Any]]]
 
     _empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT)
 
@@ -878,7 +881,7 @@ class DefaultExecutionContext(ExecutionContext):
         compiled: SQLCompiler,
         parameters: _CoreMultiExecuteParams,
         invoked_statement: Executable,
-        extracted_parameters: _CoreSingleExecuteParams,
+        extracted_parameters: Optional[Sequence[BindParameter[Any]]],
         cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
     ) -> ExecutionContext:
         """Initialize execution context for a Compiled construct."""
@@ -1513,9 +1516,10 @@ class DefaultExecutionContext(ExecutionContext):
                 inputsizes, self.cursor, self.statement, self.parameters, self
             )
 
-        has_escaped_names = bool(compiled.escaped_bind_names)
-        if has_escaped_names:
+        if compiled.escaped_bind_names:
             escaped_bind_names = compiled.escaped_bind_names
+        else:
+            escaped_bind_names = None
 
         if dialect.positional:
             items = [
@@ -1535,17 +1539,18 @@ class DefaultExecutionContext(ExecutionContext):
 
             if key in self._expanded_parameters:
                 if bindparam.type._is_tuple_type:
-                    num = len(bindparam.type.types)
+                    tup_type = cast(TupleType, bindparam.type)
+                    num = len(tup_type.types)
                     dbtypes = inputsizes[bindparam]
                     generic_inputsizes.extend(
                         (
                             (
                                 escaped_bind_names.get(paramname, paramname)
-                                if has_escaped_names
+                                if escaped_bind_names is not None
                                 else paramname
                             ),
                             dbtypes[idx % num],
-                            bindparam.type.types[idx % num],
+                            tup_type.types[idx % num],
                         )
                         for idx, paramname in enumerate(
                             self._expanded_parameters[key]
@@ -1557,7 +1562,7 @@ class DefaultExecutionContext(ExecutionContext):
                         (
                             (
                                 escaped_bind_names.get(paramname, paramname)
-                                if has_escaped_names
+                                if escaped_bind_names is not None
                                 else paramname
                             ),
                             dbtype,
@@ -1570,7 +1575,7 @@ class DefaultExecutionContext(ExecutionContext):
 
                 escaped_name = (
                     escaped_bind_names.get(key, key)
-                    if has_escaped_names
+                    if escaped_bind_names is not None
                     else key
                 )
 
@@ -1702,7 +1707,9 @@ class DefaultExecutionContext(ExecutionContext):
         else:
             assert column is not None
             assert parameters is not None
-        compile_state = cast(SQLCompiler, self.compiled).compile_state
+        compile_state = cast(
+            "DMLState", cast(SQLCompiler, self.compiled).compile_state
+        )
         assert compile_state is not None
         if (
             isolate_multiinsert_groups
@@ -1715,6 +1722,7 @@ class DefaultExecutionContext(ExecutionContext):
             else:
                 d = {column.key: parameters[column.key]}
                 index = 0
+            assert compile_state._dict_parameters is not None
             keys = compile_state._dict_parameters.keys()
             d.update(
                 (key, parameters["%s_m%d" % (key, index)]) for key in keys
index e65546eb777ffa55af9c65ef9c5dcfcd9b1dd44f..e13295d6d62522bb0989ec770cfe1656385fcc89 100644 (file)
@@ -54,9 +54,10 @@ if TYPE_CHECKING:
     from ..sql.compiler import IdentifierPreparer
     from ..sql.compiler import Linting
     from ..sql.compiler import SQLCompiler
+    from ..sql.elements import BindParameter
     from ..sql.elements import ClauseElement
     from ..sql.schema import Column
-    from ..sql.schema import ColumnDefault
+    from ..sql.schema import DefaultGenerator
     from ..sql.schema import Sequence as Sequence_SchemaItem
     from ..sql.sqltypes import Integer
     from ..sql.type_api import TypeEngine
@@ -813,6 +814,9 @@ class Dialect(EventTarget):
 
     """
 
+    _supports_statement_cache: bool
+    """internal evaluation for supports_statement_cache"""
+
     bind_typing = BindTyping.NONE
     """define a means of passing typing information to the database and/or
     driver for bound parameters.
@@ -2269,7 +2273,7 @@ class ExecutionContext:
         compiled: SQLCompiler,
         parameters: _CoreMultiExecuteParams,
         invoked_statement: Executable,
-        extracted_parameters: _CoreSingleExecuteParams,
+        extracted_parameters: Optional[Sequence[BindParameter[Any]]],
         cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
     ) -> ExecutionContext:
         raise NotImplementedError()
@@ -2299,7 +2303,7 @@ class ExecutionContext:
     def _exec_default(
         self,
         column: Optional[Column[Any]],
-        default: ColumnDefault,
+        default: DefaultGenerator,
         type_: Optional[TypeEngine[Any]],
     ) -> Any:
         raise NotImplementedError()
index 87d3cac1c779fe7fff852fdad2e5fdd33d4b25e8..d428b8a9d4361846e54e252e792c7fff41ff45d1 100644 (file)
@@ -17,6 +17,7 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Iterable
 from typing import Iterator
 from typing import List
 from typing import NoReturn
@@ -326,7 +327,7 @@ class SimpleResultMetaData(ResultMetaData):
 
 def result_tuple(
     fields: Sequence[str], extra: Optional[Any] = None
-) -> Callable[[_RawRowType], Row]:
+) -> Callable[[Iterable[Any]], Row]:
     parent = SimpleResultMetaData(fields, extra)
     return functools.partial(
         Row, parent, parent._processors, parent._keymap, Row._default_key_style
index cc78e0971ce835214b2182f082363a16a8767d49..8f4b963eba8ab6dd8b856c668a65455e3dbead62 100644 (file)
@@ -33,6 +33,7 @@ if typing.TYPE_CHECKING:
     from .engine.interfaces import _DBAPIAnyExecuteParams
     from .engine.interfaces import Dialect
     from .sql.compiler import Compiled
+    from .sql.compiler import TypeCompiler
     from .sql.elements import ClauseElement
 
 if typing.TYPE_CHECKING:
@@ -221,8 +222,8 @@ class UnsupportedCompilationError(CompileError):
 
     def __init__(
         self,
-        compiler: "Compiled",
-        element_type: Type["ClauseElement"],
+        compiler: Union[Compiled, TypeCompiler],
+        element_type: Type[ClauseElement],
         message: Optional[str] = None,
     ):
         super(UnsupportedCompilationError, self).__init__(
index 709c13c1468e33bee1bd289d277a77dbb03454b8..e490a4f03d8b30c5469f154812c6d27a9efb4fd5 100644 (file)
@@ -59,6 +59,7 @@ from ..util.typing import Literal
 from ..util.typing import Protocol
 from ..util.typing import Self
 from ..util.typing import SupportsIndex
+from ..util.typing import SupportsKeysAndGetItem
 
 if typing.TYPE_CHECKING:
     from ..orm.attributes import InstrumentedAttribute
@@ -1660,7 +1661,9 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]):
         return (item[0], self._get(item[1]))
 
     @overload
-    def update(self, __m: Mapping[_KT, _VT], **kwargs: _VT) -> None:
+    def update(
+        self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT
+    ) -> None:
         ...
 
     @overload
index 96e8f6b2c796cb1892bc54b17f431398c08507f3..9f18b882d707e5f3cd95705a47156798e0d85756 100644 (file)
@@ -7,7 +7,16 @@
 
 from __future__ import annotations
 
+import typing
+from typing import Any
 from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+    from .cache_key import CacheConst
 
 
 class prefix_anon_map(Dict[str, str]):
@@ -22,16 +31,18 @@ class prefix_anon_map(Dict[str, str]):
 
     """
 
-    def __missing__(self, key):
+    def __missing__(self, key: str) -> str:
         (ident, derived) = key.split(" ", 1)
         anonymous_counter = self.get(derived, 1)
-        self[derived] = anonymous_counter + 1
+        self[derived] = anonymous_counter + 1  # type: ignore
         value = f"{derived}_{anonymous_counter}"
         self[key] = value
         return value
 
 
-class cache_anon_map(Dict[int, str]):
+class cache_anon_map(
+    Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]]
+):
     """A map that creates new keys for missing key access.
 
     Produces an incrementing sequence given a series of unique keys.
@@ -45,11 +56,13 @@ class cache_anon_map(Dict[int, str]):
 
     _index = 0
 
-    def get_anon(self, object_):
+    def get_anon(self, object_: Any) -> Tuple[str, bool]:
 
         idself = id(object_)
         if idself in self:
-            return self[idself], True
+            s_val = self[idself]
+            assert s_val is not True
+            return s_val, True
         else:
             # inline of __missing__
             self[idself] = id_ = str(self._index)
@@ -57,7 +70,7 @@ class cache_anon_map(Dict[int, str]):
 
             return id_, False
 
-    def __missing__(self, key):
+    def __missing__(self, key: int) -> str:
         self[key] = val = str(self._index)
         self._index += 1
         return val
index b76393ad6bd365092f676e04b001da2a73b90b53..7afc2de9776b5e8ba37108e0c83e5f2e78e4912c 100644 (file)
@@ -13,22 +13,77 @@ associations.
 
 from __future__ import annotations
 
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Mapping
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
+
 from . import operators
-from .base import HasCacheKey
-from .traversals import anon_map
+from .cache_key import HasCacheKey
+from .visitors import anon_map
+from .visitors import ExternallyTraversible
 from .visitors import InternalTraversal
 from .. import util
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+    from .visitors import _TraverseInternalsType
+    from ..util.typing import Self
+
+_AnnotationDict = Mapping[str, Any]
+
+EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT
+
 
-EMPTY_ANNOTATIONS = util.immutabledict()
+SelfSupportsAnnotations = TypeVar(
+    "SelfSupportsAnnotations", bound="SupportsAnnotations"
+)
 
 
-class SupportsAnnotations:
+class SupportsAnnotations(ExternallyTraversible):
     __slots__ = ()
 
-    _annotations = EMPTY_ANNOTATIONS
+    _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS
+    proxy_set: Set[SupportsAnnotations]
+    _is_immutable: bool
+
+    def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations:
+        raise NotImplementedError()
+
+    @overload
+    def _deannotate(
+        self: SelfSupportsAnnotations,
+        values: Literal[None] = ...,
+        clone: bool = ...,
+    ) -> SelfSupportsAnnotations:
+        ...
+
+    @overload
+    def _deannotate(
+        self,
+        values: Sequence[str] = ...,
+        clone: bool = ...,
+    ) -> SupportsAnnotations:
+        ...
+
+    def _deannotate(
+        self,
+        values: Optional[Sequence[str]] = None,
+        clone: bool = False,
+    ) -> SupportsAnnotations:
+        raise NotImplementedError()
 
     @util.memoized_property
-    def _annotations_cache_key(self):
+    def _annotations_cache_key(self) -> Tuple[Any, ...]:
         anon_map_ = anon_map()
         return (
             "_annotations",
@@ -47,14 +102,22 @@ class SupportsAnnotations:
         )
 
 
+SelfSupportsCloneAnnotations = TypeVar(
+    "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations"
+)
+
+
 class SupportsCloneAnnotations(SupportsAnnotations):
-    __slots__ = ()
+    if not typing.TYPE_CHECKING:
+        __slots__ = ()
 
-    _clone_annotations_traverse_internals = [
+    _clone_annotations_traverse_internals: _TraverseInternalsType = [
         ("_annotations", InternalTraversal.dp_annotations_key)
     ]
 
-    def _annotate(self, values):
+    def _annotate(
+        self: SelfSupportsCloneAnnotations, values: _AnnotationDict
+    ) -> SelfSupportsCloneAnnotations:
         """return a copy of this ClauseElement with annotations
         updated by the given dictionary.
 
@@ -65,7 +128,9 @@ class SupportsCloneAnnotations(SupportsAnnotations):
         new.__dict__.pop("_generate_cache_key", None)
         return new
 
-    def _with_annotations(self, values):
+    def _with_annotations(
+        self: SelfSupportsCloneAnnotations, values: _AnnotationDict
+    ) -> SelfSupportsCloneAnnotations:
         """return a copy of this ClauseElement with annotations
         replaced by the given dictionary.
 
@@ -76,7 +141,27 @@ class SupportsCloneAnnotations(SupportsAnnotations):
         new.__dict__.pop("_generate_cache_key", None)
         return new
 
-    def _deannotate(self, values=None, clone=False):
+    @overload
+    def _deannotate(
+        self: SelfSupportsAnnotations,
+        values: Literal[None] = ...,
+        clone: bool = ...,
+    ) -> SelfSupportsAnnotations:
+        ...
+
+    @overload
+    def _deannotate(
+        self,
+        values: Sequence[str] = ...,
+        clone: bool = ...,
+    ) -> SupportsAnnotations:
+        ...
+
+    def _deannotate(
+        self,
+        values: Optional[Sequence[str]] = None,
+        clone: bool = False,
+    ) -> SupportsAnnotations:
         """return a copy of this :class:`_expression.ClauseElement`
         with annotations
         removed.
@@ -96,24 +181,52 @@ class SupportsCloneAnnotations(SupportsAnnotations):
             return self
 
 
+SelfSupportsWrappingAnnotations = TypeVar(
+    "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations"
+)
+
+
 class SupportsWrappingAnnotations(SupportsAnnotations):
     __slots__ = ()
 
-    def _annotate(self, values):
+    _constructor: Callable[..., SupportsWrappingAnnotations]
+    entity_namespace: Mapping[str, Any]
+
+    def _annotate(self, values: _AnnotationDict) -> Annotated:
         """return a copy of this ClauseElement with annotations
         updated by the given dictionary.
 
         """
-        return Annotated(self, values)
+        return Annotated._as_annotated_instance(self, values)
 
-    def _with_annotations(self, values):
+    def _with_annotations(self, values: _AnnotationDict) -> Annotated:
         """return a copy of this ClauseElement with annotations
         replaced by the given dictionary.
 
         """
-        return Annotated(self, values)
-
-    def _deannotate(self, values=None, clone=False):
+        return Annotated._as_annotated_instance(self, values)
+
+    @overload
+    def _deannotate(
+        self: SelfSupportsAnnotations,
+        values: Literal[None] = ...,
+        clone: bool = ...,
+    ) -> SelfSupportsAnnotations:
+        ...
+
+    @overload
+    def _deannotate(
+        self,
+        values: Sequence[str] = ...,
+        clone: bool = ...,
+    ) -> SupportsAnnotations:
+        ...
+
+    def _deannotate(
+        self,
+        values: Optional[Sequence[str]] = None,
+        clone: bool = False,
+    ) -> SupportsAnnotations:
         """return a copy of this :class:`_expression.ClauseElement`
         with annotations
         removed.
@@ -129,8 +242,11 @@ class SupportsWrappingAnnotations(SupportsAnnotations):
             return self
 
 
-class Annotated:
-    """clones a SupportsAnnotated and applies an 'annotations' dictionary.
+SelfAnnotated = TypeVar("SelfAnnotated", bound="Annotated")
+
+
+class Annotated(SupportsAnnotations):
+    """clones a SupportsAnnotations and applies an 'annotations' dictionary.
 
     Unlike regular clones, this clone also mimics __hash__() and
     __cmp__() of the original element so that it takes its place
@@ -151,21 +267,26 @@ class Annotated:
 
     _is_column_operators = False
 
-    def __new__(cls, *args):
-        if not args:
-            # clone constructor
-            return object.__new__(cls)
-        else:
-            element, values = args
-            # pull appropriate subclass from registry of annotated
-            # classes
-            try:
-                cls = annotated_classes[element.__class__]
-            except KeyError:
-                cls = _new_annotation_type(element.__class__, cls)
-            return object.__new__(cls)
-
-    def __init__(self, element, values):
+    @classmethod
+    def _as_annotated_instance(
+        cls, element: SupportsWrappingAnnotations, values: _AnnotationDict
+    ) -> Annotated:
+        try:
+            cls = annotated_classes[element.__class__]
+        except KeyError:
+            cls = _new_annotation_type(element.__class__, cls)
+        return cls(element, values)
+
+    _annotations: util.immutabledict[str, Any]
+    __element: SupportsWrappingAnnotations
+    _hash: int
+
+    def __new__(cls: Type[SelfAnnotated], *args: Any) -> SelfAnnotated:
+        return object.__new__(cls)
+
+    def __init__(
+        self, element: SupportsWrappingAnnotations, values: _AnnotationDict
+    ):
         self.__dict__ = element.__dict__.copy()
         self.__dict__.pop("_annotations_cache_key", None)
         self.__dict__.pop("_generate_cache_key", None)
@@ -173,11 +294,15 @@ class Annotated:
         self._annotations = util.immutabledict(values)
         self._hash = hash(element)
 
-    def _annotate(self, values):
+    def _annotate(
+        self: SelfAnnotated, values: _AnnotationDict
+    ) -> SelfAnnotated:
         _values = self._annotations.union(values)
         return self._with_annotations(_values)
 
-    def _with_annotations(self, values):
+    def _with_annotations(
+        self: SelfAnnotated, values: util.immutabledict[str, Any]
+    ) -> SelfAnnotated:
         clone = self.__class__.__new__(self.__class__)
         clone.__dict__ = self.__dict__.copy()
         clone.__dict__.pop("_annotations_cache_key", None)
@@ -185,7 +310,27 @@ class Annotated:
         clone._annotations = values
         return clone
 
-    def _deannotate(self, values=None, clone=True):
+    @overload
+    def _deannotate(
+        self: SelfAnnotated,
+        values: Literal[None] = ...,
+        clone: bool = ...,
+    ) -> SelfAnnotated:
+        ...
+
+    @overload
+    def _deannotate(
+        self,
+        values: Sequence[str] = ...,
+        clone: bool = ...,
+    ) -> Annotated:
+        ...
+
+    def _deannotate(
+        self,
+        values: Optional[Sequence[str]] = None,
+        clone: bool = True,
+    ) -> SupportsAnnotations:
         if values is None:
             return self.__element
         else:
@@ -199,14 +344,18 @@ class Annotated:
                 )
             )
 
-    def _compiler_dispatch(self, visitor, **kw):
-        return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
+    if not typing.TYPE_CHECKING:
+        # manually proxy some methods that need extra attention
+        def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any:
+            return self.__element.__class__._compiler_dispatch(
+                self, visitor, **kw
+            )
 
-    @property
-    def _constructor(self):
-        return self.__element._constructor
+        @property
+        def _constructor(self):
+            return self.__element._constructor
 
-    def _clone(self, **kw):
+    def _clone(self: SelfAnnotated, **kw: Any) -> SelfAnnotated:
         clone = self.__element._clone(**kw)
         if clone is self.__element:
             # detect immutable, don't change anything
@@ -217,22 +366,25 @@ class Annotated:
             clone.__dict__.update(self.__dict__)
             return self.__class__(clone, self._annotations)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]:
         return self.__class__, (self.__element, self._annotations)
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return self._hash
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         if self._is_column_operators:
             return self.__element.__class__.__eq__(self, other)
         else:
             return hash(other) == hash(self)
 
     @property
-    def entity_namespace(self):
+    def entity_namespace(self) -> Mapping[str, Any]:
         if "entity_namespace" in self._annotations:
-            return self._annotations["entity_namespace"].entity_namespace
+            return cast(
+                SupportsWrappingAnnotations,
+                self._annotations["entity_namespace"],
+            ).entity_namespace
         else:
             return self.__element.entity_namespace
 
@@ -242,12 +394,19 @@ class Annotated:
 # so that the resulting objects are pickleable; additionally, other
 # decisions can be made up front about the type of object being annotated
 # just once per class rather than per-instance.
-annotated_classes = {}
+annotated_classes: Dict[
+    Type[SupportsWrappingAnnotations], Type[Annotated]
+] = {}
+
+_SA = TypeVar("_SA", bound="SupportsAnnotations")
 
 
 def _deep_annotate(
-    element, annotations, exclude=None, detect_subquery_cols=False
-):
+    element: _SA,
+    annotations: _AnnotationDict,
+    exclude: Optional[Sequence[SupportsAnnotations]] = None,
+    detect_subquery_cols: bool = False,
+) -> _SA:
     """Deep copy the given ClauseElement, annotating each element
     with the given annotations dictionary.
 
@@ -258,9 +417,9 @@ def _deep_annotate(
     # annotated objects hack the __hash__() method so if we want to
     # uniquely process them we have to use id()
 
-    cloned_ids = {}
+    cloned_ids: Dict[int, SupportsAnnotations] = {}
 
-    def clone(elem, **kw):
+    def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
         kw["detect_subquery_cols"] = detect_subquery_cols
         id_ = id(elem)
 
@@ -285,17 +444,20 @@ def _deep_annotate(
         return newelem
 
     if element is not None:
-        element = clone(element)
-    clone = None  # remove gc cycles
+        element = cast(_SA, clone(element))
+    clone = None  # type: ignore  # remove gc cycles
     return element
 
 
-def _deep_deannotate(element, values=None):
+def _deep_deannotate(
+    element: _SA, values: Optional[Sequence[str]] = None
+) -> _SA:
     """Deep copy the given element, removing annotations."""
 
-    cloned = {}
+    cloned: Dict[Any, SupportsAnnotations] = {}
 
-    def clone(elem, **kw):
+    def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
+        key: Any
         if values:
             key = id(elem)
         else:
@@ -310,12 +472,14 @@ def _deep_deannotate(element, values=None):
             return cloned[key]
 
     if element is not None:
-        element = clone(element)
-    clone = None  # remove gc cycles
+        element = cast(_SA, clone(element))
+    clone = None  # type: ignore  # remove gc cycles
     return element
 
 
-def _shallow_annotate(element, annotations):
+def _shallow_annotate(
+    element: SupportsAnnotations, annotations: _AnnotationDict
+) -> SupportsAnnotations:
     """Annotate the given ClauseElement and copy its internals so that
     internal objects refer to the new annotated object.
 
@@ -328,7 +492,13 @@ def _shallow_annotate(element, annotations):
     return element
 
 
-def _new_annotation_type(cls, base_cls):
+def _new_annotation_type(
+    cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated]
+) -> Type[Annotated]:
+    """Generates a new class that subclasses Annotated and proxies a given
+    element type.
+
+    """
     if issubclass(cls, Annotated):
         return cls
     elif cls in annotated_classes:
@@ -342,8 +512,9 @@ def _new_annotation_type(cls, base_cls):
             base_cls = annotated_classes[super_]
             break
 
-    annotated_classes[cls] = anno_cls = type(
-        "Annotated%s" % cls.__name__, (base_cls, cls), {}
+    annotated_classes[cls] = anno_cls = cast(
+        Type[Annotated],
+        type("Annotated%s" % cls.__name__, (base_cls, cls), {}),
     )
     globals()["Annotated%s" % cls.__name__] = anno_cls
 
@@ -359,13 +530,15 @@ def _new_annotation_type(cls, base_cls):
     # some classes include this even if they have traverse_internals
     # e.g. BindParameter, add it if present.
     if cls.__dict__.get("inherit_cache", False):
-        anno_cls.inherit_cache = True
+        anno_cls.inherit_cache = True  # type: ignore
 
     anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
 
     return anno_cls
 
 
-def _prepare_annotations(target_hierarchy, base_cls):
+def _prepare_annotations(
+    target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated]
+) -> None:
     for cls in util.walk_subclasses(target_hierarchy):
         _new_annotation_type(cls, base_cls)
index a94590da1c0c357ac773af6754fb3b2e4479eae6..a408a010a08ff00a4b5b67c13363759626932276 100644 (file)
@@ -19,8 +19,10 @@ from itertools import zip_longest
 import operator
 import re
 import typing
+from typing import MutableMapping
 from typing import Optional
 from typing import Sequence
+from typing import Set
 from typing import TypeVar
 
 from . import roles
@@ -36,14 +38,9 @@ from .. import util
 from ..util import HasMemoized as HasMemoized
 from ..util import hybridmethod
 from ..util import typing as compat_typing
-from ..util._has_cy import HAS_CYEXTENSION
-
-if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
-    from ._py_util import prefix_anon_map  # noqa
-else:
-    from sqlalchemy.cyextension.util import prefix_anon_map  # noqa
 
 if typing.TYPE_CHECKING:
+    from .elements import ColumnElement
     from ..engine import Connection
     from ..engine import Result
     from ..engine.interfaces import _CoreMultiExecuteParams
@@ -63,6 +60,8 @@ NO_ARG = util.symbol("NO_ARG")
 # symbols, mypy reports: "error: _Fn? not callable"
 _Fn = typing.TypeVar("_Fn", bound=typing.Callable)
 
+_AmbiguousTableNameMap = MutableMapping[str, str]
+
 
 class Immutable:
     """mark a ClauseElement as 'immutable' when expressions are cloned."""
@@ -87,6 +86,10 @@ class SingletonConstant(Immutable):
 
     _is_singleton_constant = True
 
+    _singleton: SingletonConstant
+
+    proxy_set: Set[ColumnElement]
+
     def __new__(cls, *arg, **kw):
         return cls._singleton
 
@@ -519,6 +522,8 @@ class CompileState:
 
     plugins = {}
 
+    _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
+
     @classmethod
     def create_for_statement(cls, statement, compiler, **kw):
         # factory construction.
index ff659b77de31fa566b239bbf5e9d3416fdcd2e2a..fca58f98e53cb5a494ff796bc0bd2d171876ce9f 100644 (file)
@@ -11,21 +11,41 @@ import enum
 from itertools import zip_longest
 import typing
 from typing import Any
-from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
 from typing import NamedTuple
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
 from typing import Union
 
 from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import HasTraversalDispatch
+from .visitors import HasTraverseInternals
 from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
 from .. import util
 from ..inspection import inspect
 from ..util import HasMemoized
 from ..util.typing import Literal
-
+from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
     from .elements import BindParameter
+    from .elements import ClauseElement
+    from .visitors import _TraverseInternalsType
+    from ..engine.base import _CompiledCacheType
+    from ..engine.interfaces import _CoreSingleExecuteParams
+
+
+class _CacheKeyTraversalDispatchType(Protocol):
+    def __call__(
+        s, self: HasCacheKey, visitor: _CacheKeyTraversal
+    ) -> CacheKey:
+        ...
 
 
 class CacheConst(enum.Enum):
@@ -70,7 +90,9 @@ class HasCacheKey:
 
     __slots__ = ()
 
-    _cache_key_traversal = NO_CACHE
+    _cache_key_traversal: Union[
+        _TraverseInternalsType, Literal[CacheConst.NO_CACHE]
+    ] = NO_CACHE
 
     _is_has_cache_key = True
 
@@ -83,7 +105,7 @@ class HasCacheKey:
 
     """
 
-    inherit_cache = None
+    inherit_cache: Optional[bool] = None
     """Indicate if this :class:`.HasCacheKey` instance should make use of the
     cache key generation scheme used by its immediate superclass.
 
@@ -106,8 +128,12 @@ class HasCacheKey:
 
     __slots__ = ()
 
+    _generated_cache_key_traversal: Any
+
     @classmethod
-    def _generate_cache_attrs(cls):
+    def _generate_cache_attrs(
+        cls,
+    ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]:
         """generate cache key dispatcher for a new class.
 
         This sets the _generated_cache_key_traversal attribute once called
@@ -121,8 +147,11 @@ class HasCacheKey:
             _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
             if _cache_key_traversal is None:
                 try:
-                    # this would be HasTraverseInternals
-                    _cache_key_traversal = cls._traverse_internals
+                    # check for _traverse_internals, which is part of
+                    # HasTraverseInternals
+                    _cache_key_traversal = cast(
+                        "Type[HasTraverseInternals]", cls
+                    )._traverse_internals
                 except AttributeError:
                     cls._generated_cache_key_traversal = NO_CACHE
                     return NO_CACHE
@@ -138,7 +167,9 @@ class HasCacheKey:
             # more complicated, so for the moment this is a little less
             # efficient on startup but simpler.
             return _cache_key_traversal_visitor.generate_dispatch(
-                cls, _cache_key_traversal, "_generated_cache_key_traversal"
+                cls,
+                _cache_key_traversal,
+                "_generated_cache_key_traversal",
             )
         else:
             _cache_key_traversal = cls.__dict__.get(
@@ -170,11 +201,15 @@ class HasCacheKey:
                     return NO_CACHE
 
             return _cache_key_traversal_visitor.generate_dispatch(
-                cls, _cache_key_traversal, "_generated_cache_key_traversal"
+                cls,
+                _cache_key_traversal,
+                "_generated_cache_key_traversal",
             )
 
     @util.preload_module("sqlalchemy.sql.elements")
-    def _gen_cache_key(self, anon_map, bindparams):
+    def _gen_cache_key(
+        self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+    ) -> Optional[Tuple[Any, ...]]:
         """return an optional cache key.
 
         The cache key is a tuple which can contain any series of
@@ -202,15 +237,15 @@ class HasCacheKey:
 
         dispatcher: Union[
             Literal[CacheConst.NO_CACHE],
-            Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"],
+            _CacheKeyTraversalDispatchType,
         ]
 
         try:
             dispatcher = cls.__dict__["_generated_cache_key_traversal"]
         except KeyError:
-            # most of the dispatchers are generated up front
-            # in sqlalchemy/sql/__init__.py ->
-            # traversals.py-> _preconfigure_traversals().
+            # traversals.py -> _preconfigure_traversals()
+            # may be used to run these ahead of time, but
+            # is not enabled right now.
             # this block will generate any remaining dispatchers.
             dispatcher = cls._generate_cache_attrs()
 
@@ -218,7 +253,7 @@ class HasCacheKey:
             anon_map[NO_CACHE] = True
             return None
 
-        result = (id_, cls)
+        result: Tuple[Any, ...] = (id_, cls)
 
         # inline of _cache_key_traversal_visitor.run_generated_dispatch()
 
@@ -268,7 +303,7 @@ class HasCacheKey:
                         # Columns, this should be long lived.   For select()
                         # statements, not so much, but they usually won't have
                         # annotations.
-                        result += self._annotations_cache_key
+                        result += self._annotations_cache_key  # type: ignore
                     elif (
                         meth is InternalTraversal.dp_clauseelement_list
                         or meth is InternalTraversal.dp_clauseelement_tuple
@@ -290,7 +325,7 @@ class HasCacheKey:
                         )
         return result
 
-    def _generate_cache_key(self):
+    def _generate_cache_key(self) -> Optional[CacheKey]:
         """return a cache key.
 
         The cache key is a tuple which can contain any series of
@@ -322,32 +357,40 @@ class HasCacheKey:
 
         """
 
-        bindparams = []
+        bindparams: List[BindParameter[Any]] = []
 
         _anon_map = anon_map()
         key = self._gen_cache_key(_anon_map, bindparams)
         if NO_CACHE in _anon_map:
             return None
         else:
+            assert key is not None
             return CacheKey(key, bindparams)
 
     @classmethod
-    def _generate_cache_key_for_object(cls, obj):
-        bindparams = []
+    def _generate_cache_key_for_object(
+        cls, obj: HasCacheKey
+    ) -> Optional[CacheKey]:
+        bindparams: List[BindParameter[Any]] = []
 
         _anon_map = anon_map()
         key = obj._gen_cache_key(_anon_map, bindparams)
         if NO_CACHE in _anon_map:
             return None
         else:
+            assert key is not None
             return CacheKey(key, bindparams)
 
 
+class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey):
+    pass
+
+
 class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
     __slots__ = ()
 
     @HasMemoized.memoized_instancemethod
-    def _generate_cache_key(self):
+    def _generate_cache_key(self) -> Optional[CacheKey]:
         return HasCacheKey._generate_cache_key(self)
 
 
@@ -362,14 +405,22 @@ class CacheKey(NamedTuple):
     """
 
     key: Tuple[Any, ...]
-    bindparams: Sequence[BindParameter]
+    bindparams: Sequence[BindParameter[Any]]
 
-    def __hash__(self):
+    # can't set __hash__ attribute because it interferes
+    # with namedtuple
+    # can't use "if not TYPE_CHECKING" because mypy rejects it
+    # inside of a NamedTuple
+    def __hash__(self) -> Optional[int]:  # type: ignore
         """CacheKey itself is not hashable - hash the .key portion"""
-
         return None
 
-    def to_offline_string(self, statement_cache, statement, parameters):
+    def to_offline_string(
+        self,
+        statement_cache: _CompiledCacheType,
+        statement: ClauseElement,
+        parameters: _CoreSingleExecuteParams,
+    ) -> str:
         """Generate an "offline string" form of this :class:`.CacheKey`
 
         The "offline string" is basically the string SQL for the
@@ -400,21 +451,21 @@ class CacheKey(NamedTuple):
 
         return repr((sql_str, param_tuple))
 
-    def __eq__(self, other):
-        return self.key == other.key
+    def __eq__(self, other: Any) -> bool:
+        return bool(self.key == other.key)
 
     @classmethod
-    def _diff_tuples(cls, left, right):
+    def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
         ck1 = CacheKey(left, [])
         ck2 = CacheKey(right, [])
         return ck1._diff(ck2)
 
-    def _whats_different(self, other):
+    def _whats_different(self, other: CacheKey) -> Iterator[str]:
 
         k1 = self.key
         k2 = other.key
 
-        stack = []
+        stack: List[int] = []
         pickup_index = 0
         while True:
             s1, s2 = k1, k2
@@ -440,11 +491,11 @@ class CacheKey(NamedTuple):
                 pickup_index = stack.pop(-1)
                 break
 
-    def _diff(self, other):
+    def _diff(self, other: CacheKey) -> str:
         return ", ".join(self._whats_different(other))
 
-    def __str__(self):
-        stack = [self.key]
+    def __str__(self) -> str:
+        stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key]
 
         output = []
         sentinel = object()
@@ -473,15 +524,15 @@ class CacheKey(NamedTuple):
 
         return "CacheKey(key=%s)" % ("\n".join(output),)
 
-    def _generate_param_dict(self):
+    def _generate_param_dict(self) -> Dict[str, Any]:
         """used for testing"""
 
-        from .compiler import prefix_anon_map
-
         _anon_map = prefix_anon_map()
         return {b.key % _anon_map: b.effective_value for b in self.bindparams}
 
-    def _apply_params_to_element(self, original_cache_key, target_element):
+    def _apply_params_to_element(
+        self, original_cache_key: CacheKey, target_element: ClauseElement
+    ) -> ClauseElement:
         translate = {
             k.key: v.value
             for k, v in zip(original_cache_key.bindparams, self.bindparams)
@@ -490,7 +541,7 @@ class CacheKey(NamedTuple):
         return target_element.params(translate)
 
 
-class _CacheKeyTraversal(ExtendedInternalTraversal):
+class _CacheKeyTraversal(HasTraversalDispatch):
     # very common elements are inlined into the main _get_cache_key() method
     # to produce a dramatic savings in Python function call overhead
 
@@ -512,17 +563,43 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
     visit_propagate_attrs = PROPAGATE_ATTRS
 
     def visit_with_context_options(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return tuple((fn.__code__, c_key) for fn, c_key in obj)
 
-    def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+    def visit_inspectable(
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
 
-    def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+    def visit_string_list(
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return tuple(obj)
 
-    def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+    def visit_multi(
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (
             attrname,
             obj._gen_cache_key(anon_map, bindparams)
@@ -530,7 +607,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
             else obj,
         )
 
-    def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+    def visit_multi_list(
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (
             attrname,
             tuple(
@@ -542,8 +626,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_has_cache_key_tuples(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         if not obj:
             return ()
         return (
@@ -558,8 +647,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_has_cache_key_list(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         if not obj:
             return ()
         return (
@@ -568,8 +662,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_executable_options(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         if not obj:
             return ()
         return (
@@ -582,22 +681,37 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_inspectable_list(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return self.visit_has_cache_key_list(
             attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
         )
 
     def visit_clauseelement_tuples(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return self.visit_has_cache_key_tuples(
             attrname, obj, parent, anon_map, bindparams
         )
 
     def visit_fromclause_ordered_set(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         if not obj:
             return ()
         return (
@@ -606,8 +720,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_clauseelement_unordered_set(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         if not obj:
             return ()
         cache_keys = [
@@ -621,13 +740,23 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_named_ddl_element(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (attrname, obj.name)
 
     def visit_prefix_sequence(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         if not obj:
             return ()
 
@@ -642,8 +771,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_setup_join_tuple(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return tuple(
             (
                 target._gen_cache_key(anon_map, bindparams),
@@ -659,8 +793,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_table_hint_list(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         if not obj:
             return ()
 
@@ -678,12 +817,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
             ),
         )
 
-    def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+    def visit_plain_dict(
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
 
     def visit_dialect_options(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (
             attrname,
             tuple(
@@ -701,8 +852,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_string_clauseelement_dict(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (
             attrname,
             tuple(
@@ -712,8 +868,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_string_multi_dict(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (
             attrname,
             tuple(
@@ -728,8 +889,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_fromclause_canonical_column_collection(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         # inlining into the internals of ColumnCollection
         return (
             attrname,
@@ -740,14 +906,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_unknown_structure(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         anon_map[NO_CACHE] = True
         return ()
 
     def visit_dml_ordered_values(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         return (
             attrname,
             tuple(
@@ -761,7 +937,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
             ),
         )
 
-    def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+    def visit_dml_values(
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         # in py37 we can assume two dictionaries created in the same
         # insert ordering will retain that sorting
         return (
@@ -778,8 +961,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
         )
 
     def visit_dml_multi_values(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
         # multivalues are simply not cacheable right now
         anon_map[NO_CACHE] = True
         return ()
index d616417ab3facc7f3bd2b76baba4cf1b2c88dbb6..834bfb75d459da8ae28e359f4d722aa5697c325d 100644 (file)
@@ -13,6 +13,9 @@ import re
 import typing
 from typing import Any
 from typing import Any as TODO_Any
+from typing import Dict
+from typing import List
+from typing import NoReturn
 from typing import Optional
 from typing import Type
 from typing import TypeVar
@@ -42,6 +45,7 @@ if typing.TYPE_CHECKING:
     from . import selectable
     from . import traversals
     from .elements import ClauseElement
+    from .elements import ColumnClause
 
 _SR = TypeVar("_SR", bound=roles.SQLRole)
 _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole)
@@ -252,7 +256,7 @@ def expect_col_expression_collection(role, expressions):
         if isinstance(resolved, str):
             strname = resolved = expr
         else:
-            cols = []
+            cols: List[ColumnClause[Any]] = []
             visitors.traverse(resolved, {}, {"column": cols.append})
             if cols:
                 column = cols[0]
@@ -266,7 +270,7 @@ class RoleImpl:
     def _literal_coercion(self, element, **kw):
         raise NotImplementedError()
 
-    _post_coercion = None
+    _post_coercion: Any = None
     _resolve_literal_only = False
     _skip_clauseelement_for_target_match = False
 
@@ -276,19 +280,24 @@ class RoleImpl:
         self._use_inspection = issubclass(role_class, roles.UsesInspection)
 
     def _implicit_coercions(
-        self, element, resolved, argname=None, **kw
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
     ) -> Any:
         self._raise_for_expected(element, argname, resolved)
 
     def _raise_for_expected(
         self,
-        element,
-        argname=None,
-        resolved=None,
-        advice=None,
-        code=None,
-        err=None,
-    ):
+        element: Any,
+        argname: Optional[str] = None,
+        resolved: Optional[Any] = None,
+        advice: Optional[str] = None,
+        code: Optional[str] = None,
+        err: Optional[Exception] = None,
+        **kw: Any,
+    ) -> NoReturn:
         if resolved is not None and resolved is not element:
             got = "%r object resolved from %r object" % (resolved, element)
         else:
@@ -324,22 +333,20 @@ class _StringOnly:
     _resolve_literal_only = True
 
 
-class _ReturnsStringKey:
+class _ReturnsStringKey(RoleImpl):
     __slots__ = ()
 
-    def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
-        if isinstance(original_element, str):
-            return original_element
+    def _implicit_coercions(self, element, resolved, argname=None, **kw):
+        if isinstance(element, str):
+            return element
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
     def _literal_coercion(self, element, **kw):
         return element
 
 
-class _ColumnCoercions:
+class _ColumnCoercions(RoleImpl):
     __slots__ = ()
 
     def _warn_for_scalar_subquery_coercion(self):
@@ -368,8 +375,12 @@ class _ColumnCoercions:
 
 
 def _no_text_coercion(
-    element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None
-):
+    element: Any,
+    argname: Optional[str] = None,
+    exc_cls: Type[exc.SQLAlchemyError] = exc.ArgumentError,
+    extra: Optional[str] = None,
+    err: Optional[Exception] = None,
+) -> NoReturn:
     raise exc_cls(
         "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
         "explicitly declared as text(%(expr)r)"
@@ -381,7 +392,7 @@ def _no_text_coercion(
     ) from err
 
 
-class _NoTextCoercion:
+class _NoTextCoercion(RoleImpl):
     __slots__ = ()
 
     def _literal_coercion(self, element, argname=None, **kw):
@@ -393,7 +404,7 @@ class _NoTextCoercion:
             self._raise_for_expected(element, argname)
 
 
-class _CoerceLiterals:
+class _CoerceLiterals(RoleImpl):
     __slots__ = ()
     _coerce_consts = False
     _coerce_star = False
@@ -440,12 +451,19 @@ class LiteralValueImpl(RoleImpl):
         return element
 
 
-class _SelectIsNotFrom:
+class _SelectIsNotFrom(RoleImpl):
     __slots__ = ()
 
     def _raise_for_expected(
-        self, element, argname=None, resolved=None, advice=None, **kw
-    ):
+        self,
+        element: Any,
+        argname: Optional[str] = None,
+        resolved: Optional[Any] = None,
+        advice: Optional[str] = None,
+        code: Optional[str] = None,
+        err: Optional[Exception] = None,
+        **kw: Any,
+    ) -> NoReturn:
         if (
             not advice
             and isinstance(element, roles.SelectStatementRole)
@@ -460,26 +478,33 @@ class _SelectIsNotFrom:
         else:
             code = None
 
-        return super(_SelectIsNotFrom, self)._raise_for_expected(
+        super()._raise_for_expected(
             element,
             argname=argname,
             resolved=resolved,
             advice=advice,
             code=code,
+            err=err,
             **kw,
         )
+        # never reached
+        assert False
 
 
 class HasCacheKeyImpl(RoleImpl):
     __slots__ = ()
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
-        if isinstance(original_element, traversals.HasCacheKey):
-            return original_element
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
+        if isinstance(element, HasCacheKey):
+            return element
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
     def _literal_coercion(self, element, **kw):
         return element
@@ -489,12 +514,16 @@ class ExecutableOptionImpl(RoleImpl):
     __slots__ = ()
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
-        if isinstance(original_element, ExecutableOption):
-            return original_element
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
+        if isinstance(element, ExecutableOption):
+            return element
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
     def _literal_coercion(self, element, **kw):
         return element
@@ -560,8 +589,12 @@ class InElementImpl(RoleImpl):
     __slots__ = ()
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
         if resolved._is_from_clause:
             if (
                 isinstance(resolved, selectable.Alias)
@@ -573,7 +606,7 @@ class InElementImpl(RoleImpl):
                 self._warn_for_implicit_coercion(resolved)
                 return self._post_coercion(resolved.select(), **kw)
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
     def _warn_for_implicit_coercion(self, elem):
         util.warn(
@@ -586,12 +619,16 @@ class InElementImpl(RoleImpl):
         if isinstance(element, collections_abc.Iterable) and not isinstance(
             element, str
         ):
-            non_literal_expressions = {}
+            non_literal_expressions: Dict[
+                Optional[operators.ColumnOperators[Any]],
+                operators.ColumnOperators[Any],
+            ] = {}
             element = list(element)
             for o in element:
                 if not _is_literal(o):
                     if not isinstance(o, operators.ColumnOperators):
                         self._raise_for_expected(element, **kw)
+
                     else:
                         non_literal_expressions[o] = o
                 elif o is None:
@@ -712,8 +749,12 @@ class GroupByImpl(ByOfImpl, RoleImpl):
     __slots__ = ()
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
         if isinstance(resolved, roles.StrictFromClauseRole):
             return elements.ClauseList(*resolved.c)
         else:
@@ -748,12 +789,16 @@ class TruncatedLabelImpl(_StringOnly, RoleImpl):
     __slots__ = ()
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
-        if isinstance(original_element, str):
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
+        if isinstance(element, str):
             return resolved
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
     def _literal_coercion(self, element, argname=None, **kw):
         """coerce the given value to :class:`._truncated_label`.
@@ -794,7 +839,13 @@ class DDLReferredColumnImpl(DDLConstraintColumnImpl):
 class LimitOffsetImpl(RoleImpl):
     __slots__ = ()
 
-    def _implicit_coercions(self, element, resolved, argname=None, **kw):
+    def _implicit_coercions(
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
         if resolved is None:
             return None
         else:
@@ -814,18 +865,22 @@ class LabeledColumnExprImpl(ExpressionElementImpl):
     __slots__ = ()
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
         if isinstance(resolved, roles.ExpressionElementRole):
             return resolved.label(None)
         else:
             new = super(LabeledColumnExprImpl, self)._implicit_coercions(
-                original_element, resolved, argname=argname, **kw
+                element, resolved, argname=argname, **kw
             )
             if isinstance(new, roles.ExpressionElementRole):
                 return new.label(None)
             else:
-                self._raise_for_expected(original_element, argname, resolved)
+                self._raise_for_expected(element, argname, resolved)
 
 
 class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
@@ -899,13 +954,17 @@ class StatementImpl(_CoerceLiterals, RoleImpl):
         return resolved
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
         if resolved._is_lambda_element:
             return resolved
         else:
-            return super(StatementImpl, self)._implicit_coercions(
-                original_element, resolved, argname=argname, **kw
+            return super()._implicit_coercions(
+                element, resolved, argname=argname, **kw
             )
 
 
@@ -913,12 +972,16 @@ class SelectStatementImpl(_NoTextCoercion, RoleImpl):
     __slots__ = ()
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
         if resolved._is_text_clause:
             return resolved.columns()
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
 
 class HasCTEImpl(ReturnsRowsImpl):
@@ -938,13 +1001,18 @@ class JoinTargetImpl(RoleImpl):
         self._raise_for_expected(element, argname)
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, legacy=False, **kw
-    ):
-        if isinstance(original_element, roles.JoinTargetRole):
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        legacy: bool = False,
+        **kw: Any,
+    ) -> Any:
+        if isinstance(element, roles.JoinTargetRole):
             # note that this codepath no longer occurs as of
             # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match
             # were set to False.
-            return original_element
+            return element
         elif legacy and resolved._is_select_statement:
             util.warn_deprecated(
                 "Implicit coercion of SELECT and textual SELECT "
@@ -959,7 +1027,7 @@ class JoinTargetImpl(RoleImpl):
             # in _ORMJoin->Join
             return resolved
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
 
 class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
@@ -967,13 +1035,13 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
 
     def _implicit_coercions(
         self,
-        original_element,
-        resolved,
-        argname=None,
-        explicit_subquery=False,
-        allow_select=True,
-        **kw,
-    ):
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        explicit_subquery: bool = False,
+        allow_select: bool = True,
+        **kw: Any,
+    ) -> Any:
         if resolved._is_select_statement:
             if explicit_subquery:
                 return resolved.subquery()
@@ -989,7 +1057,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
         elif resolved._is_text_clause:
             return resolved
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
     def _post_coercion(self, element, deannotate=False, **kw):
         if deannotate:
@@ -1003,12 +1071,13 @@ class StrictFromClauseImpl(FromClauseImpl):
 
     def _implicit_coercions(
         self,
-        original_element,
-        resolved,
-        argname=None,
-        allow_select=False,
-        **kw,
-    ):
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        explicit_subquery: bool = False,
+        allow_select: bool = False,
+        **kw: Any,
+    ) -> Any:
         if resolved._is_select_statement and allow_select:
             util.warn_deprecated(
                 "Implicit coercion of SELECT and textual SELECT constructs "
@@ -1019,7 +1088,7 @@ class StrictFromClauseImpl(FromClauseImpl):
             )
             return resolved._implicit_subquery
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
 
 class AnonymizedFromClauseImpl(StrictFromClauseImpl):
@@ -1045,8 +1114,12 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl):
     __slots__ = ()
 
     def _implicit_coercions(
-        self, original_element, resolved, argname=None, **kw
-    ):
+        self,
+        element: Any,
+        resolved: Any,
+        argname: Optional[str] = None,
+        **kw: Any,
+    ) -> Any:
         if resolved._is_from_clause:
             if (
                 isinstance(resolved, selectable.Alias)
@@ -1056,7 +1129,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl):
             else:
                 return resolved.select()
         else:
-            self._raise_for_expected(original_element, argname, resolved)
+            self._raise_for_expected(element, argname, resolved)
 
 
 class CompoundElementImpl(_NoTextCoercion, RoleImpl):
index 423c3d446e3f09c485cbe92edfa38ed373e3ff5a..f28dceefcee16137a4b724e2aef44bddb88e305f 100644 (file)
@@ -35,14 +35,19 @@ from time import perf_counter
 import typing
 from typing import Any
 from typing import Callable
+from typing import cast
 from typing import Dict
+from typing import FrozenSet
+from typing import Iterable
 from typing import List
 from typing import Mapping
 from typing import MutableMapping
 from typing import NamedTuple
 from typing import Optional
 from typing import Sequence
+from typing import Set
 from typing import Tuple
+from typing import Type
 from typing import Union
 
 from . import base
@@ -54,19 +59,42 @@ from . import operators
 from . import schema
 from . import selectable
 from . import sqltypes
+from .base import _from_objects
 from .base import NO_ARG
-from .base import prefix_anon_map
 from .elements import quoted_name
 from .schema import Column
+from .sqltypes import TupleType
 from .type_api import TypeEngine
+from .visitors import prefix_anon_map
 from .. import exc
 from .. import util
 from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypedDict
 
 if typing.TYPE_CHECKING:
+    from .annotation import _AnnotationDict
+    from .base import _AmbiguousTableNameMap
+    from .base import CompileState
+    from .cache_key import CacheKey
+    from .elements import BindParameter
+    from .elements import ColumnClause
+    from .elements import Label
+    from .functions import Function
+    from .selectable import Alias
+    from .selectable import AliasedReturnsRows
+    from .selectable import CompoundSelectState
     from .selectable import CTE
     from .selectable import FromClause
+    from .selectable import NamedFromClause
+    from .selectable import ReturnsRows
+    from .selectable import Select
+    from .selectable import SelectState
+    from ..engine.cursor import CursorResultMetaData
     from ..engine.interfaces import _CoreSingleExecuteParams
+    from ..engine.interfaces import _ExecuteOptions
+    from ..engine.interfaces import _MutableCoreSingleExecuteParams
+    from ..engine.interfaces import _SchemaTranslateMapType
     from ..engine.result import _ProcessorType
 
 _FromHintsType = Dict["FromClause", str]
@@ -236,7 +264,7 @@ OPERATORS = {
     operators.nulls_last_op: " NULLS LAST",
 }
 
-FUNCTIONS = {
+FUNCTIONS: Dict[Type[Function], str] = {
     functions.coalesce: "coalesce",
     functions.current_date: "CURRENT_DATE",
     functions.current_time: "CURRENT_TIME",
@@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple):
     name: str
     """column name, may be labeled"""
 
-    objects: List[Any]
-    """list of objects that should be able to locate this column
+    objects: Tuple[Any, ...]
+    """sequence of objects that should be able to locate this column
     in a RowMapping.  This is typically string names and aliases
     as well as Column objects.
 
@@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple):
     """
 
 
+class _ResultMapAppender(Protocol):
+    def __call__(
+        self,
+        keyname: str,
+        name: str,
+        objects: Sequence[Any],
+        type_: TypeEngine[Any],
+    ) -> None:
+        ...
+
+
 # integer indexes into ResultColumnsEntry used by cursor.py.
 # some profiling showed integer access faster than named tuple
 RM_RENDERED_NAME: Literal[0] = 0
@@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2
 RM_TYPE: Literal[3] = 3
 
 
+class _BaseCompilerStackEntry(TypedDict):
+    asfrom_froms: Set[FromClause]
+    correlate_froms: Set[FromClause]
+    selectable: ReturnsRows
+
+
+class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
+    compile_state: CompileState
+    need_result_map_for_nested: bool
+    need_result_map_for_compound: bool
+    select_0: ReturnsRows
+    insert_from_select: Select
+
+
 class ExpandedState(NamedTuple):
     statement: str
     additional_parameters: _CoreSingleExecuteParams
@@ -427,21 +480,23 @@ class Compiled:
     defaults.
     """
 
-    _cached_metadata = None
+    _cached_metadata: Optional[CursorResultMetaData] = None
 
     _result_columns: Optional[List[ResultColumnsEntry]] = None
 
-    schema_translate_map = None
+    schema_translate_map: Optional[_SchemaTranslateMapType] = None
 
-    execution_options = util.EMPTY_DICT
+    execution_options: _ExecuteOptions = util.EMPTY_DICT
     """
     Execution options propagated from the statement.   In some cases,
     sub-elements of the statement can modify these.
     """
 
-    _annotations = util.EMPTY_DICT
+    preparer: IdentifierPreparer
+
+    _annotations: _AnnotationDict = util.EMPTY_DICT
 
-    compile_state = None
+    compile_state: Optional[CompileState] = None
     """Optional :class:`.CompileState` object that maintains additional
     state used by the compiler.
 
@@ -457,9 +512,21 @@ class Compiled:
 
     """
 
-    cache_key = None
+    cache_key: Optional[CacheKey] = None
+    """The :class:`.CacheKey` that was generated ahead of creating this
+    :class:`.Compiled` object.
+
+    This is used for routines that need access to the original
+    :class:`.CacheKey` instance generated when the :class:`.Compiled`
+    instance was first cached, typically in order to reconcile
+    the original list of :class:`.BindParameter` objects with a
+    per-statement list that's generated on each call.
+
+    """
 
     _gen_time: float
+    """Generation time of this :class:`.Compiled`, used for reporting
+    cache stats."""
 
     def __init__(
         self,
@@ -543,7 +610,11 @@ class Compiled:
 
         return self.string or ""
 
-    def construct_params(self, params=None, extracted_parameters=None):
+    def construct_params(
+        self,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+    ) -> Optional[_MutableCoreSingleExecuteParams]:
         """Return the bind params for this compiled object.
 
         :param params: a dict of string/object pairs whose values will
@@ -646,6 +717,17 @@ class SQLCompiler(Compiled):
 
     isplaintext: bool = False
 
+    binds: Dict[str, BindParameter[Any]]
+    """a dictionary of bind parameter keys to BindParameter instances."""
+
+    bind_names: Dict[BindParameter[Any], str]
+    """a dictionary of BindParameter instances to "compiled" names
+    that are actually present in the generated SQL"""
+
+    stack: List[_CompilerStackEntry]
+    """major statements such as SELECT, INSERT, UPDATE, DELETE are
+    tracked in this stack using an entry format."""
+
     result_columns: List[ResultColumnsEntry]
     """relates label names in the final SQL to a tuple of local
     column/label name, ColumnElement object (if any) and
@@ -709,7 +791,7 @@ class SQLCompiler(Compiled):
 
     """
 
-    insert_single_values_expr = None
+    insert_single_values_expr: Optional[str] = None
     """When an INSERT is compiled with a single set of parameters inside
     a VALUES expression, the string is assigned here, where it can be
     used for insert batching schemes to rewrite the VALUES expression.
@@ -718,19 +800,19 @@ class SQLCompiler(Compiled):
 
     """
 
-    literal_execute_params = frozenset()
+    literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
     """bindparameter objects that are rendered as literal values at statement
     execution time.
 
     """
 
-    post_compile_params = frozenset()
+    post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
     """bindparameter objects that are rendered as bound parameter placeholders
     at statement execution time.
 
     """
 
-    escaped_bind_names = util.EMPTY_DICT
+    escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
     """Late escaping of bound parameter names that has to be converted
     to the original name when looking in the parameter dictionary.
 
@@ -744,14 +826,25 @@ class SQLCompiler(Compiled):
     """if True, and this in insert, use cursor.lastrowid to populate
     result.inserted_primary_key. """
 
-    _cache_key_bind_match = None
+    _cache_key_bind_match: Optional[
+        Tuple[
+            Dict[
+                BindParameter[Any],
+                List[BindParameter[Any]],
+            ],
+            Dict[
+                str,
+                BindParameter[Any],
+            ],
+        ]
+    ] = None
     """a mapping that will relate the BindParameter object we compile
     to those that are part of the extracted collection of parameters
     in the cache key, if we were given a cache key.
 
     """
 
-    positiontup: Optional[Sequence[str]] = None
+    positiontup: Optional[List[str]] = None
     """for a compiled construct that uses a positional paramstyle, will be
     a sequence of strings, indicating the names of bound parameters in order.
 
@@ -768,6 +861,19 @@ class SQLCompiler(Compiled):
 
     inline: bool = False
 
+    ctes: Optional[MutableMapping[CTE, str]]
+
+    # Detect same CTE references - Dict[(level, name), cte]
+    # Level is required for supporting nesting
+    ctes_by_level_name: Dict[Tuple[int, str], CTE]
+
+    # To retrieve key/level in ctes_by_level_name -
+    # Dict[cte_reference, (level, cte_name, cte_opts)]
+    level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
+
+    ctes_recursive: bool
+    cte_positional: Dict[CTE, List[str]]
+
     def __init__(
         self,
         dialect,
@@ -804,10 +910,9 @@ class SQLCompiler(Compiled):
         self.cache_key = cache_key
 
         if cache_key:
-            self._cache_key_bind_match = ckbm = {
-                b.key: b for b in cache_key[1]
-            }
-            ckbm.update({b: [b] for b in cache_key[1]})
+            cksm = {b.key: b for b in cache_key[1]}
+            ckbm = {b: [b] for b in cache_key[1]}
+            self._cache_key_bind_match = (ckbm, cksm)
 
         # compile INSERT/UPDATE defaults/sequences to expect executemany
         # style execution, which may mean no pre-execute of defaults,
@@ -911,14 +1016,14 @@ class SQLCompiler(Compiled):
 
     @property
     def prefetch(self):
-        return list(self.insert_prefetch + self.update_prefetch)
+        return list(self.insert_prefetch) + list(self.update_prefetch)
 
     @util.memoized_property
     def _global_attributes(self):
         return {}
 
     @util.memoized_instancemethod
-    def _init_cte_state(self) -> None:
+    def _init_cte_state(self) -> MutableMapping[CTE, str]:
         """Initialize collections related to CTEs only if
         a CTE is located, to save on the overhead of
         these collections otherwise.
@@ -926,21 +1031,22 @@ class SQLCompiler(Compiled):
         """
         # collect CTEs to tack on top of a SELECT
         # To store the query to print - Dict[cte, text_query]
-        self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
+        ctes: MutableMapping[CTE, str] = util.OrderedDict()
+        self.ctes = ctes
 
         # Detect same CTE references - Dict[(level, name), cte]
         # Level is required for supporting nesting
-        self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
+        self.ctes_by_level_name = {}
 
         # To retrieve key/level in ctes_by_level_name -
         # Dict[cte_reference, (level, cte_name, cte_opts)]
-        self.level_name_by_cte: Dict[
-            CTE, Tuple[int, str, selectable._CTEOpts]
-        ] = {}
+        self.level_name_by_cte = {}
 
-        self.ctes_recursive: bool = False
+        self.ctes_recursive = False
         if self.positional:
-            self.cte_positional: Dict[CTE, List[str]] = {}
+            self.cte_positional = {}
+
+        return ctes
 
     @contextlib.contextmanager
     def _nested_result(self):
@@ -985,7 +1091,7 @@ class SQLCompiler(Compiled):
                     if not bindparam.type._is_tuple_type
                     else tuple(
                         elem_type._cached_bind_processor(self.dialect)
-                        for elem_type in bindparam.type.types
+                        for elem_type in cast(TupleType, bindparam.type).types
                     ),
                 )
                 for bindparam in self.bind_names
@@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled):
 
     def construct_params(
         self,
-        params=None,
-        _group_number=None,
-        _check=True,
-        extracted_parameters=None,
-    ):
+        params: Optional[_CoreSingleExecuteParams] = None,
+        extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+        _group_number: Optional[int] = None,
+        _check: bool = True,
+    ) -> _MutableCoreSingleExecuteParams:
         """return a dictionary of bind parameter keys and values"""
 
         has_escaped_names = bool(self.escaped_bind_names)
@@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled):
             # way.   The parameters present in self.bind_names may be clones of
             # these original cache key params in the case of DML but the .key
             # will be guaranteed to match.
-            try:
-                orig_extracted = self.cache_key[1]
-            except TypeError as err:
+            if self.cache_key is None:
                 raise exc.CompileError(
                     "This compiled object has no original cache key; "
                     "can't pass extracted_parameters to construct_params"
-                ) from err
+                )
+            else:
+                orig_extracted = self.cache_key[1]
 
-            ckbm = self._cache_key_bind_match
+            ckbm_tuple = self._cache_key_bind_match
+            assert ckbm_tuple is not None
+            ckbm, _ = ckbm_tuple
             resolved_extracted = {
                 bind: extracted
                 for b, extracted in zip(orig_extracted, extracted_parameters)
@@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled):
 
             if bindparam.type._is_tuple_type:
                 inputsizes[bindparam] = [
-                    lookup_type(typ) for typ in bindparam.type.types
+                    lookup_type(typ)
+                    for typ in cast(TupleType, bindparam.type).types
                 ]
             else:
                 inputsizes[bindparam] = lookup_type(bindparam.type)
@@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled):
 
     def _process_parameters_for_postcompile(
         self,
-        parameters: Optional[_CoreSingleExecuteParams] = None,
+        parameters: Optional[_MutableCoreSingleExecuteParams] = None,
         _populate_self: bool = False,
     ) -> ExpandedState:
         """handle special post compile parameters.
@@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled):
             parameters = self.construct_params()
 
         expanded_parameters = {}
+        positiontup: Optional[List[str]]
+
         if self.positional:
             positiontup = []
         else:
             positiontup = None
 
         processors = self._bind_processors
+        single_processors = cast("Mapping[str, _ProcessorType]", processors)
+        tuple_processors = cast(
+            "Mapping[str, Sequence[_ProcessorType]]", processors
+        )
 
-        new_processors = {}
+        new_processors: Dict[str, _ProcessorType] = {}
 
         if self.positional and self._numeric_binds:
             # I'm not familiar with any DBAPI that uses 'numeric'.
@@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled):
                 "the 'numeric' paramstyle at this time."
             )
 
-        replacement_expressions = {}
-        to_update_sets = {}
+        replacement_expressions: Dict[str, Any] = {}
+        to_update_sets: Dict[str, Any] = {}
 
         # notes:
         # *unescaped* parameter names in:
@@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled):
         # *escaped* parameter names in:
         # construct_params(), replacement_expressions
 
-        for name in (
-            self.positiontup if self.positional else self.bind_names.values()
-        ):
+        if self.positional and self.positiontup is not None:
+            names: Iterable[str] = self.positiontup
+        else:
+            names = self.bind_names.values()
+
+        for name in names:
             escaped_name = (
                 self.escaped_bind_names.get(name, name)
                 if self.escaped_bind_names
@@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled):
             if parameter in self.post_compile_params:
                 if escaped_name in replacement_expressions:
                     to_update = to_update_sets[escaped_name]
+                    values = None
                 else:
                     # we are removing the parameter from parameters
                     # because it is a list value, which is not expected by
@@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled):
                 if not parameter.literal_execute:
                     parameters.update(to_update)
                     if parameter.type._is_tuple_type:
+                        assert values is not None
                         new_processors.update(
                             (
                                 "%s_%s_%s" % (name, i, j),
-                                processors[name][j - 1],
+                                tuple_processors[name][j - 1],
                             )
                             for i, tuple_element in enumerate(values, 1)
-                            for j, value in enumerate(tuple_element, 1)
-                            if name in processors
-                            and processors[name][j - 1] is not None
+                            for j, _ in enumerate(tuple_element, 1)
+                            if name in tuple_processors
+                            and tuple_processors[name][j - 1] is not None
                         )
                     else:
                         new_processors.update(
-                            (key, processors[name])
-                            for key, value in to_update
-                            if name in processors
+                            (key, single_processors[name])
+                            for key, _ in to_update
+                            if name in single_processors
                         )
-                    if self.positional:
-                        positiontup.extend(name for name, value in to_update)
+                    if positiontup is not None:
+                        positiontup.extend(name for name, _ in to_update)
                     expanded_parameters[name] = [
-                        expand_key for expand_key, value in to_update
+                        expand_key for expand_key, _ in to_update
                     ]
-            elif self.positional:
+            elif positiontup is not None:
                 positiontup.append(name)
 
         def process_expanding(m):
@@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled):
             # special use cases.
             self.string = expanded_state.statement
             self._bind_processors.update(expanded_state.processors)
-            self.positiontup = expanded_state.positiontup
+            self.positiontup = list(expanded_state.positiontup or ())
             self.post_compile_params = frozenset()
             for key in expanded_state.parameter_expansion:
                 bind = self.binds.pop(key)
@@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled):
             self._result_columns
         )
 
+    _key_getters_for_crud_column: Tuple[
+        Callable[[Union[str, Column[Any]]], str],
+        Callable[[Column[Any]], str],
+        Callable[[Column[Any]], str],
+    ]
+
     @util.memoized_property
     def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
         getter = self._key_getters_for_crud_column[2]
@@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled):
     @util.memoized_property
     @util.preload_module("sqlalchemy.engine.result")
     def _inserted_primary_key_from_returning_getter(self):
-        result = util.preloaded.engine_result
+        if typing.TYPE_CHECKING:
+            from ..engine import result
+        else:
+            result = util.preloaded.engine_result
 
         param_key_getter = self._within_exec_param_key_getter
         table = self.statement.table
 
-        ret = {col: idx for idx, col in enumerate(self.returning)}
+        returning = self.returning
+        assert returning is not None
+        ret = {col: idx for idx, col in enumerate(returning)}
 
-        getters = [
-            (operator.itemgetter(ret[col]), True)
-            if col in ret
-            else (
-                operator.methodcaller("get", param_key_getter(col), None),
-                False,
-            )
-            for col in table.primary_key
-        ]
+        getters = cast(
+            "List[Tuple[Callable[[Any], Any], bool]]",
+            [
+                (operator.itemgetter(ret[col]), True)
+                if col in ret
+                else (
+                    operator.methodcaller("get", param_key_getter(col), None),
+                    False,
+                )
+                for col in table.primary_key
+            ],
+        )
 
         row_fn = result.result_tuple([col.key for col in table.primary_key])
 
@@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled):
         self, element, within_columns_clause=False, **kwargs
     ):
         if self.stack and self.dialect.supports_simple_order_by_label:
-            compile_state = self.stack[-1]["compile_state"]
+            try:
+                compile_state = cast(
+                    "Union[SelectState, CompoundSelectState]",
+                    self.stack[-1]["compile_state"],
+                )
+            except KeyError as ke:
+                raise exc.CompileError(
+                    "Can't resolve label reference for ORDER BY / "
+                    "GROUP BY / DISTINCT etc."
+                ) from ke
 
             (
                 with_cols,
@@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled):
             # compiling the element outside of the context of a SELECT
             return self.process(element._text_clause)
 
-        compile_state = self.stack[-1]["compile_state"]
+        try:
+            compile_state = cast(
+                "Union[SelectState, CompoundSelectState]",
+                self.stack[-1]["compile_state"],
+            )
+        except KeyError as ke:
+            coercions._no_text_coercion(
+                element.element,
+                extra=(
+                    "Can't resolve label reference for ORDER BY / "
+                    "GROUP BY / DISTINCT etc."
+                ),
+                exc_cls=exc.CompileError,
+                err=ke,
+            )
+
         with_cols, only_froms, only_cols = compile_state._label_resolve_dict
         try:
             if within_columns_clause:
@@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled):
 
     def visit_column(
         self,
-        column,
-        add_to_result_map=None,
-        include_table=True,
-        result_map_targets=(),
-        ambiguous_table_name_map=None,
-        **kwargs,
-    ):
+        column: ColumnClause[Any],
+        add_to_result_map: Optional[_ResultMapAppender] = None,
+        include_table: bool = True,
+        result_map_targets: Tuple[Any, ...] = (),
+        ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
+        **kwargs: Any,
+    ) -> str:
         name = orig_name = column.name
         if name is None:
             name = self._fallback_column_name(column)
@@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled):
                 )
             else:
                 schema_prefix = ""
-            tablename = table.name
+
+            tablename = cast("NamedFromClause", table).name
 
             if (
                 not effective_schema
@@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
-        new_entry = {
+        new_entry: _CompilerStackEntry = {
             "correlate_froms": set(),
             "asfrom_froms": set(),
             "selectable": taf,
@@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled):
         compiled_col = self.visit_column(element, **kw)
         return "(%s).%s" % (compiled_fn, compiled_col)
 
-    def visit_function(self, func, add_to_result_map=None, **kwargs):
+    def visit_function(
+        self,
+        func: Function,
+        add_to_result_map: Optional[_ResultMapAppender] = None,
+        **kwargs: Any,
+    ) -> str:
         if add_to_result_map is not None:
             add_to_result_map(func.name, func.name, (), func.type)
 
         disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+
+        text: str
+
         if disp:
             text = disp(func, **kwargs)
         else:
@@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled):
         if compound_stmt._independent_ctes:
             self._dispatch_independent_ctes(compound_stmt, kwargs)
 
-        keyword = self.compound_keywords.get(cs.keyword)
+        keyword = self.compound_keywords[cs.keyword]
 
         text = (" " + keyword + " ").join(
             (
@@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled):
         # a different set of parameter values.   here, we accommodate for
         # parameters that may have been cloned both before and after the cache
         # key was been generated.
-        ckbm = self._cache_key_bind_match
-        if ckbm:
+        ckbm_tuple = self._cache_key_bind_match
+
+        if ckbm_tuple:
+            ckbm, cksm = ckbm_tuple
             for bp in bindparam._cloned_set:
-                if bp.key in ckbm:
-                    cb = ckbm[bp.key]
+                if bp.key in cksm:
+                    cb = cksm[bp.key]
                     ckbm[cb].append(bindparam)
 
         if bindparam.isoutparam:
@@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled):
             if positional_names is not None:
                 positional_names.append(name)
             else:
-                self.positiontup.append(name)
+                self.positiontup.append(name)  # type: ignore[union-attr]
         elif not escaped_from:
 
             if _BIND_TRANSLATE_RE.search(name):
@@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled):
                 name = new_name
 
         if escaped_from:
-            if not self.escaped_bind_names:
-                self.escaped_bind_names = {}
-            self.escaped_bind_names[escaped_from] = name
+            self.escaped_bind_names = self.escaped_bind_names.union(
+                {escaped_from: name}
+            )
         if post_compile:
             return "__[POSTCOMPILE_%s]" % name
 
@@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled):
         cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
         **kwargs: Any,
     ) -> Optional[str]:
-        self._init_cte_state()
+        self_ctes = self._init_cte_state()
+        assert self_ctes is self.ctes
 
         kwargs["visiting_cte"] = cte
 
@@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled):
                 # we've generated a same-named CTE that is
                 # enclosed in us - we take precedence, so
                 # discard the text for the "inner".
-                del self.ctes[existing_cte]
+                del self_ctes[existing_cte]
 
                 existing_cte_reference_cte = existing_cte._get_reference_cte()
 
@@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled):
             if pre_alias_cte not in self.ctes:
                 self.visit_cte(pre_alias_cte, **kwargs)
 
-            if not cte_pre_alias_name and cte not in self.ctes:
+            if not cte_pre_alias_name and cte not in self_ctes:
                 if cte.recursive:
                     self.ctes_recursive = True
                 text = self.preparer.format_alias(cte, cte_name)
@@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled):
                         cte, cte._suffixes, **kwargs
                     )
 
-                self.ctes[cte] = text
+                self_ctes[cte] = text
 
         if asfrom:
             if from_linter:
                 from_linter.froms[cte] = cte_name
 
             if not is_new_cte and embedded_in_current_named_cte:
-                return self.preparer.format_alias(cte, cte_name)
+                return self.preparer.format_alias(cte, cte_name)  # type: ignore[no-any-return]  # noqa: E501
 
             if cte_pre_alias_name:
                 text = self.preparer.format_alias(cte, cte_pre_alias_name)
@@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled):
             else:
                 return self.preparer.format_alias(cte, cte_name)
 
+        return None
+
     def visit_table_valued_alias(self, element, **kw):
         if element._is_lateral:
             return self.visit_lateral(element, **kw)
@@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled):
         self,
         keyname: str,
         name: str,
-        objects: List[Any],
+        objects: Tuple[Any, ...],
         type_: TypeEngine[Any],
     ) -> None:
         if keyname is None or keyname == "*":
@@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled):
     def get_statement_hint_text(self, hint_texts):
         return " ".join(hint_texts)
 
-    _default_stack_entry = util.immutabledict(
-        [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
-    )
+    _default_stack_entry: _CompilerStackEntry
+
+    if not typing.TYPE_CHECKING:
+        _default_stack_entry = util.immutabledict(
+            [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+        )
 
     def _display_froms_for_select(
         self, select_stmt, asfrom, lateral=False, **kw
@@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled):
             )
         return froms
 
-    translate_select_structure = None
+    translate_select_structure: Any = None
     """if not ``None``, should be a callable which accepts ``(select_stmt,
     **kw)`` and returns a select object.   this is used for structural changes
     mostly to accommodate for LIMIT/OFFSET schemes
@@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled):
             )
 
             self._result_columns = [
-                (key, name, tuple(translate.get(o, o) for o in obj), type_)
+                ResultColumnsEntry(
+                    key, name, tuple(translate.get(o, o) for o in obj), type_
+                )
                 for key, name, obj, type_ in self._result_columns
             ]
 
@@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled):
                 implicit_correlate_froms=asfrom_froms,
             )
 
-        new_correlate_froms = set(selectable._from_objects(*froms))
+        new_correlate_froms = set(_from_objects(*froms))
         all_correlate_froms = new_correlate_froms.union(correlate_froms)
 
-        new_entry = {
+        new_entry: _CompilerStackEntry = {
             "asfrom_froms": new_correlate_froms,
             "correlate_froms": all_correlate_froms,
             "selectable": select,
@@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled):
                 text += " \nWHERE " + t
 
         if warn_linting:
+            assert from_linter is not None
             from_linter.warn()
 
         if select._group_by_clauses:
@@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled):
         if not self.ctes:
             return ""
 
+        ctes: MutableMapping[CTE, str]
+
         if nesting_level and nesting_level > 1:
             ctes = util.OrderedDict()
             for cte in list(self.ctes.keys()):
@@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled):
         ctes_recursive = any([cte.recursive for cte in ctes])
 
         if self.positional:
+            assert self.positiontup is not None
             self.positiontup = (
-                sum([self.cte_positional[cte] for cte in ctes], [])
+                list(
+                    itertools.chain.from_iterable(
+                        self.cte_positional[cte] for cte in ctes
+                    )
+                )
                 + self.positiontup
             )
+
         cte_text = self.get_cte_preamble(ctes_recursive) + " "
         cte_text += ", \n".join([txt for txt in ctes.values()])
         cte_text += "\n "
@@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled):
 
         if is_multitable:
             # main table might be a JOIN
-            main_froms = set(selectable._from_objects(update_stmt.table))
+            main_froms = set(_from_objects(update_stmt.table))
             render_extra_froms = [
                 f for f in extra_froms if f not in main_froms
             ]
@@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled):
     def type_compiler(self):
         return self.dialect.type_compiler
 
-    def construct_params(self, params=None, extracted_parameters=None):
+    def construct_params(
+        self,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+    ) -> Optional[_MutableCoreSingleExecuteParams]:
         return None
 
     def visit_ddl(self, ddl, **kwargs):
@@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler):
             return get_col_spec(**kw)
 
 
+class _SchemaForObjectCallable(Protocol):
+    def __call__(self, obj: Any) -> str:
+        ...
+
+
 class IdentifierPreparer:
 
     """Handle quoting and case-folding of identifiers based on options."""
@@ -5209,7 +5404,13 @@ class IdentifierPreparer:
 
     illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
 
-    schema_for_object = operator.attrgetter("schema")
+    initial_quote: str
+
+    final_quote: str
+
+    _strings: MutableMapping[str, str]
+
+    schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
     """Return the .schema attribute for an object.
 
     For the default IdentifierPreparer, the schema for an object is always
@@ -5297,7 +5498,7 @@ class IdentifierPreparer:
 
         return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
 
-    def _escape_identifier(self, value):
+    def _escape_identifier(self, value: str) -> str:
         """Escape an identifier.
 
         Subclasses should override this to provide database-dependent
@@ -5309,7 +5510,7 @@ class IdentifierPreparer:
             value = value.replace("%", "%%")
         return value
 
-    def _unescape_identifier(self, value):
+    def _unescape_identifier(self, value: str) -> str:
         """Canonicalize an escaped identifier.
 
         Subclasses should override this to provide database-dependent
@@ -5336,7 +5537,7 @@ class IdentifierPreparer:
             )
         return element
 
-    def quote_identifier(self, value):
+    def quote_identifier(self, value: str) -> str:
         """Quote an identifier.
 
         Subclasses should override this to provide database-dependent
@@ -5349,7 +5550,7 @@ class IdentifierPreparer:
             + self.final_quote
         )
 
-    def _requires_quotes(self, value):
+    def _requires_quotes(self, value: str) -> bool:
         """Return True if the given identifier requires quoting."""
         lc_value = value.lower()
         return (
@@ -5364,7 +5565,7 @@ class IdentifierPreparer:
         not taking case convention into account."""
         return not self.legal_characters.match(str(value))
 
-    def quote_schema(self, schema, force=None):
+    def quote_schema(self, schema: str, force: Any = None) -> str:
         """Conditionally quote a schema name.
 
 
@@ -5403,7 +5604,7 @@ class IdentifierPreparer:
 
         return self.quote(schema)
 
-    def quote(self, ident, force=None):
+    def quote(self, ident: str, force: Any = None) -> str:
         """Conditionally quote an identifier.
 
         The identifier is quoted if it is a reserved word, contains
@@ -5474,11 +5675,19 @@ class IdentifierPreparer:
             name = self.quote_schema(effective_schema) + "." + name
         return name
 
-    def format_label(self, label, name=None):
+    def format_label(
+        self, label: Label[Any], name: Optional[str] = None
+    ) -> str:
         return self.quote(name or label.name)
 
-    def format_alias(self, alias, name=None):
-        return self.quote(name or alias.name)
+    def format_alias(
+        self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
+    ) -> str:
+        if name is None:
+            assert alias is not None
+            return self.quote(alias.name)
+        else:
+            return self.quote(name)
 
     def format_savepoint(self, savepoint, name=None):
         # Running the savepoint name through quoting is unnecessary
index 5aded307b68fd86b1f054bbbc053483883d1a438..96e90b0ea183e9d4f14d3062c2cca85b2a6eb4b2 100644 (file)
@@ -13,6 +13,10 @@ from __future__ import annotations
 
 import collections.abc as collections_abc
 import typing
+from typing import Any
+from typing import List
+from typing import MutableMapping
+from typing import Optional
 
 from . import coercions
 from . import roles
@@ -40,8 +44,8 @@ from .. import util
 
 class DMLState(CompileState):
     _no_parameters = True
-    _dict_parameters = None
-    _multi_parameters = None
+    _dict_parameters: Optional[MutableMapping[str, Any]] = None
+    _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None
     _ordered_values = None
     _parameter_ordering = None
     _has_multi_parameters = False
index 168da17ccc007fa4df644e146c8ac3a4aa8d28be..08d632afd944e53c1f9cfb0678984cac51891b25 100644 (file)
@@ -18,7 +18,9 @@ import re
 import typing
 from typing import Any
 from typing import Callable
+from typing import Dict
 from typing import Generic
+from typing import List
 from typing import Optional
 from typing import overload
 from typing import Sequence
@@ -47,6 +49,7 @@ from .coercions import _document_text_coercion  # noqa
 from .operators import ColumnOperators
 from .traversals import HasCopyInternals
 from .visitors import cloned_traverse
+from .visitors import ExternallyTraversible
 from .visitors import InternalTraversal
 from .visitors import traverse
 from .visitors import Visitable
@@ -68,6 +71,8 @@ if typing.TYPE_CHECKING:
     from ..engine import Connection
     from ..engine import Dialect
     from ..engine import Engine
+    from ..engine.base import _CompiledCacheType
+    from ..engine.base import _SchemaTranslateMapType
 
 
 _NUMERIC = Union[complex, "Decimal"]
@@ -238,6 +243,7 @@ class ClauseElement(
     SupportsWrappingAnnotations,
     MemoizedHasCacheKey,
     HasCopyInternals,
+    ExternallyTraversible,
     CompilerElement,
 ):
     """Base class for elements of a programmatically constructed SQL
@@ -398,7 +404,9 @@ class ClauseElement(
         """
         return self._replace_params(True, optionaldict, kwargs)
 
-    def params(self, *optionaldict, **kwargs):
+    def params(
+        self, *optionaldict: Dict[str, Any], **kwargs: Any
+    ) -> ClauseElement:
         """Return a copy with :func:`_expression.bindparam` elements
         replaced.
 
@@ -415,7 +423,12 @@ class ClauseElement(
         """
         return self._replace_params(False, optionaldict, kwargs)
 
-    def _replace_params(self, unique, optionaldict, kwargs):
+    def _replace_params(
+        self,
+        unique: bool,
+        optionaldict: Optional[Dict[str, Any]],
+        kwargs: Dict[str, Any],
+    ) -> ClauseElement:
 
         if len(optionaldict) == 1:
             kwargs.update(optionaldict[0])
@@ -487,12 +500,12 @@ class ClauseElement(
 
     def _compile_w_cache(
         self,
-        dialect,
-        compiled_cache=None,
-        column_keys=None,
-        for_executemany=False,
-        schema_translate_map=None,
-        **kw,
+        dialect: Dialect,
+        compiled_cache: Optional[_CompiledCacheType] = None,
+        column_keys: Optional[List[str]] = None,
+        for_executemany: bool = False,
+        schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+        **kw: Any,
     ):
         if compiled_cache is not None and dialect._supports_statement_cache:
             elem_cache_key = self._generate_cache_key()
@@ -1383,7 +1396,7 @@ class ColumnElement(
         """
         return Cast(self, type_)
 
-    def label(self, name):
+    def label(self, name: Optional[str]) -> Label[_T]:
         """Produce a column label, i.e. ``<columnname> AS <name>``.
 
         This is a shortcut to the :func:`_expression.label` function.
@@ -1608,6 +1621,9 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
         ("value", InternalTraversal.dp_plain_obj),
     ]
 
+    key: str
+    type: TypeEngine
+
     _is_crud = False
     _is_bind_parameter = True
     _key_is_anon = False
index eb3d17ee468284f88f3fd846270a21c26fcde8a1..6e5eec12712934021ff847608dbbca8866ab621c 100644 (file)
@@ -12,6 +12,7 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Sequence
 from typing import TypeVar
 
 from . import annotation
@@ -839,6 +840,8 @@ class Function(FunctionElement):
 
     identifier: str
 
+    packagenames: Sequence[str]
+
     type: TypeEngine = sqltypes.NULLTYPE
     """A :class:`_types.TypeEngine` object which refers to the SQL return
     type represented by this SQL function.
index 64bd4b951b9a8512444ef9af38594ce1591c46f2..1a7a5f4d4bf581965e65f81dc21da409a0fe8e73 100644 (file)
@@ -7,14 +7,22 @@
 from __future__ import annotations
 
 import typing
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
 
-from sqlalchemy.util.langhelpers import TypingOnly
 from .. import util
-
+from ..util import TypingOnly
+from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
+    from .base import ColumnCollection
     from .elements import ClauseElement
+    from .elements import Label
     from .selectable import FromClause
+    from .selectable import Subquery
 
 
 class SQLRole:
@@ -35,7 +43,7 @@ class SQLRole:
 
 class UsesInspection:
     __slots__ = ()
-    _post_inspect = None
+    _post_inspect: Literal[None] = None
     uses_inspection = True
 
 
@@ -96,7 +104,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
     _role_name = "Column expression or FROM clause"
 
     @property
-    def _select_iterable(self):
+    def _select_iterable(self) -> Sequence[ColumnsClauseRole]:
         raise NotImplementedError()
 
 
@@ -150,6 +158,9 @@ class ExpressionElementRole(SQLRole):
     __slots__ = ()
     _role_name = "SQL expression element"
 
+    def label(self, name: Optional[str]) -> Label[Any]:
+        raise NotImplementedError()
+
 
 class ConstExprRole(ExpressionElementRole):
     __slots__ = ()
@@ -187,7 +198,7 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
     _is_subquery = False
 
     @property
-    def _hide_froms(self):
+    def _hide_froms(self) -> Iterable[FromClause]:
         raise NotImplementedError()
 
 
@@ -195,8 +206,10 @@ class StrictFromClauseRole(FromClauseRole):
     __slots__ = ()
     # does not allow text() or select() objects
 
+    c: ColumnCollection
+
     @property
-    def description(self):
+    def description(self) -> str:
         raise NotImplementedError()
 
 
@@ -204,7 +217,9 @@ class AnonymizedFromClauseRole(StrictFromClauseRole):
     __slots__ = ()
     # calls .alias() as a post processor
 
-    def _anonymous_fromclause(self, name=None, flat=False):
+    def _anonymous_fromclause(
+        self, name: Optional[str] = None, flat: bool = False
+    ) -> FromClause:
         raise NotImplementedError()
 
 
@@ -220,14 +235,14 @@ class StatementRole(SQLRole):
     __slots__ = ()
     _role_name = "Executable SQL or text() construct"
 
-    _propagate_attrs = util.immutabledict()
+    _propagate_attrs: Mapping[str, Any] = util.immutabledict()
 
 
 class SelectStatementRole(StatementRole, ReturnsRowsRole):
     __slots__ = ()
     _role_name = "SELECT construct or equivalent text() construct"
 
-    def subquery(self):
+    def subquery(self) -> Subquery:
         raise NotImplementedError(
             "All SelectStatementRole objects should implement a "
             ".subquery() method."
index c270e15648b846dfc5ca557330b9dc1864e43348..33e300bf61044a07a0dba0b5e3512efd374cffd3 100644 (file)
@@ -51,7 +51,7 @@ from . import visitors
 from .base import DedupeColumnCollection
 from .base import DialectKWArgs
 from .base import Executable
-from .base import SchemaEventTarget
+from .base import SchemaEventTarget as SchemaEventTarget
 from .coercions import _document_text_coercion
 from .elements import ClauseElement
 from .elements import ColumnClause
@@ -2676,6 +2676,10 @@ class DefaultGenerator(Executable, SchemaItem):
     def __init__(self, for_update=False):
         self.for_update = for_update
 
+    @util.memoized_property
+    def is_callable(self):
+        raise NotImplementedError()
+
     def _set_parent(self, column, **kw):
         self.column = column
         if self.for_update:
index e5c2bef6860ffc86d01f5644e304f3e42039c6f5..09befb07868973be33a0016410c703f63ac0e60e 100644 (file)
@@ -53,7 +53,6 @@ from .base import Generative
 from .base import HasCompileState
 from .base import HasMemoized
 from .base import Immutable
-from .base import prefix_anon_map
 from .coercions import _document_text_coercion
 from .elements import _anonymous_label
 from .elements import BindParameter
@@ -69,10 +68,10 @@ from .elements import literal_column
 from .elements import TableValuedColumn
 from .elements import UnaryExpression
 from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
 from .. import exc
 from .. import util
 
-
 and_ = BooleanClauseList.and_
 
 _T = TypeVar("_T", bound=Any)
@@ -855,6 +854,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         return self.alias(name=name)
 
 
+class NamedFromClause(FromClause):
+    named_with_column = True
+
+    name: str
+
+
 class SelectLabelStyle(Enum):
     """Label style constants that may be passed to
     :meth:`_sql.Select.set_label_style`."""
@@ -1317,15 +1322,16 @@ class NoInit:
 #        -> Lateral -> FromClause, but we accept SelectBase
 #           w/ non-deprecated coercion
 #        -> TableSample -> only for FromClause
-class AliasedReturnsRows(NoInit, FromClause):
+class AliasedReturnsRows(NoInit, NamedFromClause):
     """Base class of aliases against tables, subqueries, and other
     selectables."""
 
     _is_from_container = True
-    named_with_column = True
 
     _supports_derived_columns = False
 
+    element: ClauseElement
+
     _traverse_internals = [
         ("element", InternalTraversal.dp_clauseelement),
         ("name", InternalTraversal.dp_anon_name),
@@ -1423,6 +1429,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows):
 
     inherit_cache = True
 
+    element: FromClause
+
     @classmethod
     def _factory(cls, selectable, name=None, flat=False):
         return coercions.expect(
@@ -1689,6 +1697,8 @@ class CTE(
         + HasSuffixes._has_suffixes_traverse_internals
     )
 
+    element: HasCTE
+
     @classmethod
     def _factory(cls, selectable, name=None, recursive=False):
         r"""Return a new :class:`_expression.CTE`,
@@ -1819,7 +1829,7 @@ class _CTEOpts(NamedTuple):
     nesting: bool
 
 
-class HasCTE(roles.HasCTERole):
+class HasCTE(roles.HasCTERole, ClauseElement):
     """Mixin that declares a class to include CTE support.
 
     .. versionadded:: 1.1
@@ -2247,6 +2257,8 @@ class Subquery(AliasedReturnsRows):
 
     inherit_cache = True
 
+    element: Select
+
     @classmethod
     def _factory(cls, selectable, name=None):
         """Return a :class:`.Subquery` object."""
@@ -2331,7 +2343,7 @@ class FromGrouping(GroupedElement, FromClause):
         self.element = state["element"]
 
 
-class TableClause(roles.DMLTableRole, Immutable, FromClause):
+class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
     """Represents a minimal "table" construct.
 
     This is a lightweight table object that has only a name, a
@@ -2371,8 +2383,6 @@ class TableClause(roles.DMLTableRole, Immutable, FromClause):
         ("name", InternalTraversal.dp_string),
     ]
 
-    named_with_column = True
-
     _is_table = True
 
     implicit_returning = False
@@ -2542,7 +2552,7 @@ class ForUpdateArg(ClauseElement):
 SelfValues = typing.TypeVar("SelfValues", bound="Values")
 
 
-class Values(Generative, FromClause):
+class Values(Generative, NamedFromClause):
     """Represent a ``VALUES`` construct that can be used as a FROM element
     in a statement.
 
@@ -2553,7 +2563,6 @@ class Values(Generative, FromClause):
 
     """
 
-    named_with_column = True
     __visit_name__ = "values"
 
     _data = ()
index 7d21f1262419309f8c94b8927c4ad9c2d407d118..b2b1d9bc29b09c03c5776db807325b0b75cc3455 100644 (file)
@@ -35,13 +35,13 @@ from .elements import _NONE_NAME
 from .elements import quoted_name
 from .elements import Slice
 from .elements import TypeCoerce as type_coerce  # noqa
-from .traversals import InternalTraversal
 from .type_api import Emulated
 from .type_api import NativeForEmulated  # noqa
 from .type_api import to_instance
 from .type_api import TypeDecorator
 from .type_api import TypeEngine
 from .type_api import Variant  # noqa
+from .visitors import InternalTraversal
 from .. import event
 from .. import exc
 from .. import inspection
index 4fa23d370cab657230949bcc428b5208f59bb4a3..cf9487f939b5767ef8258a33791a537d7e9e8228 100644 (file)
@@ -15,7 +15,10 @@ import operator
 import typing
 from typing import Any
 from typing import Callable
+from typing import Deque
 from typing import Dict
+from typing import Set
+from typing import Tuple
 from typing import Type
 from typing import TypeVar
 
@@ -23,9 +26,9 @@ from . import operators
 from .cache_key import HasCacheKey
 from .visitors import _TraverseInternalsType
 from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
+from .visitors import HasTraversalDispatch
 from .visitors import HasTraverseInternals
-from .visitors import InternalTraversal
 from .. import util
 from ..util import langhelpers
 
@@ -35,6 +38,7 @@ COMPARE_SUCCEEDED = True
 
 
 def compare(obj1, obj2, **kw):
+    strategy: TraversalComparatorStrategy
     if kw.get("use_proxies", False):
         strategy = ColIdentityComparatorStrategy()
     else:
@@ -45,16 +49,18 @@ def compare(obj1, obj2, **kw):
 
 def _preconfigure_traversals(target_hierarchy):
     for cls in util.walk_subclasses(target_hierarchy):
-        if hasattr(cls, "_traverse_internals"):
-            cls._generate_cache_attrs()
+        if hasattr(cls, "_generate_cache_attrs") and hasattr(
+            cls, "_traverse_internals"
+        ):
+            cls._generate_cache_attrs()  # type: ignore
             _copy_internals.generate_dispatch(
-                cls,
-                cls._traverse_internals,
+                cls,  # type: ignore
+                cls._traverse_internals,  # type: ignore
                 "_generated_copy_internals_traversal",
             )
             _get_children.generate_dispatch(
-                cls,
-                cls._traverse_internals,
+                cls,  # type: ignore
+                cls._traverse_internals,  # type: ignore
                 "_generated_get_children_traversal",
             )
 
@@ -125,54 +131,58 @@ class HasShallowCopy(HasTraverseInternals):
         meth_text = f"def {method_name}(self, d):\n{code}\n"
         return langhelpers._exec_code_in_env(meth_text, {}, method_name)
 
-    def _shallow_from_dict(self, d: Dict) -> None:
+    def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
         cls = self.__class__
 
+        shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
         try:
             shallow_from_dict = cls.__dict__[
                 "_generated_shallow_from_dict_traversal"
             ]
         except KeyError:
-            shallow_from_dict = (
-                cls._generated_shallow_from_dict_traversal  # type: ignore
-            ) = self._generate_shallow_from_dict(
+            shallow_from_dict = self._generate_shallow_from_dict(
                 cls._traverse_internals,
                 "_generated_shallow_from_dict_traversal",
             )
 
+            cls._generated_shallow_from_dict_traversal = shallow_from_dict  # type: ignore  # noqa E501
+
         shallow_from_dict(self, d)
 
     def _shallow_to_dict(self) -> Dict[str, Any]:
         cls = self.__class__
 
+        shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
+
         try:
             shallow_to_dict = cls.__dict__[
                 "_generated_shallow_to_dict_traversal"
             ]
         except KeyError:
-            shallow_to_dict = (
-                cls._generated_shallow_to_dict_traversal  # type: ignore
-            ) = self._generate_shallow_to_dict(
+            shallow_to_dict = self._generate_shallow_to_dict(
                 cls._traverse_internals, "_generated_shallow_to_dict_traversal"
             )
 
+            cls._generated_shallow_to_dict_traversal = shallow_to_dict  # type: ignore  # noqa E501
         return shallow_to_dict(self)
 
-    def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy):
+    def _shallow_copy_to(
+        self: SelfHasShallowCopy, other: SelfHasShallowCopy
+    ) -> None:
         cls = self.__class__
 
+        shallow_copy: Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None]
         try:
             shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
         except KeyError:
-            shallow_copy = (
-                cls._generated_shallow_copy_traversal  # type: ignore
-            ) = self._generate_shallow_copy(
+            shallow_copy = self._generate_shallow_copy(
                 cls._traverse_internals, "_generated_shallow_copy_traversal"
             )
 
+            cls._generated_shallow_copy_traversal = shallow_copy  # type: ignore  # noqa: E501
         shallow_copy(self, other)
 
-    def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy:
+    def _clone(self: SelfHasShallowCopy, **kw: Any) -> SelfHasShallowCopy:
         """Create a shallow copy"""
         c = self.__class__.__new__(self.__class__)
         self._shallow_copy_to(c)
@@ -246,7 +256,7 @@ class HasCopyInternals(HasTraverseInternals):
                     setattr(self, attrname, result)
 
 
-class _CopyInternalsTraversal(InternalTraversal):
+class _CopyInternalsTraversal(HasTraversalDispatch):
     """Generate a _copy_internals internal traversal dispatch for classes
     with a _traverse_internals collection."""
 
@@ -381,7 +391,7 @@ def _flatten_clauseelement(element):
     return element
 
 
-class _GetChildrenTraversal(InternalTraversal):
+class _GetChildrenTraversal(HasTraversalDispatch):
     """Generate a _children_traversal internal traversal dispatch for classes
     with a _traverse_internals collection."""
 
@@ -463,13 +473,13 @@ def _resolve_name_for_compare(element, name, anon_map, **kw):
     return name
 
 
-class TraversalComparatorStrategy(
-    ExtendedInternalTraversal, util.MemoizedSlots
-):
+class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
     __slots__ = "stack", "cache", "anon_map"
 
     def __init__(self):
-        self.stack = deque()
+        self.stack: Deque[
+            Tuple[ExternallyTraversible, ExternallyTraversible]
+        ] = deque()
         self.cache = set()
 
     def _memoized_attr_anon_map(self):
@@ -653,7 +663,7 @@ class TraversalComparatorStrategy(
         if seq1 is None:
             return seq2 is None
 
-        completed = set()
+        completed: Set[object] = set()
         for clause in seq1:
             for other_clause in set(seq2).difference(completed):
                 if self.compare_inner(clause, other_clause, **kw):
index e0248adf0df4515ca11dbf39fdfafaebc683bce1..5114a2431da84abe72d8259fdbfe8733b806a3d0 100644 (file)
@@ -21,9 +21,9 @@ from . import coercions
 from . import operators
 from . import roles
 from . import visitors
-from .annotation import _deep_annotate  # noqa
-from .annotation import _deep_deannotate  # noqa
-from .annotation import _shallow_annotate  # noqa
+from .annotation import _deep_annotate as _deep_annotate
+from .annotation import _deep_deannotate as _deep_deannotate
+from .annotation import _shallow_annotate as _shallow_annotate
 from .base import _expand_cloned
 from .base import _from_objects
 from .base import ColumnSet
index 111ecd32ef37bef87281c952d115ac27eb1facc5..0c41e440efa266123b723c75a3730a8486c38dd2 100644 (file)
@@ -7,43 +7,46 @@
 
 """Visitor/traversal interface and library functions.
 
-SQLAlchemy schema and expression constructs rely on a Python-centric
-version of the classic "visitor" pattern as the primary way in which
-they apply functionality.  The most common use of this pattern
-is statement compilation, where individual expression classes match
-up to rendering methods that produce a string result.   Beyond this,
-the visitor system is also used to inspect expressions for various
-information and patterns, as well as for the purposes of applying
-transformations to expressions.
-
-Examples of how the visit system is used can be seen in the source code
-of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler``
-modules.  Some background on clause adaption is also at
-https://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
 
 """
 
 from __future__ import annotations
 
 from collections import deque
+from enum import Enum
 import itertools
 import operator
 import typing
 from typing import Any
+from typing import Callable
+from typing import cast
+from typing import ClassVar
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
 from typing import List
+from typing import Mapping
+from typing import Optional
 from typing import Tuple
+from typing import Type
+from typing import TypeVar
+from typing import Union
 
 from .. import exc
 from .. import util
 from ..util import langhelpers
-from ..util import symbol
 from ..util._has_cy import HAS_CYEXTENSION
-from ..util.langhelpers import _symbol
+from ..util.typing import Protocol
+from ..util.typing import Self
 
 if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
-    from ._py_util import cache_anon_map as anon_map  # noqa
+    from ._py_util import prefix_anon_map as prefix_anon_map
+    from ._py_util import cache_anon_map as anon_map
 else:
-    from sqlalchemy.cyextension.util import cache_anon_map as anon_map  # noqa
+    from sqlalchemy.cyextension.util import prefix_anon_map as prefix_anon_map
+    from sqlalchemy.cyextension.util import cache_anon_map as anon_map
+
 
 __all__ = [
     "iterate",
@@ -54,57 +57,23 @@ __all__ = [
     "Visitable",
     "ExternalTraversal",
     "InternalTraversal",
+    "anon_map",
 ]
 
-_TraverseInternalsType = List[Tuple[str, _symbol]]
-
-
-class HasTraverseInternals:
-    """base for classes that have a "traverse internals" element,
-    which defines all kinds of ways of traversing the elements of an object.
-
-    """
-
-    __slots__ = ()
-
-    _traverse_internals: _TraverseInternalsType
-
-    @util.preload_module("sqlalchemy.sql.traversals")
-    def get_children(self, omit_attrs=(), **kw):
-        r"""Return immediate child :class:`.visitors.Visitable`
-        elements of this :class:`.visitors.Visitable`.
-
-        This is used for visit traversal.
-
-        \**kw may contain flags that change the collection that is
-        returned, for example to return a subset of items in order to
-        cut down on larger traversals, or to return child items from a
-        different context (such as schema-level collections instead of
-        clause-level).
-
-        """
-
-        traversals = util.preloaded.sql_traversals
-
-        try:
-            traverse_internals = self._traverse_internals
-        except AttributeError:
-            # user-defined classes may not have a _traverse_internals
-            return []
 
-        dispatch = traversals._get_children.run_generated_dispatch
-        return itertools.chain.from_iterable(
-            meth(obj, **kw)
-            for attrname, obj, meth in dispatch(
-                self, traverse_internals, "_generated_get_children_traversal"
-            )
-            if attrname not in omit_attrs and obj is not None
-        )
+class _CompilerDispatchType(Protocol):
+    def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any:
+        ...
 
 
 class Visitable:
     """Base class for visitable objects.
 
+    :class:`.Visitable` is used to implement the SQL compiler dispatch
+    functions.    Other forms of traversal such as for cache key generation
+    are implemented separately using the :class:`.HasTraverseInternals`
+    interface.
+
     .. versionchanged:: 2.0  The :class:`.Visitable` class was named
        :class:`.Traversible` in the 1.4 series; the name is changed back
        to :class:`.Visitable` in 2.0 which is what it was prior to 1.4.
@@ -117,32 +86,20 @@ class Visitable:
 
     __visit_name__: str
 
+    _original_compiler_dispatch: _CompilerDispatchType
+
+    if typing.TYPE_CHECKING:
+
+        def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str:
+            ...
+
     def __init_subclass__(cls) -> None:
         if "__visit_name__" in cls.__dict__:
             cls._generate_compiler_dispatch()
         super().__init_subclass__()
 
     @classmethod
-    def _generate_compiler_dispatch(cls):
-        """Assign dispatch attributes to various kinds of
-        "visitable" classes.
-
-        Attributes include:
-
-        * The ``_compiler_dispatch`` method, corresponding to
-          ``__visit_name__``. This is called "external traversal" because the
-          caller of each visit() method is responsible for sub-traversing the
-          inner elements of each object. This is appropriate for string
-          compilers and other traversals that need to call upon the inner
-          elements in a specific pattern.
-
-        * internal traversal collections ``_children_traversal``,
-          ``_cache_key_traversal``, ``_copy_internals_traversal``, generated
-          from an optional ``_traverse_internals`` collection of symbols which
-          comes from the :class:`.InternalTraversal` list of symbols. This is
-          called "internal traversal".
-
-        """
+    def _generate_compiler_dispatch(cls) -> None:
         visit_name = cls.__visit_name__
 
         if "_compiler_dispatch" in cls.__dict__:
@@ -161,7 +118,9 @@ class Visitable:
         name = "visit_%s" % visit_name
         getter = operator.attrgetter(name)
 
-        def _compiler_dispatch(self, visitor, **kw):
+        def _compiler_dispatch(
+            self: Visitable, visitor: Any, **kw: Any
+        ) -> str:
             """Look for an attribute named "visit_<visit_name>" on the
             visitor, and call it with the same kw params.
 
@@ -169,105 +128,20 @@ class Visitable:
             try:
                 meth = getter(visitor)
             except AttributeError as err:
-                return visitor.visit_unsupported_compilation(self, err, **kw)
+                return visitor.visit_unsupported_compilation(self, err, **kw)  # type: ignore  # noqa E501
             else:
-                return meth(self, **kw)
+                return meth(self, **kw)  # type: ignore  # noqa E501
 
-        cls._compiler_dispatch = (
+        cls._compiler_dispatch = (  # type: ignore
             cls._original_compiler_dispatch
         ) = _compiler_dispatch
 
-    def __class_getitem__(cls, key):
+    def __class_getitem__(cls, key: str) -> Any:
         # allow generic classes in py3.9+
         return cls
 
 
-class _HasTraversalDispatch:
-    r"""Define infrastructure for the :class:`.InternalTraversal` class.
-
-    .. versionadded:: 2.0
-
-    """
-
-    __slots__ = ()
-
-    def __init_subclass__(cls) -> None:
-        cls._generate_traversal_dispatch()
-        super().__init_subclass__()
-
-    def dispatch(self, visit_symbol):
-        """Given a method from :class:`._HasTraversalDispatch`, return the
-        corresponding method on a subclass.
-
-        """
-        name = self._dispatch_lookup[visit_symbol]
-        return getattr(self, name, None)
-
-    def run_generated_dispatch(
-        self, target, internal_dispatch, generate_dispatcher_name
-    ):
-        try:
-            dispatcher = target.__class__.__dict__[generate_dispatcher_name]
-        except KeyError:
-            # most of the dispatchers are generated up front
-            # in sqlalchemy/sql/__init__.py ->
-            # traversals.py-> _preconfigure_traversals().
-            # this block will generate any remaining dispatchers.
-            dispatcher = self.generate_dispatch(
-                target.__class__, internal_dispatch, generate_dispatcher_name
-            )
-        return dispatcher(target, self)
-
-    def generate_dispatch(
-        self, target_cls, internal_dispatch, generate_dispatcher_name
-    ):
-        dispatcher = self._generate_dispatcher(
-            internal_dispatch, generate_dispatcher_name
-        )
-        # assert isinstance(target_cls, type)
-        setattr(target_cls, generate_dispatcher_name, dispatcher)
-        return dispatcher
-
-    @classmethod
-    def _generate_traversal_dispatch(cls):
-        lookup = {}
-        clsdict = cls.__dict__
-        for key, sym in clsdict.items():
-            if key.startswith("dp_"):
-                visit_key = key.replace("dp_", "visit_")
-                sym_name = sym.name
-                assert sym_name not in lookup, sym_name
-                lookup[sym] = lookup[sym_name] = visit_key
-        if hasattr(cls, "_dispatch_lookup"):
-            lookup.update(cls._dispatch_lookup)
-        cls._dispatch_lookup = lookup
-
-    def _generate_dispatcher(self, internal_dispatch, method_name):
-        names = []
-        for attrname, visit_sym in internal_dispatch:
-            meth = self.dispatch(visit_sym)
-            if meth:
-                visit_name = ExtendedInternalTraversal._dispatch_lookup[
-                    visit_sym
-                ]
-                names.append((attrname, visit_name))
-
-        code = (
-            ("    return [\n")
-            + (
-                ", \n".join(
-                    "        (%r, self.%s, visitor.%s)"
-                    % (attrname, attrname, visit_name)
-                    for attrname, visit_name in names
-                )
-            )
-            + ("\n    ]\n")
-        )
-        meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
-        return langhelpers._exec_code_in_env(meth_text, {}, method_name)
-
-
-class InternalTraversal(_HasTraversalDispatch):
+class InternalTraversal(Enum):
     r"""Defines visitor symbols used for internal traversal.
 
     The :class:`.InternalTraversal` class is used in two ways.  One is that
@@ -306,18 +180,16 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    __slots__ = ()
-
-    dp_has_cache_key = symbol("HC")
+    dp_has_cache_key = "HC"
     """Visit a :class:`.HasCacheKey` object."""
 
-    dp_has_cache_key_list = symbol("HL")
+    dp_has_cache_key_list = "HL"
     """Visit a list of :class:`.HasCacheKey` objects."""
 
-    dp_clauseelement = symbol("CE")
+    dp_clauseelement = "CE"
     """Visit a :class:`_expression.ClauseElement` object."""
 
-    dp_fromclause_canonical_column_collection = symbol("FC")
+    dp_fromclause_canonical_column_collection = "FC"
     """Visit a :class:`_expression.FromClause` object in the context of the
     ``columns`` attribute.
 
@@ -329,30 +201,30 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_clauseelement_tuples = symbol("CTS")
+    dp_clauseelement_tuples = "CTS"
     """Visit a list of tuples which contain :class:`_expression.ClauseElement`
     objects.
 
     """
 
-    dp_clauseelement_list = symbol("CL")
+    dp_clauseelement_list = "CL"
     """Visit a list of :class:`_expression.ClauseElement` objects.
 
     """
 
-    dp_clauseelement_tuple = symbol("CT")
+    dp_clauseelement_tuple = "CT"
     """Visit a tuple of :class:`_expression.ClauseElement` objects.
 
     """
 
-    dp_executable_options = symbol("EO")
+    dp_executable_options = "EO"
 
-    dp_with_context_options = symbol("WC")
+    dp_with_context_options = "WC"
 
-    dp_fromclause_ordered_set = symbol("CO")
+    dp_fromclause_ordered_set = "CO"
     """Visit an ordered set of :class:`_expression.FromClause` objects. """
 
-    dp_string = symbol("S")
+    dp_string = "S"
     """Visit a plain string value.
 
     Examples include table and column names, bound parameter keys, special
@@ -363,10 +235,10 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_string_list = symbol("SL")
+    dp_string_list = "SL"
     """Visit a list of strings."""
 
-    dp_anon_name = symbol("AN")
+    dp_anon_name = "AN"
     """Visit a potentially "anonymized" string value.
 
     The string value is considered to be significant for cache key
@@ -374,7 +246,7 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_boolean = symbol("B")
+    dp_boolean = "B"
     """Visit a boolean value.
 
     The boolean value is considered to be significant for cache key
@@ -382,7 +254,7 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_operator = symbol("O")
+    dp_operator = "O"
     """Visit an operator.
 
     The operator is a function from the :mod:`sqlalchemy.sql.operators`
@@ -393,7 +265,7 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_type = symbol("T")
+    dp_type = "T"
     """Visit a :class:`.TypeEngine` object
 
     The type object is considered to be significant for cache key
@@ -401,7 +273,7 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_plain_dict = symbol("PD")
+    dp_plain_dict = "PD"
     """Visit a dictionary with string keys.
 
     The keys of the dictionary should be strings, the values should
@@ -410,22 +282,22 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_dialect_options = symbol("DO")
+    dp_dialect_options = "DO"
     """Visit a dialect options structure."""
 
-    dp_string_clauseelement_dict = symbol("CD")
+    dp_string_clauseelement_dict = "CD"
     """Visit a dictionary of string keys to :class:`_expression.ClauseElement`
     objects.
 
     """
 
-    dp_string_multi_dict = symbol("MD")
+    dp_string_multi_dict = "MD"
     """Visit a dictionary of string keys to values which may either be
     plain immutable/hashable or :class:`.HasCacheKey` objects.
 
     """
 
-    dp_annotations_key = symbol("AK")
+    dp_annotations_key = "AK"
     """Visit the _annotations_cache_key element.
 
     This is a dictionary of additional information about a ClauseElement
@@ -436,7 +308,7 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_plain_obj = symbol("PO")
+    dp_plain_obj = "PO"
     """Visit a plain python object.
 
     The value should be immutable and hashable, such as an integer.
@@ -444,7 +316,7 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_named_ddl_element = symbol("DD")
+    dp_named_ddl_element = "DD"
     """Visit a simple named DDL element.
 
     The current object used by this method is the :class:`.Sequence`.
@@ -454,57 +326,56 @@ class InternalTraversal(_HasTraversalDispatch):
 
     """
 
-    dp_prefix_sequence = symbol("PS")
+    dp_prefix_sequence = "PS"
     """Visit the sequence represented by :class:`_expression.HasPrefixes`
     or :class:`_expression.HasSuffixes`.
 
     """
 
-    dp_table_hint_list = symbol("TH")
+    dp_table_hint_list = "TH"
     """Visit the ``_hints`` collection of a :class:`_expression.Select`
     object.
 
     """
 
-    dp_setup_join_tuple = symbol("SJ")
+    dp_setup_join_tuple = "SJ"
 
-    dp_memoized_select_entities = symbol("ME")
+    dp_memoized_select_entities = "ME"
 
-    dp_statement_hint_list = symbol("SH")
+    dp_statement_hint_list = "SH"
     """Visit the ``_statement_hints`` collection of a
     :class:`_expression.Select`
     object.
 
     """
 
-    dp_unknown_structure = symbol("UK")
+    dp_unknown_structure = "UK"
     """Visit an unknown structure.
 
     """
 
-    dp_dml_ordered_values = symbol("DML_OV")
+    dp_dml_ordered_values = "DML_OV"
     """Visit the values() ordered tuple list of an
     :class:`_expression.Update` object."""
 
-    dp_dml_values = symbol("DML_V")
+    dp_dml_values = "DML_V"
     """Visit the values() dictionary of a :class:`.ValuesBase`
     (e.g. Insert or Update) object.
 
     """
 
-    dp_dml_multi_values = symbol("DML_MV")
+    dp_dml_multi_values = "DML_MV"
     """Visit the values() multi-valued list of dictionaries of an
     :class:`_expression.Insert` object.
 
     """
 
-    dp_propagate_attrs = symbol("PA")
+    dp_propagate_attrs = "PA"
     """Visit the propagate attrs dict.  This hardcodes to the particular
     elements we care about right now."""
 
-
-class ExtendedInternalTraversal(InternalTraversal):
-    """Defines additional symbols that are useful in caching applications.
+    """Symbols that follow are additional symbols that are useful in
+    caching applications.
 
     Traversals for :class:`_expression.ClauseElement` objects only need to use
     those symbols present in :class:`.InternalTraversal`.  However, for
@@ -513,9 +384,7 @@ class ExtendedInternalTraversal(InternalTraversal):
 
     """
 
-    __slots__ = ()
-
-    dp_ignore = symbol("IG")
+    dp_ignore = "IG"
     """Specify an object that should be ignored entirely.
 
     This currently applies function call argument caching where some
@@ -523,29 +392,235 @@ class ExtendedInternalTraversal(InternalTraversal):
 
     """
 
-    dp_inspectable = symbol("IS")
+    dp_inspectable = "IS"
     """Visit an inspectable object where the return value is a
     :class:`.HasCacheKey` object."""
 
-    dp_multi = symbol("M")
+    dp_multi = "M"
     """Visit an object that may be a :class:`.HasCacheKey` or may be a
     plain hashable object."""
 
-    dp_multi_list = symbol("MT")
+    dp_multi_list = "MT"
     """Visit a tuple containing elements that may be :class:`.HasCacheKey` or
     may be a plain hashable object."""
 
-    dp_has_cache_key_tuples = symbol("HT")
+    dp_has_cache_key_tuples = "HT"
     """Visit a list of tuples which contain :class:`.HasCacheKey`
     objects.
 
     """
 
-    dp_inspectable_list = symbol("IL")
+    dp_inspectable_list = "IL"
     """Visit a list of inspectable objects which upon inspection are
     HasCacheKey objects."""
 
 
+_TraverseInternalsType = List[Tuple[str, InternalTraversal]]
+"""a structure that defines how a HasTraverseInternals should be
+traversed.
+
+This structure consists of a list of (attributename, internaltraversal)
+tuples, where the "attributename" refers to the name of an attribute on an
+instance of the HasTraverseInternals object, and "internaltraversal" refers
+to an :class:`.InternalTraversal` enumeration symbol defining what kind
+of data this attribute stores, which indicates to the traverser how it should
+be handled.
+
+"""
+
+
+class HasTraverseInternals:
+    """base for classes that have a "traverse internals" element,
+    which defines all kinds of ways of traversing the elements of an object.
+
+    Compared to :class:`.Visitable`, which relies upon an external visitor to
+    define how the object is travered (i.e. the :class:`.SQLCompiler`), the
+    :class:`.HasTraverseInternals` interface allows classes to define their own
+    traversal, that is, what attributes are accessed and in what order.
+
+    """
+
+    __slots__ = ()
+
+    _traverse_internals: _TraverseInternalsType
+
+    @util.preload_module("sqlalchemy.sql.traversals")
+    def get_children(
+        self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+    ) -> Iterable[HasTraverseInternals]:
+        r"""Return immediate child :class:`.visitors.HasTraverseInternals`
+        elements of this :class:`.visitors.HasTraverseInternals`.
+
+        This is used for visit traversal.
+
+        \**kw may contain flags that change the collection that is
+        returned, for example to return a subset of items in order to
+        cut down on larger traversals, or to return child items from a
+        different context (such as schema-level collections instead of
+        clause-level).
+
+        """
+
+        traversals = util.preloaded.sql_traversals
+
+        try:
+            traverse_internals = self._traverse_internals
+        except AttributeError:
+            # user-defined classes may not have a _traverse_internals
+            return []
+
+        dispatch = traversals._get_children.run_generated_dispatch
+        return itertools.chain.from_iterable(
+            meth(obj, **kw)
+            for attrname, obj, meth in dispatch(
+                self, traverse_internals, "_generated_get_children_traversal"
+            )
+            if attrname not in omit_attrs and obj is not None
+        )
+
+
+class _InternalTraversalDispatchType(Protocol):
+    def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any:
+        ...
+
+
+class HasTraversalDispatch:
+    r"""Define infrastructure for classes that perform internal traversals
+
+    .. versionadded:: 2.0
+
+    """
+
+    __slots__ = ()
+
+    _dispatch_lookup: ClassVar[Dict[Union[InternalTraversal, str], str]] = {}
+
+    def dispatch(self, visit_symbol: InternalTraversal) -> Callable[..., Any]:
+        """Given a method from :class:`.HasTraversalDispatch`, return the
+        corresponding method on a subclass.
+
+        """
+        name = _dispatch_lookup[visit_symbol]
+        return getattr(self, name, None)  # type: ignore
+
+    def run_generated_dispatch(
+        self,
+        target: object,
+        internal_dispatch: _TraverseInternalsType,
+        generate_dispatcher_name: str,
+    ) -> Any:
+        dispatcher: _InternalTraversalDispatchType
+        try:
+            dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+        except KeyError:
+            # traversals.py -> _preconfigure_traversals()
+            # may be used to run these ahead of time, but
+            # is not enabled right now.
+            # this block will generate any remaining dispatchers.
+            dispatcher = self.generate_dispatch(
+                target.__class__, internal_dispatch, generate_dispatcher_name
+            )
+        return dispatcher(target, self)
+
+    def generate_dispatch(
+        self,
+        target_cls: Type[object],
+        internal_dispatch: _TraverseInternalsType,
+        generate_dispatcher_name: str,
+    ) -> _InternalTraversalDispatchType:
+        dispatcher = self._generate_dispatcher(
+            internal_dispatch, generate_dispatcher_name
+        )
+        # assert isinstance(target_cls, type)
+        setattr(target_cls, generate_dispatcher_name, dispatcher)
+        return dispatcher
+
+    def _generate_dispatcher(
+        self, internal_dispatch: _TraverseInternalsType, method_name: str
+    ) -> _InternalTraversalDispatchType:
+        names = []
+        for attrname, visit_sym in internal_dispatch:
+            meth = self.dispatch(visit_sym)
+            if meth:
+                visit_name = _dispatch_lookup[visit_sym]
+                names.append((attrname, visit_name))
+
+        code = (
+            ("    return [\n")
+            + (
+                ", \n".join(
+                    "        (%r, self.%s, visitor.%s)"
+                    % (attrname, attrname, visit_name)
+                    for attrname, visit_name in names
+                )
+            )
+            + ("\n    ]\n")
+        )
+        meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+        return cast(
+            _InternalTraversalDispatchType,
+            langhelpers._exec_code_in_env(meth_text, {}, method_name),
+        )
+
+
+ExtendedInternalTraversal = InternalTraversal
+
+
+def _generate_traversal_dispatch() -> None:
+    lookup = _dispatch_lookup
+
+    for sym in InternalTraversal:
+        key = sym.name
+        if key.startswith("dp_"):
+            visit_key = key.replace("dp_", "visit_")
+            sym_name = sym.value
+            assert sym_name not in lookup, sym_name
+            lookup[sym] = lookup[sym_name] = visit_key
+
+
+_dispatch_lookup = HasTraversalDispatch._dispatch_lookup
+_generate_traversal_dispatch()
+
+
+class ExternallyTraversible(HasTraverseInternals, Visitable):
+    __slots__ = ()
+
+    _annotations: Collection[Any] = ()
+
+    if typing.TYPE_CHECKING:
+
+        def get_children(
+            self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+        ) -> Iterable[ExternallyTraversible]:
+            ...
+
+    def _clone(self: Self, **kw: Any) -> Self:
+        """clone this element"""
+        raise NotImplementedError()
+
+    def _copy_internals(
+        self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+    ) -> Self:
+        """Reassign internal elements to be clones of themselves.
+
+        Called during a copy-and-traverse operation on newly
+        shallow-copied elements to create a deep copy.
+
+        The given clone function should be used, which may be applying
+        additional transformations to the element (i.e. replacement
+        traversal, cloned traversal, annotations).
+
+        """
+        raise NotImplementedError()
+
+
+_ET = TypeVar("_ET", bound=ExternallyTraversible)
+_TraverseCallableType = Callable[[_ET], None]
+_TraverseTransformCallableType = Callable[
+    [ExternallyTraversible], Optional[ExternallyTraversible]
+]
+
+
 class ExternalTraversal:
     """Base class for visitor objects which can traverse externally using
     the :func:`.visitors.traverse` function.
@@ -555,7 +630,8 @@ class ExternalTraversal:
 
     """
 
-    __traverse_options__ = {}
+    __traverse_options__: Dict[str, Any] = {}
+    _next: Optional[ExternalTraversal]
 
     def traverse_single(self, obj: Visitable, **kw: Any) -> Any:
         for v in self.visitor_iterator:
@@ -563,20 +639,22 @@ class ExternalTraversal:
             if meth:
                 return meth(obj, **kw)
 
-    def iterate(self, obj):
+    def iterate(
+        self, obj: ExternallyTraversible
+    ) -> Iterator[ExternallyTraversible]:
         """Traverse the given expression structure, returning an iterator
         of all elements.
 
         """
         return iterate(obj, self.__traverse_options__)
 
-    def traverse(self, obj):
+    def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
         """Traverse and visit the given expression structure."""
 
         return traverse(obj, self.__traverse_options__, self._visitor_dict)
 
     @util.memoized_property
-    def _visitor_dict(self):
+    def _visitor_dict(self) -> Dict[str, _TraverseCallableType[Any]]:
         visitors = {}
 
         for name in dir(self):
@@ -585,16 +663,16 @@ class ExternalTraversal:
         return visitors
 
     @property
-    def visitor_iterator(self):
+    def visitor_iterator(self) -> Iterator[ExternalTraversal]:
         """Iterate through this visitor and each 'chained' visitor."""
 
-        v = self
+        v: Optional[ExternalTraversal] = self
         while v:
             yield v
             v = getattr(v, "_next", None)
 
-    def chain(self, visitor):
-        """'Chain' an additional ClauseVisitor onto this ClauseVisitor.
+    def chain(self, visitor: ExternalTraversal) -> ExternalTraversal:
+        """'Chain' an additional ExternalTraversal onto this ExternalTraversal
 
         The chained visitor will receive all visit events after this one.
 
@@ -614,14 +692,16 @@ class CloningExternalTraversal(ExternalTraversal):
 
     """
 
-    def copy_and_process(self, list_):
+    def copy_and_process(
+        self, list_: List[ExternallyTraversible]
+    ) -> List[ExternallyTraversible]:
         """Apply cloned traversal to the given list of elements, and return
         the new list.
 
         """
         return [self.traverse(x) for x in list_]
 
-    def traverse(self, obj):
+    def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
         """Traverse and visit the given expression structure."""
 
         return cloned_traverse(
@@ -638,7 +718,9 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
 
     """
 
-    def replace(self, elem):
+    def replace(
+        self, elem: ExternallyTraversible
+    ) -> Optional[ExternallyTraversible]:
         """Receive pre-copied elements during a cloning traversal.
 
         If the method returns a new element, the element is used
@@ -647,15 +729,19 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
         """
         return None
 
-    def traverse(self, obj):
+    def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
         """Traverse and visit the given expression structure."""
 
-        def replace(elem):
+        def replace(
+            elem: ExternallyTraversible,
+        ) -> Optional[ExternallyTraversible]:
             for v in self.visitor_iterator:
-                e = v.replace(elem)
+                e = cast(ReplacingExternalTraversal, v).replace(elem)
                 if e is not None:
                     return e
 
+            return None
+
         return replacement_traverse(obj, self.__traverse_options__, replace)
 
 
@@ -667,7 +753,9 @@ CloningVisitor = CloningExternalTraversal
 ReplacingCloningVisitor = ReplacingExternalTraversal
 
 
-def iterate(obj, opts=util.immutabledict()):
+def iterate(
+    obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT
+) -> Iterator[ExternallyTraversible]:
     r"""Traverse the given expression structure, returning an iterator.
 
     Traversal is configured to be breadth-first.
@@ -702,7 +790,11 @@ def iterate(obj, opts=util.immutabledict()):
             stack.append(t.get_children(**opts))
 
 
-def traverse_using(iterator, obj, visitors):
+def traverse_using(
+    iterator: Iterable[ExternallyTraversible],
+    obj: ExternallyTraversible,
+    visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> ExternallyTraversible:
     """Visit the given expression structure using the given iterator of
     objects.
 
@@ -734,7 +826,11 @@ def traverse_using(iterator, obj, visitors):
     return obj
 
 
-def traverse(obj, opts, visitors):
+def traverse(
+    obj: ExternallyTraversible,
+    opts: Mapping[str, Any],
+    visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> ExternallyTraversible:
     """Traverse and visit the given expression structure using the default
     iterator.
 
@@ -767,7 +863,11 @@ def traverse(obj, opts, visitors):
     return traverse_using(iterate(obj, opts), obj, visitors)
 
 
-def cloned_traverse(obj, opts, visitors):
+def cloned_traverse(
+    obj: ExternallyTraversible,
+    opts: Mapping[str, Any],
+    visitors: Mapping[str, _TraverseTransformCallableType],
+) -> ExternallyTraversible:
     """Clone the given expression structure, allowing modifications by
     visitors.
 
@@ -794,20 +894,24 @@ def cloned_traverse(obj, opts, visitors):
 
     """
 
-    cloned = {}
+    cloned: Dict[int, ExternallyTraversible] = {}
     stop_on = set(opts.get("stop_on", []))
 
-    def deferred_copy_internals(obj):
+    def deferred_copy_internals(
+        obj: ExternallyTraversible,
+    ) -> ExternallyTraversible:
         return cloned_traverse(obj, opts, visitors)
 
-    def clone(elem, **kw):
+    def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible:
         if elem in stop_on:
             return elem
         else:
             if id(elem) not in cloned:
 
                 if "replace" in kw:
-                    newelem = kw["replace"](elem)
+                    newelem = cast(
+                        Optional[ExternallyTraversible], kw["replace"](elem)
+                    )
                     if newelem is not None:
                         cloned[id(elem)] = newelem
                         return newelem
@@ -823,11 +927,15 @@ def cloned_traverse(obj, opts, visitors):
         obj = clone(
             obj, deferred_copy_internals=deferred_copy_internals, **opts
         )
-    clone = None  # remove gc cycles
+    clone = None  # type: ignore[assignment]  # remove gc cycles
     return obj
 
 
-def replacement_traverse(obj, opts, replace):
+def replacement_traverse(
+    obj: ExternallyTraversible,
+    opts: Mapping[str, Any],
+    replace: _TraverseTransformCallableType,
+) -> ExternallyTraversible:
     """Clone the given expression structure, allowing element
     replacement by a given replacement function.
 
@@ -854,10 +962,12 @@ def replacement_traverse(obj, opts, replace):
     cloned = {}
     stop_on = {id(x) for x in opts.get("stop_on", [])}
 
-    def deferred_copy_internals(obj):
+    def deferred_copy_internals(
+        obj: ExternallyTraversible,
+    ) -> ExternallyTraversible:
         return replacement_traverse(obj, opts, replace)
 
-    def clone(elem, **kw):
+    def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible:
         if (
             id(elem) in stop_on
             or "no_replacement_traverse" in elem._annotations
@@ -888,5 +998,5 @@ def replacement_traverse(obj, opts, replace):
         obj = clone(
             obj, deferred_copy_internals=deferred_copy_internals, **opts
         )
-    clone = None  # remove gc cycles
+    clone = None  # type: ignore[assignment]  # remove gc cycles
     return obj
index 8cb84f73f5193a8f3fbbaed45b49b962481999e0..ae61155ffaa7f86b106ff218beb9272f5ea952fe 100644 (file)
@@ -305,9 +305,11 @@ def _update_argspec_defaults_into_env(spec, env):
         return spec
 
 
-def _exec_code_in_env(code, env, fn_name):
+def _exec_code_in_env(
+    code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
+) -> Callable[..., Any]:
     exec(code, env)
-    return env[fn_name]
+    return env[fn_name]  # type: ignore[no-any-return]
 
 
 _PF = TypeVar("_PF")
@@ -1181,7 +1183,7 @@ class memoized_property(Generic[_T]):
         obj.__dict__.pop(name, None)
 
 
-def memoized_instancemethod(fn):
+def memoized_instancemethod(fn: _F) -> _F:
     """Decorate a method memoize its return value.
 
     Best applied to no-arg methods: memoization is not sensitive to
@@ -1201,7 +1203,7 @@ def memoized_instancemethod(fn):
         self.__dict__[fn.__name__] = memo
         return result
 
-    return update_wrapper(oneshot, fn)
+    return update_wrapper(oneshot, fn)  # type: ignore
 
 
 class HasMemoized:
index 160eabd85fcdf4e3bce0296dec7dd6b26af67e66..c089616e4e65b68f0923d77d4cf6fa8b929e0a30 100644 (file)
@@ -3,10 +3,11 @@ from __future__ import annotations
 import sys
 import typing
 from typing import Any
-from typing import Callable  # noqa
 from typing import cast
 from typing import Dict
 from typing import ForwardRef
+from typing import Iterable
+from typing import Tuple
 from typing import Type
 from typing import TypeVar
 from typing import Union
@@ -16,6 +17,11 @@ from typing_extensions import NotRequired as NotRequired  # noqa
 from . import compat
 
 _T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT")
+_KT_co = TypeVar("_KT_co", covariant=True)
+_KT_contra = TypeVar("_KT_contra", contravariant=True)
+_VT = TypeVar("_VT")
+_VT_co = TypeVar("_VT_co", covariant=True)
 
 Self = TypeVar("Self", bound=Any)
 
@@ -45,6 +51,18 @@ else:
     from typing_extensions import Protocol as Protocol  # noqa F401
     from typing_extensions import TypedDict as TypedDict  # noqa F401
 
+# copied from TypeShed, required in order to implement
+# MutableMapping.update()
+
+
+class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
+    def keys(self) -> Iterable[_KT]:
+        ...
+
+    def __getitem__(self, __k: _KT) -> _VT_co:
+        ...
+
+
 # work around https://github.com/microsoft/pyright/issues/3025
 _LiteralStar = Literal["*"]
 
@@ -120,7 +138,9 @@ def make_union_type(*types):
     return cast(Any, Union).__getitem__(types)
 
 
-def expand_unions(type_, include_union=False, discard_none=False):
+def expand_unions(
+    type_: Type[Any], include_union: bool = False, discard_none: bool = False
+) -> Tuple[Type[Any], ...]:
     """Return a type as as a tuple of individual types, expanding for
     ``Union`` types."""
 
index b90feae49896707a52eac6ea094ec87a643377cd..407af71c3f64d86c85fb127e8d1b01fa209cf17f 100644 (file)
@@ -40,30 +40,12 @@ markers = [
 ]
 
 [tool.pyright]
-include = [
-    "lib/sqlalchemy/engine/base.py",
-    "lib/sqlalchemy/engine/events.py",
-    "lib/sqlalchemy/engine/interfaces.py",
-    "lib/sqlalchemy/engine/_py_row.py",
-    "lib/sqlalchemy/engine/result.py",
-    "lib/sqlalchemy/engine/row.py",
-    "lib/sqlalchemy/engine/util.py",
-    "lib/sqlalchemy/engine/url.py",
-    "lib/sqlalchemy/pool/",
-    "lib/sqlalchemy/event/",
-    "lib/sqlalchemy/events.py",
-    "lib/sqlalchemy/exc.py",
-    "lib/sqlalchemy/log.py",
-    "lib/sqlalchemy/inspection.py",
-    "lib/sqlalchemy/schema.py",
-    "lib/sqlalchemy/types.py",
-    "lib/sqlalchemy/util/",
-]
+
 
 reportPrivateUsage = "none"
 reportUnusedClass = "none"
 reportUnusedFunction = "none"
-
+reportTypedDictNotRequiredAccess = "warning"
 
 [tool.mypy]
 mypy_path = "./lib/"
@@ -99,6 +81,11 @@ ignore_errors = true
 # strict checking
 [[tool.mypy.overrides]]
 module = [
+    "sqlalchemy.sql.annotation",
+    "sqlalchemy.sql.cache_key",
+    "sqlalchemy.sql.roles",
+    "sqlalchemy.sql.visitors",
+    "sqlalchemy.sql._py_util",
     "sqlalchemy.connectors.*",
     "sqlalchemy.engine.*",
     "sqlalchemy.ext.associationproxy",
@@ -117,6 +104,12 @@ strict = true
 [[tool.mypy.overrides]]
 
 module = [
+    "sqlalchemy.sql.coercions",
+    "sqlalchemy.sql.compiler",
+    #"sqlalchemy.sql.crud",
+    #"sqlalchemy.sql.default_comparator",
+    "sqlalchemy.sql.naming",
+    "sqlalchemy.sql.traversals",
     "sqlalchemy.util.*",
     "sqlalchemy.engine.cursor",
     "sqlalchemy.engine.default",