From: Mike Bayer Date: Tue, 8 Mar 2022 22:14:41 +0000 (-0500) Subject: pep-484: sqlalchemy.sql pass one X-Git-Tag: rel_2_0_0b1~427 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=769fa67d842035dd852ab8b6a26ea3f110a51131;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep-484: sqlalchemy.sql pass one sqlalchemy.sql will require many passes to get all modules even gradually typed. Will have to pick and choose what modules can be strictly typed vs. which can be gradual. in this patch, emphasis is on visitors.py, cache_key.py, annotations.py for strict typing, compiler.py is on gradual typing but has much more structure, in particular where it connects with the outside world. The work within compiler.py also reached back out to engine/cursor.py , default.py quite a bit. References: #6810 Change-Id: I6e8a29f6013fd216e43d45091bc193f8be0368fd --- diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 821c0cb8e3..f776e59753 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -604,18 +604,20 @@ class CursorResultMetaData(ResultMetaData): cls, result_columns: List[ResultColumnsEntry], loose_column_name_matching: bool = False, - ) -> Dict[Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]]: + ) -> Dict[ + Union[str, object], Tuple[str, Tuple[Any, ...], TypeEngine[Any], int] + ]: """when matching cursor.description to a set of names that are present in a Compiled object, as is the case with TextualSelect, get all the names we expect might match those in cursor.description. """ d: Dict[ - Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int] + Union[str, object], + Tuple[str, Tuple[Any, ...], TypeEngine[Any], int], ] = {} for ridx, elem in enumerate(result_columns): key = elem[RM_RENDERED_NAME] - if key in d: # conflicting keyname - just add the column-linked objects # to the existing record. if there is a duplicate column diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 2579f573c5..c9fb1ebf2c 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -46,6 +46,7 @@ from .interfaces import ExecutionContext from .. import event from .. import exc from .. import pool +from .. import TupleType from .. import types as sqltypes from .. import util from ..sql import compiler @@ -76,6 +77,8 @@ if typing.TYPE_CHECKING: from ..sql.compiler import Compiled from ..sql.compiler import ResultColumnsEntry from ..sql.compiler import TypeCompiler + from ..sql.dml import DMLState + from ..sql.elements import BindParameter from ..sql.schema import Column from ..sql.type_api import TypeEngine @@ -820,7 +823,7 @@ class DefaultExecutionContext(ExecutionContext): cursor: DBAPICursor compiled_parameters: List[_MutableCoreSingleExecuteParams] parameters: _DBAPIMultiExecuteParams - extracted_parameters: _CoreSingleExecuteParams + extracted_parameters: Optional[Sequence[BindParameter[Any]]] _empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT) @@ -878,7 +881,7 @@ class DefaultExecutionContext(ExecutionContext): compiled: SQLCompiler, parameters: _CoreMultiExecuteParams, invoked_statement: Executable, - extracted_parameters: _CoreSingleExecuteParams, + extracted_parameters: Optional[Sequence[BindParameter[Any]]], cache_hit: CacheStats = CacheStats.CACHING_DISABLED, ) -> ExecutionContext: """Initialize execution context for a Compiled construct.""" @@ -1513,9 +1516,10 @@ class DefaultExecutionContext(ExecutionContext): inputsizes, self.cursor, self.statement, self.parameters, self ) - has_escaped_names = bool(compiled.escaped_bind_names) - if has_escaped_names: + if compiled.escaped_bind_names: escaped_bind_names = compiled.escaped_bind_names + else: + escaped_bind_names = None if dialect.positional: items = [ @@ -1535,17 +1539,18 @@ class DefaultExecutionContext(ExecutionContext): if key in self._expanded_parameters: if bindparam.type._is_tuple_type: - num = len(bindparam.type.types) + tup_type = cast(TupleType, bindparam.type) + num = len(tup_type.types) dbtypes = inputsizes[bindparam] generic_inputsizes.extend( ( ( escaped_bind_names.get(paramname, paramname) - if has_escaped_names + if escaped_bind_names is not None else paramname ), dbtypes[idx % num], - bindparam.type.types[idx % num], + tup_type.types[idx % num], ) for idx, paramname in enumerate( self._expanded_parameters[key] @@ -1557,7 +1562,7 @@ class DefaultExecutionContext(ExecutionContext): ( ( escaped_bind_names.get(paramname, paramname) - if has_escaped_names + if escaped_bind_names is not None else paramname ), dbtype, @@ -1570,7 +1575,7 @@ class DefaultExecutionContext(ExecutionContext): escaped_name = ( escaped_bind_names.get(key, key) - if has_escaped_names + if escaped_bind_names is not None else key ) @@ -1702,7 +1707,9 @@ class DefaultExecutionContext(ExecutionContext): else: assert column is not None assert parameters is not None - compile_state = cast(SQLCompiler, self.compiled).compile_state + compile_state = cast( + "DMLState", cast(SQLCompiler, self.compiled).compile_state + ) assert compile_state is not None if ( isolate_multiinsert_groups @@ -1715,6 +1722,7 @@ class DefaultExecutionContext(ExecutionContext): else: d = {column.key: parameters[column.key]} index = 0 + assert compile_state._dict_parameters is not None keys = compile_state._dict_parameters.keys() d.update( (key, parameters["%s_m%d" % (key, index)]) for key in keys diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index e65546eb77..e13295d6d6 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -54,9 +54,10 @@ if TYPE_CHECKING: from ..sql.compiler import IdentifierPreparer from ..sql.compiler import Linting from ..sql.compiler import SQLCompiler + from ..sql.elements import BindParameter from ..sql.elements import ClauseElement from ..sql.schema import Column - from ..sql.schema import ColumnDefault + from ..sql.schema import DefaultGenerator from ..sql.schema import Sequence as Sequence_SchemaItem from ..sql.sqltypes import Integer from ..sql.type_api import TypeEngine @@ -813,6 +814,9 @@ class Dialect(EventTarget): """ + _supports_statement_cache: bool + """internal evaluation for supports_statement_cache""" + bind_typing = BindTyping.NONE """define a means of passing typing information to the database and/or driver for bound parameters. @@ -2269,7 +2273,7 @@ class ExecutionContext: compiled: SQLCompiler, parameters: _CoreMultiExecuteParams, invoked_statement: Executable, - extracted_parameters: _CoreSingleExecuteParams, + extracted_parameters: Optional[Sequence[BindParameter[Any]]], cache_hit: CacheStats = CacheStats.CACHING_DISABLED, ) -> ExecutionContext: raise NotImplementedError() @@ -2299,7 +2303,7 @@ class ExecutionContext: def _exec_default( self, column: Optional[Column[Any]], - default: ColumnDefault, + default: DefaultGenerator, type_: Optional[TypeEngine[Any]], ) -> Any: raise NotImplementedError() diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 87d3cac1c7..d428b8a9d4 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -17,6 +17,7 @@ import typing from typing import Any from typing import Callable from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import NoReturn @@ -326,7 +327,7 @@ class SimpleResultMetaData(ResultMetaData): def result_tuple( fields: Sequence[str], extra: Optional[Any] = None -) -> Callable[[_RawRowType], Row]: +) -> Callable[[Iterable[Any]], Row]: parent = SimpleResultMetaData(fields, extra) return functools.partial( Row, parent, parent._processors, parent._keymap, Row._default_key_style diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index cc78e0971c..8f4b963eba 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -33,6 +33,7 @@ if typing.TYPE_CHECKING: from .engine.interfaces import _DBAPIAnyExecuteParams from .engine.interfaces import Dialect from .sql.compiler import Compiled + from .sql.compiler import TypeCompiler from .sql.elements import ClauseElement if typing.TYPE_CHECKING: @@ -221,8 +222,8 @@ class UnsupportedCompilationError(CompileError): def __init__( self, - compiler: "Compiled", - element_type: Type["ClauseElement"], + compiler: Union[Compiled, TypeCompiler], + element_type: Type[ClauseElement], message: Optional[str] = None, ): super(UnsupportedCompilationError, self).__init__( diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 709c13c146..e490a4f03d 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -59,6 +59,7 @@ from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self from ..util.typing import SupportsIndex +from ..util.typing import SupportsKeysAndGetItem if typing.TYPE_CHECKING: from ..orm.attributes import InstrumentedAttribute @@ -1660,7 +1661,9 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): return (item[0], self._get(item[1])) @overload - def update(self, __m: Mapping[_KT, _VT], **kwargs: _VT) -> None: + def update( + self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT + ) -> None: ... @overload diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py index 96e8f6b2c7..9f18b882d7 100644 --- a/lib/sqlalchemy/sql/_py_util.py +++ b/lib/sqlalchemy/sql/_py_util.py @@ -7,7 +7,16 @@ from __future__ import annotations +import typing +from typing import Any from typing import Dict +from typing import Tuple +from typing import Union + +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .cache_key import CacheConst class prefix_anon_map(Dict[str, str]): @@ -22,16 +31,18 @@ class prefix_anon_map(Dict[str, str]): """ - def __missing__(self, key): + def __missing__(self, key: str) -> str: (ident, derived) = key.split(" ", 1) anonymous_counter = self.get(derived, 1) - self[derived] = anonymous_counter + 1 + self[derived] = anonymous_counter + 1 # type: ignore value = f"{derived}_{anonymous_counter}" self[key] = value return value -class cache_anon_map(Dict[int, str]): +class cache_anon_map( + Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]] +): """A map that creates new keys for missing key access. Produces an incrementing sequence given a series of unique keys. @@ -45,11 +56,13 @@ class cache_anon_map(Dict[int, str]): _index = 0 - def get_anon(self, object_): + def get_anon(self, object_: Any) -> Tuple[str, bool]: idself = id(object_) if idself in self: - return self[idself], True + s_val = self[idself] + assert s_val is not True + return s_val, True else: # inline of __missing__ self[idself] = id_ = str(self._index) @@ -57,7 +70,7 @@ class cache_anon_map(Dict[int, str]): return id_, False - def __missing__(self, key): + def __missing__(self, key: int) -> str: self[key] = val = str(self._index) self._index += 1 return val diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index b76393ad6b..7afc2de977 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -13,22 +13,77 @@ associations. from __future__ import annotations +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar + from . import operators -from .base import HasCacheKey -from .traversals import anon_map +from .cache_key import HasCacheKey +from .visitors import anon_map +from .visitors import ExternallyTraversible from .visitors import InternalTraversal from .. import util +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .visitors import _TraverseInternalsType + from ..util.typing import Self + +_AnnotationDict = Mapping[str, Any] + +EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT + -EMPTY_ANNOTATIONS = util.immutabledict() +SelfSupportsAnnotations = TypeVar( + "SelfSupportsAnnotations", bound="SupportsAnnotations" +) -class SupportsAnnotations: +class SupportsAnnotations(ExternallyTraversible): __slots__ = () - _annotations = EMPTY_ANNOTATIONS + _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS + proxy_set: Set[SupportsAnnotations] + _is_immutable: bool + + def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations: + raise NotImplementedError() + + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + raise NotImplementedError() @util.memoized_property - def _annotations_cache_key(self): + def _annotations_cache_key(self) -> Tuple[Any, ...]: anon_map_ = anon_map() return ( "_annotations", @@ -47,14 +102,22 @@ class SupportsAnnotations: ) +SelfSupportsCloneAnnotations = TypeVar( + "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations" +) + + class SupportsCloneAnnotations(SupportsAnnotations): - __slots__ = () + if not typing.TYPE_CHECKING: + __slots__ = () - _clone_annotations_traverse_internals = [ + _clone_annotations_traverse_internals: _TraverseInternalsType = [ ("_annotations", InternalTraversal.dp_annotations_key) ] - def _annotate(self, values): + def _annotate( + self: SelfSupportsCloneAnnotations, values: _AnnotationDict + ) -> SelfSupportsCloneAnnotations: """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -65,7 +128,9 @@ class SupportsCloneAnnotations(SupportsAnnotations): new.__dict__.pop("_generate_cache_key", None) return new - def _with_annotations(self, values): + def _with_annotations( + self: SelfSupportsCloneAnnotations, values: _AnnotationDict + ) -> SelfSupportsCloneAnnotations: """return a copy of this ClauseElement with annotations replaced by the given dictionary. @@ -76,7 +141,27 @@ class SupportsCloneAnnotations(SupportsAnnotations): new.__dict__.pop("_generate_cache_key", None) return new - def _deannotate(self, values=None, clone=False): + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: """return a copy of this :class:`_expression.ClauseElement` with annotations removed. @@ -96,24 +181,52 @@ class SupportsCloneAnnotations(SupportsAnnotations): return self +SelfSupportsWrappingAnnotations = TypeVar( + "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations" +) + + class SupportsWrappingAnnotations(SupportsAnnotations): __slots__ = () - def _annotate(self, values): + _constructor: Callable[..., SupportsWrappingAnnotations] + entity_namespace: Mapping[str, Any] + + def _annotate(self, values: _AnnotationDict) -> Annotated: """return a copy of this ClauseElement with annotations updated by the given dictionary. """ - return Annotated(self, values) + return Annotated._as_annotated_instance(self, values) - def _with_annotations(self, values): + def _with_annotations(self, values: _AnnotationDict) -> Annotated: """return a copy of this ClauseElement with annotations replaced by the given dictionary. """ - return Annotated(self, values) - - def _deannotate(self, values=None, clone=False): + return Annotated._as_annotated_instance(self, values) + + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: """return a copy of this :class:`_expression.ClauseElement` with annotations removed. @@ -129,8 +242,11 @@ class SupportsWrappingAnnotations(SupportsAnnotations): return self -class Annotated: - """clones a SupportsAnnotated and applies an 'annotations' dictionary. +SelfAnnotated = TypeVar("SelfAnnotated", bound="Annotated") + + +class Annotated(SupportsAnnotations): + """clones a SupportsAnnotations and applies an 'annotations' dictionary. Unlike regular clones, this clone also mimics __hash__() and __cmp__() of the original element so that it takes its place @@ -151,21 +267,26 @@ class Annotated: _is_column_operators = False - def __new__(cls, *args): - if not args: - # clone constructor - return object.__new__(cls) - else: - element, values = args - # pull appropriate subclass from registry of annotated - # classes - try: - cls = annotated_classes[element.__class__] - except KeyError: - cls = _new_annotation_type(element.__class__, cls) - return object.__new__(cls) - - def __init__(self, element, values): + @classmethod + def _as_annotated_instance( + cls, element: SupportsWrappingAnnotations, values: _AnnotationDict + ) -> Annotated: + try: + cls = annotated_classes[element.__class__] + except KeyError: + cls = _new_annotation_type(element.__class__, cls) + return cls(element, values) + + _annotations: util.immutabledict[str, Any] + __element: SupportsWrappingAnnotations + _hash: int + + def __new__(cls: Type[SelfAnnotated], *args: Any) -> SelfAnnotated: + return object.__new__(cls) + + def __init__( + self, element: SupportsWrappingAnnotations, values: _AnnotationDict + ): self.__dict__ = element.__dict__.copy() self.__dict__.pop("_annotations_cache_key", None) self.__dict__.pop("_generate_cache_key", None) @@ -173,11 +294,15 @@ class Annotated: self._annotations = util.immutabledict(values) self._hash = hash(element) - def _annotate(self, values): + def _annotate( + self: SelfAnnotated, values: _AnnotationDict + ) -> SelfAnnotated: _values = self._annotations.union(values) return self._with_annotations(_values) - def _with_annotations(self, values): + def _with_annotations( + self: SelfAnnotated, values: util.immutabledict[str, Any] + ) -> SelfAnnotated: clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() clone.__dict__.pop("_annotations_cache_key", None) @@ -185,7 +310,27 @@ class Annotated: clone._annotations = values return clone - def _deannotate(self, values=None, clone=True): + @overload + def _deannotate( + self: SelfAnnotated, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfAnnotated: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> Annotated: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = True, + ) -> SupportsAnnotations: if values is None: return self.__element else: @@ -199,14 +344,18 @@ class Annotated: ) ) - def _compiler_dispatch(self, visitor, **kw): - return self.__element.__class__._compiler_dispatch(self, visitor, **kw) + if not typing.TYPE_CHECKING: + # manually proxy some methods that need extra attention + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any: + return self.__element.__class__._compiler_dispatch( + self, visitor, **kw + ) - @property - def _constructor(self): - return self.__element._constructor + @property + def _constructor(self): + return self.__element._constructor - def _clone(self, **kw): + def _clone(self: SelfAnnotated, **kw: Any) -> SelfAnnotated: clone = self.__element._clone(**kw) if clone is self.__element: # detect immutable, don't change anything @@ -217,22 +366,25 @@ class Annotated: clone.__dict__.update(self.__dict__) return self.__class__(clone, self._annotations) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]: return self.__class__, (self.__element, self._annotations) - def __hash__(self): + def __hash__(self) -> int: return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self._is_column_operators: return self.__element.__class__.__eq__(self, other) else: return hash(other) == hash(self) @property - def entity_namespace(self): + def entity_namespace(self) -> Mapping[str, Any]: if "entity_namespace" in self._annotations: - return self._annotations["entity_namespace"].entity_namespace + return cast( + SupportsWrappingAnnotations, + self._annotations["entity_namespace"], + ).entity_namespace else: return self.__element.entity_namespace @@ -242,12 +394,19 @@ class Annotated: # so that the resulting objects are pickleable; additionally, other # decisions can be made up front about the type of object being annotated # just once per class rather than per-instance. -annotated_classes = {} +annotated_classes: Dict[ + Type[SupportsWrappingAnnotations], Type[Annotated] +] = {} + +_SA = TypeVar("_SA", bound="SupportsAnnotations") def _deep_annotate( - element, annotations, exclude=None, detect_subquery_cols=False -): + element: _SA, + annotations: _AnnotationDict, + exclude: Optional[Sequence[SupportsAnnotations]] = None, + detect_subquery_cols: bool = False, +) -> _SA: """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. @@ -258,9 +417,9 @@ def _deep_annotate( # annotated objects hack the __hash__() method so if we want to # uniquely process them we have to use id() - cloned_ids = {} + cloned_ids: Dict[int, SupportsAnnotations] = {} - def clone(elem, **kw): + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: kw["detect_subquery_cols"] = detect_subquery_cols id_ = id(elem) @@ -285,17 +444,20 @@ def _deep_annotate( return newelem if element is not None: - element = clone(element) - clone = None # remove gc cycles + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles return element -def _deep_deannotate(element, values=None): +def _deep_deannotate( + element: _SA, values: Optional[Sequence[str]] = None +) -> _SA: """Deep copy the given element, removing annotations.""" - cloned = {} + cloned: Dict[Any, SupportsAnnotations] = {} - def clone(elem, **kw): + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: + key: Any if values: key = id(elem) else: @@ -310,12 +472,14 @@ def _deep_deannotate(element, values=None): return cloned[key] if element is not None: - element = clone(element) - clone = None # remove gc cycles + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles return element -def _shallow_annotate(element, annotations): +def _shallow_annotate( + element: SupportsAnnotations, annotations: _AnnotationDict +) -> SupportsAnnotations: """Annotate the given ClauseElement and copy its internals so that internal objects refer to the new annotated object. @@ -328,7 +492,13 @@ def _shallow_annotate(element, annotations): return element -def _new_annotation_type(cls, base_cls): +def _new_annotation_type( + cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated] +) -> Type[Annotated]: + """Generates a new class that subclasses Annotated and proxies a given + element type. + + """ if issubclass(cls, Annotated): return cls elif cls in annotated_classes: @@ -342,8 +512,9 @@ def _new_annotation_type(cls, base_cls): base_cls = annotated_classes[super_] break - annotated_classes[cls] = anno_cls = type( - "Annotated%s" % cls.__name__, (base_cls, cls), {} + annotated_classes[cls] = anno_cls = cast( + Type[Annotated], + type("Annotated%s" % cls.__name__, (base_cls, cls), {}), ) globals()["Annotated%s" % cls.__name__] = anno_cls @@ -359,13 +530,15 @@ def _new_annotation_type(cls, base_cls): # some classes include this even if they have traverse_internals # e.g. BindParameter, add it if present. if cls.__dict__.get("inherit_cache", False): - anno_cls.inherit_cache = True + anno_cls.inherit_cache = True # type: ignore anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) return anno_cls -def _prepare_annotations(target_hierarchy, base_cls): +def _prepare_annotations( + target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated] +) -> None: for cls in util.walk_subclasses(target_hierarchy): _new_annotation_type(cls, base_cls) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a94590da1c..a408a010a0 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -19,8 +19,10 @@ from itertools import zip_longest import operator import re import typing +from typing import MutableMapping from typing import Optional from typing import Sequence +from typing import Set from typing import TypeVar from . import roles @@ -36,14 +38,9 @@ from .. import util from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing -from ..util._has_cy import HAS_CYEXTENSION - -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_util import prefix_anon_map # noqa -else: - from sqlalchemy.cyextension.util import prefix_anon_map # noqa if typing.TYPE_CHECKING: + from .elements import ColumnElement from ..engine import Connection from ..engine import Result from ..engine.interfaces import _CoreMultiExecuteParams @@ -63,6 +60,8 @@ NO_ARG = util.symbol("NO_ARG") # symbols, mypy reports: "error: _Fn? not callable" _Fn = typing.TypeVar("_Fn", bound=typing.Callable) +_AmbiguousTableNameMap = MutableMapping[str, str] + class Immutable: """mark a ClauseElement as 'immutable' when expressions are cloned.""" @@ -87,6 +86,10 @@ class SingletonConstant(Immutable): _is_singleton_constant = True + _singleton: SingletonConstant + + proxy_set: Set[ColumnElement] + def __new__(cls, *arg, **kw): return cls._singleton @@ -519,6 +522,8 @@ class CompileState: plugins = {} + _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] + @classmethod def create_for_statement(cls, statement, compiler, **kw): # factory construction. diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index ff659b77de..fca58f98e5 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -11,21 +11,41 @@ import enum from itertools import zip_longest import typing from typing import Any -from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from typing import Union from .visitors import anon_map -from .visitors import ExtendedInternalTraversal +from .visitors import HasTraversalDispatch +from .visitors import HasTraverseInternals from .visitors import InternalTraversal +from .visitors import prefix_anon_map from .. import util from ..inspection import inspect from ..util import HasMemoized from ..util.typing import Literal - +from ..util.typing import Protocol if typing.TYPE_CHECKING: from .elements import BindParameter + from .elements import ClauseElement + from .visitors import _TraverseInternalsType + from ..engine.base import _CompiledCacheType + from ..engine.interfaces import _CoreSingleExecuteParams + + +class _CacheKeyTraversalDispatchType(Protocol): + def __call__( + s, self: HasCacheKey, visitor: _CacheKeyTraversal + ) -> CacheKey: + ... class CacheConst(enum.Enum): @@ -70,7 +90,9 @@ class HasCacheKey: __slots__ = () - _cache_key_traversal = NO_CACHE + _cache_key_traversal: Union[ + _TraverseInternalsType, Literal[CacheConst.NO_CACHE] + ] = NO_CACHE _is_has_cache_key = True @@ -83,7 +105,7 @@ class HasCacheKey: """ - inherit_cache = None + inherit_cache: Optional[bool] = None """Indicate if this :class:`.HasCacheKey` instance should make use of the cache key generation scheme used by its immediate superclass. @@ -106,8 +128,12 @@ class HasCacheKey: __slots__ = () + _generated_cache_key_traversal: Any + @classmethod - def _generate_cache_attrs(cls): + def _generate_cache_attrs( + cls, + ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]: """generate cache key dispatcher for a new class. This sets the _generated_cache_key_traversal attribute once called @@ -121,8 +147,11 @@ class HasCacheKey: _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) if _cache_key_traversal is None: try: - # this would be HasTraverseInternals - _cache_key_traversal = cls._traverse_internals + # check for _traverse_internals, which is part of + # HasTraverseInternals + _cache_key_traversal = cast( + "Type[HasTraverseInternals]", cls + )._traverse_internals except AttributeError: cls._generated_cache_key_traversal = NO_CACHE return NO_CACHE @@ -138,7 +167,9 @@ class HasCacheKey: # more complicated, so for the moment this is a little less # efficient on startup but simpler. return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", ) else: _cache_key_traversal = cls.__dict__.get( @@ -170,11 +201,15 @@ class HasCacheKey: return NO_CACHE return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", ) @util.preload_module("sqlalchemy.sql.elements") - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Optional[Tuple[Any, ...]]: """return an optional cache key. The cache key is a tuple which can contain any series of @@ -202,15 +237,15 @@ class HasCacheKey: dispatcher: Union[ Literal[CacheConst.NO_CACHE], - Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"], + _CacheKeyTraversalDispatchType, ] try: dispatcher = cls.__dict__["_generated_cache_key_traversal"] except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). + # traversals.py -> _preconfigure_traversals() + # may be used to run these ahead of time, but + # is not enabled right now. # this block will generate any remaining dispatchers. dispatcher = cls._generate_cache_attrs() @@ -218,7 +253,7 @@ class HasCacheKey: anon_map[NO_CACHE] = True return None - result = (id_, cls) + result: Tuple[Any, ...] = (id_, cls) # inline of _cache_key_traversal_visitor.run_generated_dispatch() @@ -268,7 +303,7 @@ class HasCacheKey: # Columns, this should be long lived. For select() # statements, not so much, but they usually won't have # annotations. - result += self._annotations_cache_key + result += self._annotations_cache_key # type: ignore elif ( meth is InternalTraversal.dp_clauseelement_list or meth is InternalTraversal.dp_clauseelement_tuple @@ -290,7 +325,7 @@ class HasCacheKey: ) return result - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: """return a cache key. The cache key is a tuple which can contain any series of @@ -322,32 +357,40 @@ class HasCacheKey: """ - bindparams = [] + bindparams: List[BindParameter[Any]] = [] _anon_map = anon_map() key = self._gen_cache_key(_anon_map, bindparams) if NO_CACHE in _anon_map: return None else: + assert key is not None return CacheKey(key, bindparams) @classmethod - def _generate_cache_key_for_object(cls, obj): - bindparams = [] + def _generate_cache_key_for_object( + cls, obj: HasCacheKey + ) -> Optional[CacheKey]: + bindparams: List[BindParameter[Any]] = [] _anon_map = anon_map() key = obj._gen_cache_key(_anon_map, bindparams) if NO_CACHE in _anon_map: return None else: + assert key is not None return CacheKey(key, bindparams) +class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey): + pass + + class MemoizedHasCacheKey(HasCacheKey, HasMemoized): __slots__ = () @HasMemoized.memoized_instancemethod - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: return HasCacheKey._generate_cache_key(self) @@ -362,14 +405,22 @@ class CacheKey(NamedTuple): """ key: Tuple[Any, ...] - bindparams: Sequence[BindParameter] + bindparams: Sequence[BindParameter[Any]] - def __hash__(self): + # can't set __hash__ attribute because it interferes + # with namedtuple + # can't use "if not TYPE_CHECKING" because mypy rejects it + # inside of a NamedTuple + def __hash__(self) -> Optional[int]: # type: ignore """CacheKey itself is not hashable - hash the .key portion""" - return None - def to_offline_string(self, statement_cache, statement, parameters): + def to_offline_string( + self, + statement_cache: _CompiledCacheType, + statement: ClauseElement, + parameters: _CoreSingleExecuteParams, + ) -> str: """Generate an "offline string" form of this :class:`.CacheKey` The "offline string" is basically the string SQL for the @@ -400,21 +451,21 @@ class CacheKey(NamedTuple): return repr((sql_str, param_tuple)) - def __eq__(self, other): - return self.key == other.key + def __eq__(self, other: Any) -> bool: + return bool(self.key == other.key) @classmethod - def _diff_tuples(cls, left, right): + def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str: ck1 = CacheKey(left, []) ck2 = CacheKey(right, []) return ck1._diff(ck2) - def _whats_different(self, other): + def _whats_different(self, other: CacheKey) -> Iterator[str]: k1 = self.key k2 = other.key - stack = [] + stack: List[int] = [] pickup_index = 0 while True: s1, s2 = k1, k2 @@ -440,11 +491,11 @@ class CacheKey(NamedTuple): pickup_index = stack.pop(-1) break - def _diff(self, other): + def _diff(self, other: CacheKey) -> str: return ", ".join(self._whats_different(other)) - def __str__(self): - stack = [self.key] + def __str__(self) -> str: + stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key] output = [] sentinel = object() @@ -473,15 +524,15 @@ class CacheKey(NamedTuple): return "CacheKey(key=%s)" % ("\n".join(output),) - def _generate_param_dict(self): + def _generate_param_dict(self) -> Dict[str, Any]: """used for testing""" - from .compiler import prefix_anon_map - _anon_map = prefix_anon_map() return {b.key % _anon_map: b.effective_value for b in self.bindparams} - def _apply_params_to_element(self, original_cache_key, target_element): + def _apply_params_to_element( + self, original_cache_key: CacheKey, target_element: ClauseElement + ) -> ClauseElement: translate = { k.key: v.value for k, v in zip(original_cache_key.bindparams, self.bindparams) @@ -490,7 +541,7 @@ class CacheKey(NamedTuple): return target_element.params(translate) -class _CacheKeyTraversal(ExtendedInternalTraversal): +class _CacheKeyTraversal(HasTraversalDispatch): # very common elements are inlined into the main _get_cache_key() method # to produce a dramatic savings in Python function call overhead @@ -512,17 +563,43 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): visit_propagate_attrs = PROPAGATE_ATTRS def visit_with_context_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple((fn.__code__, c_key) for fn, c_key in obj) - def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + def visit_inspectable( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) - def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): + def visit_string_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple(obj) - def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + def visit_multi( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, obj._gen_cache_key(anon_map, bindparams) @@ -530,7 +607,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): else obj, ) - def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + def visit_multi_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -542,8 +626,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_has_cache_key_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -558,8 +647,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_has_cache_key_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -568,8 +662,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_executable_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -582,22 +681,37 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_inspectable_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return self.visit_has_cache_key_list( attrname, [inspect(o) for o in obj], parent, anon_map, bindparams ) def visit_clauseelement_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return self.visit_has_cache_key_tuples( attrname, obj, parent, anon_map, bindparams ) def visit_fromclause_ordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -606,8 +720,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_clauseelement_unordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () cache_keys = [ @@ -621,13 +740,23 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_named_ddl_element( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, obj.name) def visit_prefix_sequence( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () @@ -642,8 +771,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_setup_join_tuple( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple( ( target._gen_cache_key(anon_map, bindparams), @@ -659,8 +793,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_table_hint_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () @@ -678,12 +817,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ), ) - def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + def visit_plain_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) def visit_dialect_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -701,8 +852,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_string_clauseelement_dict( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -712,8 +868,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_string_multi_dict( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -728,8 +889,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_fromclause_canonical_column_collection( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # inlining into the internals of ColumnCollection return ( attrname, @@ -740,14 +906,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_unknown_structure( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: anon_map[NO_CACHE] = True return () def visit_dml_ordered_values( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -761,7 +937,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ), ) - def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + def visit_dml_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # in py37 we can assume two dictionaries created in the same # insert ordering will retain that sorting return ( @@ -778,8 +961,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_dml_multi_values( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # multivalues are simply not cacheable right now anon_map[NO_CACHE] = True return () diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index d616417ab3..834bfb75d4 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -13,6 +13,9 @@ import re import typing from typing import Any from typing import Any as TODO_Any +from typing import Dict +from typing import List +from typing import NoReturn from typing import Optional from typing import Type from typing import TypeVar @@ -42,6 +45,7 @@ if typing.TYPE_CHECKING: from . import selectable from . import traversals from .elements import ClauseElement + from .elements import ColumnClause _SR = TypeVar("_SR", bound=roles.SQLRole) _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) @@ -252,7 +256,7 @@ def expect_col_expression_collection(role, expressions): if isinstance(resolved, str): strname = resolved = expr else: - cols = [] + cols: List[ColumnClause[Any]] = [] visitors.traverse(resolved, {}, {"column": cols.append}) if cols: column = cols[0] @@ -266,7 +270,7 @@ class RoleImpl: def _literal_coercion(self, element, **kw): raise NotImplementedError() - _post_coercion = None + _post_coercion: Any = None _resolve_literal_only = False _skip_clauseelement_for_target_match = False @@ -276,19 +280,24 @@ class RoleImpl: self._use_inspection = issubclass(role_class, roles.UsesInspection) def _implicit_coercions( - self, element, resolved, argname=None, **kw + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, ) -> Any: self._raise_for_expected(element, argname, resolved) def _raise_for_expected( self, - element, - argname=None, - resolved=None, - advice=None, - code=None, - err=None, - ): + element: Any, + argname: Optional[str] = None, + resolved: Optional[Any] = None, + advice: Optional[str] = None, + code: Optional[str] = None, + err: Optional[Exception] = None, + **kw: Any, + ) -> NoReturn: if resolved is not None and resolved is not element: got = "%r object resolved from %r object" % (resolved, element) else: @@ -324,22 +333,20 @@ class _StringOnly: _resolve_literal_only = True -class _ReturnsStringKey: +class _ReturnsStringKey(RoleImpl): __slots__ = () - def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): - if isinstance(original_element, str): - return original_element + def _implicit_coercions(self, element, resolved, argname=None, **kw): + if isinstance(element, str): + return element else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _literal_coercion(self, element, **kw): return element -class _ColumnCoercions: +class _ColumnCoercions(RoleImpl): __slots__ = () def _warn_for_scalar_subquery_coercion(self): @@ -368,8 +375,12 @@ class _ColumnCoercions: def _no_text_coercion( - element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None -): + element: Any, + argname: Optional[str] = None, + exc_cls: Type[exc.SQLAlchemyError] = exc.ArgumentError, + extra: Optional[str] = None, + err: Optional[Exception] = None, +) -> NoReturn: raise exc_cls( "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " "explicitly declared as text(%(expr)r)" @@ -381,7 +392,7 @@ def _no_text_coercion( ) from err -class _NoTextCoercion: +class _NoTextCoercion(RoleImpl): __slots__ = () def _literal_coercion(self, element, argname=None, **kw): @@ -393,7 +404,7 @@ class _NoTextCoercion: self._raise_for_expected(element, argname) -class _CoerceLiterals: +class _CoerceLiterals(RoleImpl): __slots__ = () _coerce_consts = False _coerce_star = False @@ -440,12 +451,19 @@ class LiteralValueImpl(RoleImpl): return element -class _SelectIsNotFrom: +class _SelectIsNotFrom(RoleImpl): __slots__ = () def _raise_for_expected( - self, element, argname=None, resolved=None, advice=None, **kw - ): + self, + element: Any, + argname: Optional[str] = None, + resolved: Optional[Any] = None, + advice: Optional[str] = None, + code: Optional[str] = None, + err: Optional[Exception] = None, + **kw: Any, + ) -> NoReturn: if ( not advice and isinstance(element, roles.SelectStatementRole) @@ -460,26 +478,33 @@ class _SelectIsNotFrom: else: code = None - return super(_SelectIsNotFrom, self)._raise_for_expected( + super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, code=code, + err=err, **kw, ) + # never reached + assert False class HasCacheKeyImpl(RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): - if isinstance(original_element, traversals.HasCacheKey): - return original_element + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, HasCacheKey): + return element else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _literal_coercion(self, element, **kw): return element @@ -489,12 +514,16 @@ class ExecutableOptionImpl(RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): - if isinstance(original_element, ExecutableOption): - return original_element + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, ExecutableOption): + return element else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _literal_coercion(self, element, **kw): return element @@ -560,8 +589,12 @@ class InElementImpl(RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved._is_from_clause: if ( isinstance(resolved, selectable.Alias) @@ -573,7 +606,7 @@ class InElementImpl(RoleImpl): self._warn_for_implicit_coercion(resolved) return self._post_coercion(resolved.select(), **kw) else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _warn_for_implicit_coercion(self, elem): util.warn( @@ -586,12 +619,16 @@ class InElementImpl(RoleImpl): if isinstance(element, collections_abc.Iterable) and not isinstance( element, str ): - non_literal_expressions = {} + non_literal_expressions: Dict[ + Optional[operators.ColumnOperators[Any]], + operators.ColumnOperators[Any], + ] = {} element = list(element) for o in element: if not _is_literal(o): if not isinstance(o, operators.ColumnOperators): self._raise_for_expected(element, **kw) + else: non_literal_expressions[o] = o elif o is None: @@ -712,8 +749,12 @@ class GroupByImpl(ByOfImpl, RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if isinstance(resolved, roles.StrictFromClauseRole): return elements.ClauseList(*resolved.c) else: @@ -748,12 +789,16 @@ class TruncatedLabelImpl(_StringOnly, RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): - if isinstance(original_element, str): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, str): return resolved else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _literal_coercion(self, element, argname=None, **kw): """coerce the given value to :class:`._truncated_label`. @@ -794,7 +839,13 @@ class DDLReferredColumnImpl(DDLConstraintColumnImpl): class LimitOffsetImpl(RoleImpl): __slots__ = () - def _implicit_coercions(self, element, resolved, argname=None, **kw): + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved is None: return None else: @@ -814,18 +865,22 @@ class LabeledColumnExprImpl(ExpressionElementImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if isinstance(resolved, roles.ExpressionElementRole): return resolved.label(None) else: new = super(LabeledColumnExprImpl, self)._implicit_coercions( - original_element, resolved, argname=argname, **kw + element, resolved, argname=argname, **kw ) if isinstance(new, roles.ExpressionElementRole): return new.label(None) else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): @@ -899,13 +954,17 @@ class StatementImpl(_CoerceLiterals, RoleImpl): return resolved def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved._is_lambda_element: return resolved else: - return super(StatementImpl, self)._implicit_coercions( - original_element, resolved, argname=argname, **kw + return super()._implicit_coercions( + element, resolved, argname=argname, **kw ) @@ -913,12 +972,16 @@ class SelectStatementImpl(_NoTextCoercion, RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved._is_text_clause: return resolved.columns() else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class HasCTEImpl(ReturnsRowsImpl): @@ -938,13 +1001,18 @@ class JoinTargetImpl(RoleImpl): self._raise_for_expected(element, argname) def _implicit_coercions( - self, original_element, resolved, argname=None, legacy=False, **kw - ): - if isinstance(original_element, roles.JoinTargetRole): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + legacy: bool = False, + **kw: Any, + ) -> Any: + if isinstance(element, roles.JoinTargetRole): # note that this codepath no longer occurs as of # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match # were set to False. - return original_element + return element elif legacy and resolved._is_select_statement: util.warn_deprecated( "Implicit coercion of SELECT and textual SELECT " @@ -959,7 +1027,7 @@ class JoinTargetImpl(RoleImpl): # in _ORMJoin->Join return resolved else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): @@ -967,13 +1035,13 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): def _implicit_coercions( self, - original_element, - resolved, - argname=None, - explicit_subquery=False, - allow_select=True, - **kw, - ): + element: Any, + resolved: Any, + argname: Optional[str] = None, + explicit_subquery: bool = False, + allow_select: bool = True, + **kw: Any, + ) -> Any: if resolved._is_select_statement: if explicit_subquery: return resolved.subquery() @@ -989,7 +1057,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): elif resolved._is_text_clause: return resolved else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _post_coercion(self, element, deannotate=False, **kw): if deannotate: @@ -1003,12 +1071,13 @@ class StrictFromClauseImpl(FromClauseImpl): def _implicit_coercions( self, - original_element, - resolved, - argname=None, - allow_select=False, - **kw, - ): + element: Any, + resolved: Any, + argname: Optional[str] = None, + explicit_subquery: bool = False, + allow_select: bool = False, + **kw: Any, + ) -> Any: if resolved._is_select_statement and allow_select: util.warn_deprecated( "Implicit coercion of SELECT and textual SELECT constructs " @@ -1019,7 +1088,7 @@ class StrictFromClauseImpl(FromClauseImpl): ) return resolved._implicit_subquery else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class AnonymizedFromClauseImpl(StrictFromClauseImpl): @@ -1045,8 +1114,12 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved._is_from_clause: if ( isinstance(resolved, selectable.Alias) @@ -1056,7 +1129,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl): else: return resolved.select() else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class CompoundElementImpl(_NoTextCoercion, RoleImpl): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 423c3d446e..f28dceefce 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -35,14 +35,19 @@ from time import perf_counter import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict +from typing import FrozenSet +from typing import Iterable from typing import List from typing import Mapping from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type from typing import Union from . import base @@ -54,19 +59,42 @@ from . import operators from . import schema from . import selectable from . import sqltypes +from .base import _from_objects from .base import NO_ARG -from .base import prefix_anon_map from .elements import quoted_name from .schema import Column +from .sqltypes import TupleType from .type_api import TypeEngine +from .visitors import prefix_anon_map from .. import exc from .. import util from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import TypedDict if typing.TYPE_CHECKING: + from .annotation import _AnnotationDict + from .base import _AmbiguousTableNameMap + from .base import CompileState + from .cache_key import CacheKey + from .elements import BindParameter + from .elements import ColumnClause + from .elements import Label + from .functions import Function + from .selectable import Alias + from .selectable import AliasedReturnsRows + from .selectable import CompoundSelectState from .selectable import CTE from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import SelectState + from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _MutableCoreSingleExecuteParams + from ..engine.interfaces import _SchemaTranslateMapType from ..engine.result import _ProcessorType _FromHintsType = Dict["FromClause", str] @@ -236,7 +264,7 @@ OPERATORS = { operators.nulls_last_op: " NULLS LAST", } -FUNCTIONS = { +FUNCTIONS: Dict[Type[Function], str] = { functions.coalesce: "coalesce", functions.current_date: "CURRENT_DATE", functions.current_time: "CURRENT_TIME", @@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple): name: str """column name, may be labeled""" - objects: List[Any] - """list of objects that should be able to locate this column + objects: Tuple[Any, ...] + """sequence of objects that should be able to locate this column in a RowMapping. This is typically string names and aliases as well as Column objects. @@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple): """ +class _ResultMapAppender(Protocol): + def __call__( + self, + keyname: str, + name: str, + objects: Sequence[Any], + type_: TypeEngine[Any], + ) -> None: + ... + + # integer indexes into ResultColumnsEntry used by cursor.py. # some profiling showed integer access faster than named tuple RM_RENDERED_NAME: Literal[0] = 0 @@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2 RM_TYPE: Literal[3] = 3 +class _BaseCompilerStackEntry(TypedDict): + asfrom_froms: Set[FromClause] + correlate_froms: Set[FromClause] + selectable: ReturnsRows + + +class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): + compile_state: CompileState + need_result_map_for_nested: bool + need_result_map_for_compound: bool + select_0: ReturnsRows + insert_from_select: Select + + class ExpandedState(NamedTuple): statement: str additional_parameters: _CoreSingleExecuteParams @@ -427,21 +480,23 @@ class Compiled: defaults. """ - _cached_metadata = None + _cached_metadata: Optional[CursorResultMetaData] = None _result_columns: Optional[List[ResultColumnsEntry]] = None - schema_translate_map = None + schema_translate_map: Optional[_SchemaTranslateMapType] = None - execution_options = util.EMPTY_DICT + execution_options: _ExecuteOptions = util.EMPTY_DICT """ Execution options propagated from the statement. In some cases, sub-elements of the statement can modify these. """ - _annotations = util.EMPTY_DICT + preparer: IdentifierPreparer + + _annotations: _AnnotationDict = util.EMPTY_DICT - compile_state = None + compile_state: Optional[CompileState] = None """Optional :class:`.CompileState` object that maintains additional state used by the compiler. @@ -457,9 +512,21 @@ class Compiled: """ - cache_key = None + cache_key: Optional[CacheKey] = None + """The :class:`.CacheKey` that was generated ahead of creating this + :class:`.Compiled` object. + + This is used for routines that need access to the original + :class:`.CacheKey` instance generated when the :class:`.Compiled` + instance was first cached, typically in order to reconcile + the original list of :class:`.BindParameter` objects with a + per-statement list that's generated on each call. + + """ _gen_time: float + """Generation time of this :class:`.Compiled`, used for reporting + cache stats.""" def __init__( self, @@ -543,7 +610,11 @@ class Compiled: return self.string or "" - def construct_params(self, params=None, extracted_parameters=None): + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + ) -> Optional[_MutableCoreSingleExecuteParams]: """Return the bind params for this compiled object. :param params: a dict of string/object pairs whose values will @@ -646,6 +717,17 @@ class SQLCompiler(Compiled): isplaintext: bool = False + binds: Dict[str, BindParameter[Any]] + """a dictionary of bind parameter keys to BindParameter instances.""" + + bind_names: Dict[BindParameter[Any], str] + """a dictionary of BindParameter instances to "compiled" names + that are actually present in the generated SQL""" + + stack: List[_CompilerStackEntry] + """major statements such as SELECT, INSERT, UPDATE, DELETE are + tracked in this stack using an entry format.""" + result_columns: List[ResultColumnsEntry] """relates label names in the final SQL to a tuple of local column/label name, ColumnElement object (if any) and @@ -709,7 +791,7 @@ class SQLCompiler(Compiled): """ - insert_single_values_expr = None + insert_single_values_expr: Optional[str] = None """When an INSERT is compiled with a single set of parameters inside a VALUES expression, the string is assigned here, where it can be used for insert batching schemes to rewrite the VALUES expression. @@ -718,19 +800,19 @@ class SQLCompiler(Compiled): """ - literal_execute_params = frozenset() + literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as literal values at statement execution time. """ - post_compile_params = frozenset() + post_compile_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as bound parameter placeholders at statement execution time. """ - escaped_bind_names = util.EMPTY_DICT + escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT """Late escaping of bound parameter names that has to be converted to the original name when looking in the parameter dictionary. @@ -744,14 +826,25 @@ class SQLCompiler(Compiled): """if True, and this in insert, use cursor.lastrowid to populate result.inserted_primary_key. """ - _cache_key_bind_match = None + _cache_key_bind_match: Optional[ + Tuple[ + Dict[ + BindParameter[Any], + List[BindParameter[Any]], + ], + Dict[ + str, + BindParameter[Any], + ], + ] + ] = None """a mapping that will relate the BindParameter object we compile to those that are part of the extracted collection of parameters in the cache key, if we were given a cache key. """ - positiontup: Optional[Sequence[str]] = None + positiontup: Optional[List[str]] = None """for a compiled construct that uses a positional paramstyle, will be a sequence of strings, indicating the names of bound parameters in order. @@ -768,6 +861,19 @@ class SQLCompiler(Compiled): inline: bool = False + ctes: Optional[MutableMapping[CTE, str]] + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + ctes_by_level_name: Dict[Tuple[int, str], CTE] + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name, cte_opts)] + level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] + + ctes_recursive: bool + cte_positional: Dict[CTE, List[str]] + def __init__( self, dialect, @@ -804,10 +910,9 @@ class SQLCompiler(Compiled): self.cache_key = cache_key if cache_key: - self._cache_key_bind_match = ckbm = { - b.key: b for b in cache_key[1] - } - ckbm.update({b: [b] for b in cache_key[1]}) + cksm = {b.key: b for b in cache_key[1]} + ckbm = {b: [b] for b in cache_key[1]} + self._cache_key_bind_match = (ckbm, cksm) # compile INSERT/UPDATE defaults/sequences to expect executemany # style execution, which may mean no pre-execute of defaults, @@ -911,14 +1016,14 @@ class SQLCompiler(Compiled): @property def prefetch(self): - return list(self.insert_prefetch + self.update_prefetch) + return list(self.insert_prefetch) + list(self.update_prefetch) @util.memoized_property def _global_attributes(self): return {} @util.memoized_instancemethod - def _init_cte_state(self) -> None: + def _init_cte_state(self) -> MutableMapping[CTE, str]: """Initialize collections related to CTEs only if a CTE is located, to save on the overhead of these collections otherwise. @@ -926,21 +1031,22 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT # To store the query to print - Dict[cte, text_query] - self.ctes: MutableMapping[CTE, str] = util.OrderedDict() + ctes: MutableMapping[CTE, str] = util.OrderedDict() + self.ctes = ctes # Detect same CTE references - Dict[(level, name), cte] # Level is required for supporting nesting - self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {} + self.ctes_by_level_name = {} # To retrieve key/level in ctes_by_level_name - # Dict[cte_reference, (level, cte_name, cte_opts)] - self.level_name_by_cte: Dict[ - CTE, Tuple[int, str, selectable._CTEOpts] - ] = {} + self.level_name_by_cte = {} - self.ctes_recursive: bool = False + self.ctes_recursive = False if self.positional: - self.cte_positional: Dict[CTE, List[str]] = {} + self.cte_positional = {} + + return ctes @contextlib.contextmanager def _nested_result(self): @@ -985,7 +1091,7 @@ class SQLCompiler(Compiled): if not bindparam.type._is_tuple_type else tuple( elem_type._cached_bind_processor(self.dialect) - for elem_type in bindparam.type.types + for elem_type in cast(TupleType, bindparam.type).types ), ) for bindparam in self.bind_names @@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled): def construct_params( self, - params=None, - _group_number=None, - _check=True, - extracted_parameters=None, - ): + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + _group_number: Optional[int] = None, + _check: bool = True, + ) -> _MutableCoreSingleExecuteParams: """return a dictionary of bind parameter keys and values""" has_escaped_names = bool(self.escaped_bind_names) @@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled): # way. The parameters present in self.bind_names may be clones of # these original cache key params in the case of DML but the .key # will be guaranteed to match. - try: - orig_extracted = self.cache_key[1] - except TypeError as err: + if self.cache_key is None: raise exc.CompileError( "This compiled object has no original cache key; " "can't pass extracted_parameters to construct_params" - ) from err + ) + else: + orig_extracted = self.cache_key[1] - ckbm = self._cache_key_bind_match + ckbm_tuple = self._cache_key_bind_match + assert ckbm_tuple is not None + ckbm, _ = ckbm_tuple resolved_extracted = { bind: extracted for b, extracted in zip(orig_extracted, extracted_parameters) @@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled): if bindparam.type._is_tuple_type: inputsizes[bindparam] = [ - lookup_type(typ) for typ in bindparam.type.types + lookup_type(typ) + for typ in cast(TupleType, bindparam.type).types ] else: inputsizes[bindparam] = lookup_type(bindparam.type) @@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled): def _process_parameters_for_postcompile( self, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_MutableCoreSingleExecuteParams] = None, _populate_self: bool = False, ) -> ExpandedState: """handle special post compile parameters. @@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled): parameters = self.construct_params() expanded_parameters = {} + positiontup: Optional[List[str]] + if self.positional: positiontup = [] else: positiontup = None processors = self._bind_processors + single_processors = cast("Mapping[str, _ProcessorType]", processors) + tuple_processors = cast( + "Mapping[str, Sequence[_ProcessorType]]", processors + ) - new_processors = {} + new_processors: Dict[str, _ProcessorType] = {} if self.positional and self._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric'. @@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled): "the 'numeric' paramstyle at this time." ) - replacement_expressions = {} - to_update_sets = {} + replacement_expressions: Dict[str, Any] = {} + to_update_sets: Dict[str, Any] = {} # notes: # *unescaped* parameter names in: @@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled): # *escaped* parameter names in: # construct_params(), replacement_expressions - for name in ( - self.positiontup if self.positional else self.bind_names.values() - ): + if self.positional and self.positiontup is not None: + names: Iterable[str] = self.positiontup + else: + names = self.bind_names.values() + + for name in names: escaped_name = ( self.escaped_bind_names.get(name, name) if self.escaped_bind_names @@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled): if parameter in self.post_compile_params: if escaped_name in replacement_expressions: to_update = to_update_sets[escaped_name] + values = None else: # we are removing the parameter from parameters # because it is a list value, which is not expected by @@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled): if not parameter.literal_execute: parameters.update(to_update) if parameter.type._is_tuple_type: + assert values is not None new_processors.update( ( "%s_%s_%s" % (name, i, j), - processors[name][j - 1], + tuple_processors[name][j - 1], ) for i, tuple_element in enumerate(values, 1) - for j, value in enumerate(tuple_element, 1) - if name in processors - and processors[name][j - 1] is not None + for j, _ in enumerate(tuple_element, 1) + if name in tuple_processors + and tuple_processors[name][j - 1] is not None ) else: new_processors.update( - (key, processors[name]) - for key, value in to_update - if name in processors + (key, single_processors[name]) + for key, _ in to_update + if name in single_processors ) - if self.positional: - positiontup.extend(name for name, value in to_update) + if positiontup is not None: + positiontup.extend(name for name, _ in to_update) expanded_parameters[name] = [ - expand_key for expand_key, value in to_update + expand_key for expand_key, _ in to_update ] - elif self.positional: + elif positiontup is not None: positiontup.append(name) def process_expanding(m): @@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled): # special use cases. self.string = expanded_state.statement self._bind_processors.update(expanded_state.processors) - self.positiontup = expanded_state.positiontup + self.positiontup = list(expanded_state.positiontup or ()) self.post_compile_params = frozenset() for key in expanded_state.parameter_expansion: bind = self.binds.pop(key) @@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled): self._result_columns ) + _key_getters_for_crud_column: Tuple[ + Callable[[Union[str, Column[Any]]], str], + Callable[[Column[Any]], str], + Callable[[Column[Any]], str], + ] + @util.memoized_property def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._key_getters_for_crud_column[2] @@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled): @util.memoized_property @util.preload_module("sqlalchemy.engine.result") def _inserted_primary_key_from_returning_getter(self): - result = util.preloaded.engine_result + if typing.TYPE_CHECKING: + from ..engine import result + else: + result = util.preloaded.engine_result param_key_getter = self._within_exec_param_key_getter table = self.statement.table - ret = {col: idx for idx, col in enumerate(self.returning)} + returning = self.returning + assert returning is not None + ret = {col: idx for idx, col in enumerate(returning)} - getters = [ - (operator.itemgetter(ret[col]), True) - if col in ret - else ( - operator.methodcaller("get", param_key_getter(col), None), - False, - ) - for col in table.primary_key - ] + getters = cast( + "List[Tuple[Callable[[Any], Any], bool]]", + [ + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, + ) + for col in table.primary_key + ], + ) row_fn = result.result_tuple([col.key for col in table.primary_key]) @@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled): self, element, within_columns_clause=False, **kwargs ): if self.stack and self.dialect.supports_simple_order_by_label: - compile_state = self.stack[-1]["compile_state"] + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + raise exc.CompileError( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ) from ke ( with_cols, @@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled): # compiling the element outside of the context of a SELECT return self.process(element._text_clause) - compile_state = self.stack[-1]["compile_state"] + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=ke, + ) + with_cols, only_froms, only_cols = compile_state._label_resolve_dict try: if within_columns_clause: @@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled): def visit_column( self, - column, - add_to_result_map=None, - include_table=True, - result_map_targets=(), - ambiguous_table_name_map=None, - **kwargs, - ): + column: ColumnClause[Any], + add_to_result_map: Optional[_ResultMapAppender] = None, + include_table: bool = True, + result_map_targets: Tuple[Any, ...] = (), + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + **kwargs: Any, + ) -> str: name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled): ) else: schema_prefix = "" - tablename = table.name + + tablename = cast("NamedFromClause", table).name if ( not effective_schema @@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - new_entry = { + new_entry: _CompilerStackEntry = { "correlate_froms": set(), "asfrom_froms": set(), "selectable": taf, @@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled): compiled_col = self.visit_column(element, **kw) return "(%s).%s" % (compiled_fn, compiled_col) - def visit_function(self, func, add_to_result_map=None, **kwargs): + def visit_function( + self, + func: Function, + add_to_result_map: Optional[_ResultMapAppender] = None, + **kwargs: Any, + ) -> str: if add_to_result_map is not None: add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + + text: str + if disp: text = disp(func, **kwargs) else: @@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled): if compound_stmt._independent_ctes: self._dispatch_independent_ctes(compound_stmt, kwargs) - keyword = self.compound_keywords.get(cs.keyword) + keyword = self.compound_keywords[cs.keyword] text = (" " + keyword + " ").join( ( @@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled): # a different set of parameter values. here, we accommodate for # parameters that may have been cloned both before and after the cache # key was been generated. - ckbm = self._cache_key_bind_match - if ckbm: + ckbm_tuple = self._cache_key_bind_match + + if ckbm_tuple: + ckbm, cksm = ckbm_tuple for bp in bindparam._cloned_set: - if bp.key in ckbm: - cb = ckbm[bp.key] + if bp.key in cksm: + cb = cksm[bp.key] ckbm[cb].append(bindparam) if bindparam.isoutparam: @@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled): if positional_names is not None: positional_names.append(name) else: - self.positiontup.append(name) + self.positiontup.append(name) # type: ignore[union-attr] elif not escaped_from: if _BIND_TRANSLATE_RE.search(name): @@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled): name = new_name if escaped_from: - if not self.escaped_bind_names: - self.escaped_bind_names = {} - self.escaped_bind_names[escaped_from] = name + self.escaped_bind_names = self.escaped_bind_names.union( + {escaped_from: name} + ) if post_compile: return "__[POSTCOMPILE_%s]" % name @@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled): cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), **kwargs: Any, ) -> Optional[str]: - self._init_cte_state() + self_ctes = self._init_cte_state() + assert self_ctes is self.ctes kwargs["visiting_cte"] = cte @@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled): # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". - del self.ctes[existing_cte] + del self_ctes[existing_cte] existing_cte_reference_cte = existing_cte._get_reference_cte() @@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled): if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) - if not cte_pre_alias_name and cte not in self.ctes: + if not cte_pre_alias_name and cte not in self_ctes: if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) @@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled): cte, cte._suffixes, **kwargs ) - self.ctes[cte] = text + self_ctes[cte] = text if asfrom: if from_linter: from_linter.froms[cte] = cte_name if not is_new_cte and embedded_in_current_named_cte: - return self.preparer.format_alias(cte, cte_name) + return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501 if cte_pre_alias_name: text = self.preparer.format_alias(cte, cte_pre_alias_name) @@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled): else: return self.preparer.format_alias(cte, cte_name) + return None + def visit_table_valued_alias(self, element, **kw): if element._is_lateral: return self.visit_lateral(element, **kw) @@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled): self, keyname: str, name: str, - objects: List[Any], + objects: Tuple[Any, ...], type_: TypeEngine[Any], ) -> None: if keyname is None or keyname == "*": @@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled): def get_statement_hint_text(self, hint_texts): return " ".join(hint_texts) - _default_stack_entry = util.immutabledict( - [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] - ) + _default_stack_entry: _CompilerStackEntry + + if not typing.TYPE_CHECKING: + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) def _display_froms_for_select( self, select_stmt, asfrom, lateral=False, **kw @@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled): ) return froms - translate_select_structure = None + translate_select_structure: Any = None """if not ``None``, should be a callable which accepts ``(select_stmt, **kw)`` and returns a select object. this is used for structural changes mostly to accommodate for LIMIT/OFFSET schemes @@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled): ) self._result_columns = [ - (key, name, tuple(translate.get(o, o) for o in obj), type_) + ResultColumnsEntry( + key, name, tuple(translate.get(o, o) for o in obj), type_ + ) for key, name, obj, type_ in self._result_columns ] @@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled): implicit_correlate_froms=asfrom_froms, ) - new_correlate_froms = set(selectable._from_objects(*froms)) + new_correlate_froms = set(_from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) - new_entry = { + new_entry: _CompilerStackEntry = { "asfrom_froms": new_correlate_froms, "correlate_froms": all_correlate_froms, "selectable": select, @@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled): text += " \nWHERE " + t if warn_linting: + assert from_linter is not None from_linter.warn() if select._group_by_clauses: @@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled): if not self.ctes: return "" + ctes: MutableMapping[CTE, str] + if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): @@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled): ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: + assert self.positiontup is not None self.positiontup = ( - sum([self.cte_positional[cte] for cte in ctes], []) + list( + itertools.chain.from_iterable( + self.cte_positional[cte] for cte in ctes + ) + ) + self.positiontup ) + cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " @@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled): if is_multitable: # main table might be a JOIN - main_froms = set(selectable._from_objects(update_stmt.table)) + main_froms = set(_from_objects(update_stmt.table)) render_extra_froms = [ f for f in extra_froms if f not in main_froms ] @@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled): def type_compiler(self): return self.dialect.type_compiler - def construct_params(self, params=None, extracted_parameters=None): + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + ) -> Optional[_MutableCoreSingleExecuteParams]: return None def visit_ddl(self, ddl, **kwargs): @@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler): return get_col_spec(**kw) +class _SchemaForObjectCallable(Protocol): + def __call__(self, obj: Any) -> str: + ... + + class IdentifierPreparer: """Handle quoting and case-folding of identifiers based on options.""" @@ -5209,7 +5404,13 @@ class IdentifierPreparer: illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - schema_for_object = operator.attrgetter("schema") + initial_quote: str + + final_quote: str + + _strings: MutableMapping[str, str] + + schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema") """Return the .schema attribute for an object. For the default IdentifierPreparer, the schema for an object is always @@ -5297,7 +5498,7 @@ class IdentifierPreparer: return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement) - def _escape_identifier(self, value): + def _escape_identifier(self, value: str) -> str: """Escape an identifier. Subclasses should override this to provide database-dependent @@ -5309,7 +5510,7 @@ class IdentifierPreparer: value = value.replace("%", "%%") return value - def _unescape_identifier(self, value): + def _unescape_identifier(self, value: str) -> str: """Canonicalize an escaped identifier. Subclasses should override this to provide database-dependent @@ -5336,7 +5537,7 @@ class IdentifierPreparer: ) return element - def quote_identifier(self, value): + def quote_identifier(self, value: str) -> str: """Quote an identifier. Subclasses should override this to provide database-dependent @@ -5349,7 +5550,7 @@ class IdentifierPreparer: + self.final_quote ) - def _requires_quotes(self, value): + def _requires_quotes(self, value: str) -> bool: """Return True if the given identifier requires quoting.""" lc_value = value.lower() return ( @@ -5364,7 +5565,7 @@ class IdentifierPreparer: not taking case convention into account.""" return not self.legal_characters.match(str(value)) - def quote_schema(self, schema, force=None): + def quote_schema(self, schema: str, force: Any = None) -> str: """Conditionally quote a schema name. @@ -5403,7 +5604,7 @@ class IdentifierPreparer: return self.quote(schema) - def quote(self, ident, force=None): + def quote(self, ident: str, force: Any = None) -> str: """Conditionally quote an identifier. The identifier is quoted if it is a reserved word, contains @@ -5474,11 +5675,19 @@ class IdentifierPreparer: name = self.quote_schema(effective_schema) + "." + name return name - def format_label(self, label, name=None): + def format_label( + self, label: Label[Any], name: Optional[str] = None + ) -> str: return self.quote(name or label.name) - def format_alias(self, alias, name=None): - return self.quote(name or alias.name) + def format_alias( + self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None + ) -> str: + if name is None: + assert alias is not None + return self.quote(alias.name) + else: + return self.quote(name) def format_savepoint(self, savepoint, name=None): # Running the savepoint name through quoting is unnecessary diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 5aded307b6..96e90b0ea1 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -13,6 +13,10 @@ from __future__ import annotations import collections.abc as collections_abc import typing +from typing import Any +from typing import List +from typing import MutableMapping +from typing import Optional from . import coercions from . import roles @@ -40,8 +44,8 @@ from .. import util class DMLState(CompileState): _no_parameters = True - _dict_parameters = None - _multi_parameters = None + _dict_parameters: Optional[MutableMapping[str, Any]] = None + _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None _ordered_values = None _parameter_ordering = None _has_multi_parameters = False diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 168da17ccc..08d632afd9 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -18,7 +18,9 @@ import re import typing from typing import Any from typing import Callable +from typing import Dict from typing import Generic +from typing import List from typing import Optional from typing import overload from typing import Sequence @@ -47,6 +49,7 @@ from .coercions import _document_text_coercion # noqa from .operators import ColumnOperators from .traversals import HasCopyInternals from .visitors import cloned_traverse +from .visitors import ExternallyTraversible from .visitors import InternalTraversal from .visitors import traverse from .visitors import Visitable @@ -68,6 +71,8 @@ if typing.TYPE_CHECKING: from ..engine import Connection from ..engine import Dialect from ..engine import Engine + from ..engine.base import _CompiledCacheType + from ..engine.base import _SchemaTranslateMapType _NUMERIC = Union[complex, "Decimal"] @@ -238,6 +243,7 @@ class ClauseElement( SupportsWrappingAnnotations, MemoizedHasCacheKey, HasCopyInternals, + ExternallyTraversible, CompilerElement, ): """Base class for elements of a programmatically constructed SQL @@ -398,7 +404,9 @@ class ClauseElement( """ return self._replace_params(True, optionaldict, kwargs) - def params(self, *optionaldict, **kwargs): + def params( + self, *optionaldict: Dict[str, Any], **kwargs: Any + ) -> ClauseElement: """Return a copy with :func:`_expression.bindparam` elements replaced. @@ -415,7 +423,12 @@ class ClauseElement( """ return self._replace_params(False, optionaldict, kwargs) - def _replace_params(self, unique, optionaldict, kwargs): + def _replace_params( + self, + unique: bool, + optionaldict: Optional[Dict[str, Any]], + kwargs: Dict[str, Any], + ) -> ClauseElement: if len(optionaldict) == 1: kwargs.update(optionaldict[0]) @@ -487,12 +500,12 @@ class ClauseElement( def _compile_w_cache( self, - dialect, - compiled_cache=None, - column_keys=None, - for_executemany=False, - schema_translate_map=None, - **kw, + dialect: Dialect, + compiled_cache: Optional[_CompiledCacheType] = None, + column_keys: Optional[List[str]] = None, + for_executemany: bool = False, + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + **kw: Any, ): if compiled_cache is not None and dialect._supports_statement_cache: elem_cache_key = self._generate_cache_key() @@ -1383,7 +1396,7 @@ class ColumnElement( """ return Cast(self, type_) - def label(self, name): + def label(self, name: Optional[str]) -> Label[_T]: """Produce a column label, i.e. `` AS ``. This is a shortcut to the :func:`_expression.label` function. @@ -1608,6 +1621,9 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): ("value", InternalTraversal.dp_plain_obj), ] + key: str + type: TypeEngine + _is_crud = False _is_bind_parameter = True _key_is_anon = False diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index eb3d17ee46..6e5eec1271 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -12,6 +12,7 @@ from __future__ import annotations from typing import Any +from typing import Sequence from typing import TypeVar from . import annotation @@ -839,6 +840,8 @@ class Function(FunctionElement): identifier: str + packagenames: Sequence[str] + type: TypeEngine = sqltypes.NULLTYPE """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 64bd4b951b..1a7a5f4d4b 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -7,14 +7,22 @@ from __future__ import annotations import typing +from typing import Any +from typing import Iterable +from typing import Mapping +from typing import Optional +from typing import Sequence -from sqlalchemy.util.langhelpers import TypingOnly from .. import util - +from ..util import TypingOnly +from ..util.typing import Literal if typing.TYPE_CHECKING: + from .base import ColumnCollection from .elements import ClauseElement + from .elements import Label from .selectable import FromClause + from .selectable import Subquery class SQLRole: @@ -35,7 +43,7 @@ class SQLRole: class UsesInspection: __slots__ = () - _post_inspect = None + _post_inspect: Literal[None] = None uses_inspection = True @@ -96,7 +104,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): _role_name = "Column expression or FROM clause" @property - def _select_iterable(self): + def _select_iterable(self) -> Sequence[ColumnsClauseRole]: raise NotImplementedError() @@ -150,6 +158,9 @@ class ExpressionElementRole(SQLRole): __slots__ = () _role_name = "SQL expression element" + def label(self, name: Optional[str]) -> Label[Any]: + raise NotImplementedError() + class ConstExprRole(ExpressionElementRole): __slots__ = () @@ -187,7 +198,7 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole): _is_subquery = False @property - def _hide_froms(self): + def _hide_froms(self) -> Iterable[FromClause]: raise NotImplementedError() @@ -195,8 +206,10 @@ class StrictFromClauseRole(FromClauseRole): __slots__ = () # does not allow text() or select() objects + c: ColumnCollection + @property - def description(self): + def description(self) -> str: raise NotImplementedError() @@ -204,7 +217,9 @@ class AnonymizedFromClauseRole(StrictFromClauseRole): __slots__ = () # calls .alias() as a post processor - def _anonymous_fromclause(self, name=None, flat=False): + def _anonymous_fromclause( + self, name: Optional[str] = None, flat: bool = False + ) -> FromClause: raise NotImplementedError() @@ -220,14 +235,14 @@ class StatementRole(SQLRole): __slots__ = () _role_name = "Executable SQL or text() construct" - _propagate_attrs = util.immutabledict() + _propagate_attrs: Mapping[str, Any] = util.immutabledict() class SelectStatementRole(StatementRole, ReturnsRowsRole): __slots__ = () _role_name = "SELECT construct or equivalent text() construct" - def subquery(self): + def subquery(self) -> Subquery: raise NotImplementedError( "All SelectStatementRole objects should implement a " ".subquery() method." diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index c270e15648..33e300bf61 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -51,7 +51,7 @@ from . import visitors from .base import DedupeColumnCollection from .base import DialectKWArgs from .base import Executable -from .base import SchemaEventTarget +from .base import SchemaEventTarget as SchemaEventTarget from .coercions import _document_text_coercion from .elements import ClauseElement from .elements import ColumnClause @@ -2676,6 +2676,10 @@ class DefaultGenerator(Executable, SchemaItem): def __init__(self, for_update=False): self.for_update = for_update + @util.memoized_property + def is_callable(self): + raise NotImplementedError() + def _set_parent(self, column, **kw): self.column = column if self.for_update: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e5c2bef686..09befb0786 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -53,7 +53,6 @@ from .base import Generative from .base import HasCompileState from .base import HasMemoized from .base import Immutable -from .base import prefix_anon_map from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import BindParameter @@ -69,10 +68,10 @@ from .elements import literal_column from .elements import TableValuedColumn from .elements import UnaryExpression from .visitors import InternalTraversal +from .visitors import prefix_anon_map from .. import exc from .. import util - and_ = BooleanClauseList.and_ _T = TypeVar("_T", bound=Any) @@ -855,6 +854,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.alias(name=name) +class NamedFromClause(FromClause): + named_with_column = True + + name: str + + class SelectLabelStyle(Enum): """Label style constants that may be passed to :meth:`_sql.Select.set_label_style`.""" @@ -1317,15 +1322,16 @@ class NoInit: # -> Lateral -> FromClause, but we accept SelectBase # w/ non-deprecated coercion # -> TableSample -> only for FromClause -class AliasedReturnsRows(NoInit, FromClause): +class AliasedReturnsRows(NoInit, NamedFromClause): """Base class of aliases against tables, subqueries, and other selectables.""" _is_from_container = True - named_with_column = True _supports_derived_columns = False + element: ClauseElement + _traverse_internals = [ ("element", InternalTraversal.dp_clauseelement), ("name", InternalTraversal.dp_anon_name), @@ -1423,6 +1429,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows): inherit_cache = True + element: FromClause + @classmethod def _factory(cls, selectable, name=None, flat=False): return coercions.expect( @@ -1689,6 +1697,8 @@ class CTE( + HasSuffixes._has_suffixes_traverse_internals ) + element: HasCTE + @classmethod def _factory(cls, selectable, name=None, recursive=False): r"""Return a new :class:`_expression.CTE`, @@ -1819,7 +1829,7 @@ class _CTEOpts(NamedTuple): nesting: bool -class HasCTE(roles.HasCTERole): +class HasCTE(roles.HasCTERole, ClauseElement): """Mixin that declares a class to include CTE support. .. versionadded:: 1.1 @@ -2247,6 +2257,8 @@ class Subquery(AliasedReturnsRows): inherit_cache = True + element: Select + @classmethod def _factory(cls, selectable, name=None): """Return a :class:`.Subquery` object.""" @@ -2331,7 +2343,7 @@ class FromGrouping(GroupedElement, FromClause): self.element = state["element"] -class TableClause(roles.DMLTableRole, Immutable, FromClause): +class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): """Represents a minimal "table" construct. This is a lightweight table object that has only a name, a @@ -2371,8 +2383,6 @@ class TableClause(roles.DMLTableRole, Immutable, FromClause): ("name", InternalTraversal.dp_string), ] - named_with_column = True - _is_table = True implicit_returning = False @@ -2542,7 +2552,7 @@ class ForUpdateArg(ClauseElement): SelfValues = typing.TypeVar("SelfValues", bound="Values") -class Values(Generative, FromClause): +class Values(Generative, NamedFromClause): """Represent a ``VALUES`` construct that can be used as a FROM element in a statement. @@ -2553,7 +2563,6 @@ class Values(Generative, FromClause): """ - named_with_column = True __visit_name__ = "values" _data = () diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 7d21f12624..b2b1d9bc29 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -35,13 +35,13 @@ from .elements import _NONE_NAME from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa -from .traversals import InternalTraversal from .type_api import Emulated from .type_api import NativeForEmulated # noqa from .type_api import to_instance from .type_api import TypeDecorator from .type_api import TypeEngine from .type_api import Variant # noqa +from .visitors import InternalTraversal from .. import event from .. import exc from .. import inspection diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 4fa23d370c..cf9487f939 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -15,7 +15,10 @@ import operator import typing from typing import Any from typing import Callable +from typing import Deque from typing import Dict +from typing import Set +from typing import Tuple from typing import Type from typing import TypeVar @@ -23,9 +26,9 @@ from . import operators from .cache_key import HasCacheKey from .visitors import _TraverseInternalsType from .visitors import anon_map -from .visitors import ExtendedInternalTraversal +from .visitors import ExternallyTraversible +from .visitors import HasTraversalDispatch from .visitors import HasTraverseInternals -from .visitors import InternalTraversal from .. import util from ..util import langhelpers @@ -35,6 +38,7 @@ COMPARE_SUCCEEDED = True def compare(obj1, obj2, **kw): + strategy: TraversalComparatorStrategy if kw.get("use_proxies", False): strategy = ColIdentityComparatorStrategy() else: @@ -45,16 +49,18 @@ def compare(obj1, obj2, **kw): def _preconfigure_traversals(target_hierarchy): for cls in util.walk_subclasses(target_hierarchy): - if hasattr(cls, "_traverse_internals"): - cls._generate_cache_attrs() + if hasattr(cls, "_generate_cache_attrs") and hasattr( + cls, "_traverse_internals" + ): + cls._generate_cache_attrs() # type: ignore _copy_internals.generate_dispatch( - cls, - cls._traverse_internals, + cls, # type: ignore + cls._traverse_internals, # type: ignore "_generated_copy_internals_traversal", ) _get_children.generate_dispatch( - cls, - cls._traverse_internals, + cls, # type: ignore + cls._traverse_internals, # type: ignore "_generated_get_children_traversal", ) @@ -125,54 +131,58 @@ class HasShallowCopy(HasTraverseInternals): meth_text = f"def {method_name}(self, d):\n{code}\n" return langhelpers._exec_code_in_env(meth_text, {}, method_name) - def _shallow_from_dict(self, d: Dict) -> None: + def _shallow_from_dict(self, d: Dict[str, Any]) -> None: cls = self.__class__ + shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None] try: shallow_from_dict = cls.__dict__[ "_generated_shallow_from_dict_traversal" ] except KeyError: - shallow_from_dict = ( - cls._generated_shallow_from_dict_traversal # type: ignore - ) = self._generate_shallow_from_dict( + shallow_from_dict = self._generate_shallow_from_dict( cls._traverse_internals, "_generated_shallow_from_dict_traversal", ) + cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa E501 + shallow_from_dict(self, d) def _shallow_to_dict(self) -> Dict[str, Any]: cls = self.__class__ + shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]] + try: shallow_to_dict = cls.__dict__[ "_generated_shallow_to_dict_traversal" ] except KeyError: - shallow_to_dict = ( - cls._generated_shallow_to_dict_traversal # type: ignore - ) = self._generate_shallow_to_dict( + shallow_to_dict = self._generate_shallow_to_dict( cls._traverse_internals, "_generated_shallow_to_dict_traversal" ) + cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa E501 return shallow_to_dict(self) - def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy): + def _shallow_copy_to( + self: SelfHasShallowCopy, other: SelfHasShallowCopy + ) -> None: cls = self.__class__ + shallow_copy: Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None] try: shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"] except KeyError: - shallow_copy = ( - cls._generated_shallow_copy_traversal # type: ignore - ) = self._generate_shallow_copy( + shallow_copy = self._generate_shallow_copy( cls._traverse_internals, "_generated_shallow_copy_traversal" ) + cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501 shallow_copy(self, other) - def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy: + def _clone(self: SelfHasShallowCopy, **kw: Any) -> SelfHasShallowCopy: """Create a shallow copy""" c = self.__class__.__new__(self.__class__) self._shallow_copy_to(c) @@ -246,7 +256,7 @@ class HasCopyInternals(HasTraverseInternals): setattr(self, attrname, result) -class _CopyInternalsTraversal(InternalTraversal): +class _CopyInternalsTraversal(HasTraversalDispatch): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -381,7 +391,7 @@ def _flatten_clauseelement(element): return element -class _GetChildrenTraversal(InternalTraversal): +class _GetChildrenTraversal(HasTraversalDispatch): """Generate a _children_traversal internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -463,13 +473,13 @@ def _resolve_name_for_compare(element, name, anon_map, **kw): return name -class TraversalComparatorStrategy( - ExtendedInternalTraversal, util.MemoizedSlots -): +class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): __slots__ = "stack", "cache", "anon_map" def __init__(self): - self.stack = deque() + self.stack: Deque[ + Tuple[ExternallyTraversible, ExternallyTraversible] + ] = deque() self.cache = set() def _memoized_attr_anon_map(self): @@ -653,7 +663,7 @@ class TraversalComparatorStrategy( if seq1 is None: return seq2 is None - completed = set() + completed: Set[object] = set() for clause in seq1: for other_clause in set(seq2).difference(completed): if self.compare_inner(clause, other_clause, **kw): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e0248adf0d..5114a2431d 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -21,9 +21,9 @@ from . import coercions from . import operators from . import roles from . import visitors -from .annotation import _deep_annotate # noqa -from .annotation import _deep_deannotate # noqa -from .annotation import _shallow_annotate # noqa +from .annotation import _deep_annotate as _deep_annotate +from .annotation import _deep_deannotate as _deep_deannotate +from .annotation import _shallow_annotate as _shallow_annotate from .base import _expand_cloned from .base import _from_objects from .base import ColumnSet diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 111ecd32ef..0c41e440ef 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -7,43 +7,46 @@ """Visitor/traversal interface and library functions. -SQLAlchemy schema and expression constructs rely on a Python-centric -version of the classic "visitor" pattern as the primary way in which -they apply functionality. The most common use of this pattern -is statement compilation, where individual expression classes match -up to rendering methods that produce a string result. Beyond this, -the visitor system is also used to inspect expressions for various -information and patterns, as well as for the purposes of applying -transformations to expressions. - -Examples of how the visit system is used can be seen in the source code -of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler`` -modules. Some background on clause adaption is also at -https://techspot.zzzeek.org/2008/01/23/expression-transformations/ . """ from __future__ import annotations from collections import deque +from enum import Enum import itertools import operator import typing from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import Iterator from typing import List +from typing import Mapping +from typing import Optional from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union from .. import exc from .. import util from ..util import langhelpers -from ..util import symbol from ..util._has_cy import HAS_CYEXTENSION -from ..util.langhelpers import _symbol +from ..util.typing import Protocol +from ..util.typing import Self if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_util import cache_anon_map as anon_map # noqa + from ._py_util import prefix_anon_map as prefix_anon_map + from ._py_util import cache_anon_map as anon_map else: - from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa + from sqlalchemy.cyextension.util import prefix_anon_map as prefix_anon_map + from sqlalchemy.cyextension.util import cache_anon_map as anon_map + __all__ = [ "iterate", @@ -54,57 +57,23 @@ __all__ = [ "Visitable", "ExternalTraversal", "InternalTraversal", + "anon_map", ] -_TraverseInternalsType = List[Tuple[str, _symbol]] - - -class HasTraverseInternals: - """base for classes that have a "traverse internals" element, - which defines all kinds of ways of traversing the elements of an object. - - """ - - __slots__ = () - - _traverse_internals: _TraverseInternalsType - - @util.preload_module("sqlalchemy.sql.traversals") - def get_children(self, omit_attrs=(), **kw): - r"""Return immediate child :class:`.visitors.Visitable` - elements of this :class:`.visitors.Visitable`. - - This is used for visit traversal. - - \**kw may contain flags that change the collection that is - returned, for example to return a subset of items in order to - cut down on larger traversals, or to return child items from a - different context (such as schema-level collections instead of - clause-level). - - """ - - traversals = util.preloaded.sql_traversals - - try: - traverse_internals = self._traverse_internals - except AttributeError: - # user-defined classes may not have a _traverse_internals - return [] - dispatch = traversals._get_children.run_generated_dispatch - return itertools.chain.from_iterable( - meth(obj, **kw) - for attrname, obj, meth in dispatch( - self, traverse_internals, "_generated_get_children_traversal" - ) - if attrname not in omit_attrs and obj is not None - ) +class _CompilerDispatchType(Protocol): + def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any: + ... class Visitable: """Base class for visitable objects. + :class:`.Visitable` is used to implement the SQL compiler dispatch + functions. Other forms of traversal such as for cache key generation + are implemented separately using the :class:`.HasTraverseInternals` + interface. + .. versionchanged:: 2.0 The :class:`.Visitable` class was named :class:`.Traversible` in the 1.4 series; the name is changed back to :class:`.Visitable` in 2.0 which is what it was prior to 1.4. @@ -117,32 +86,20 @@ class Visitable: __visit_name__: str + _original_compiler_dispatch: _CompilerDispatchType + + if typing.TYPE_CHECKING: + + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str: + ... + def __init_subclass__(cls) -> None: if "__visit_name__" in cls.__dict__: cls._generate_compiler_dispatch() super().__init_subclass__() @classmethod - def _generate_compiler_dispatch(cls): - """Assign dispatch attributes to various kinds of - "visitable" classes. - - Attributes include: - - * The ``_compiler_dispatch`` method, corresponding to - ``__visit_name__``. This is called "external traversal" because the - caller of each visit() method is responsible for sub-traversing the - inner elements of each object. This is appropriate for string - compilers and other traversals that need to call upon the inner - elements in a specific pattern. - - * internal traversal collections ``_children_traversal``, - ``_cache_key_traversal``, ``_copy_internals_traversal``, generated - from an optional ``_traverse_internals`` collection of symbols which - comes from the :class:`.InternalTraversal` list of symbols. This is - called "internal traversal". - - """ + def _generate_compiler_dispatch(cls) -> None: visit_name = cls.__visit_name__ if "_compiler_dispatch" in cls.__dict__: @@ -161,7 +118,9 @@ class Visitable: name = "visit_%s" % visit_name getter = operator.attrgetter(name) - def _compiler_dispatch(self, visitor, **kw): + def _compiler_dispatch( + self: Visitable, visitor: Any, **kw: Any + ) -> str: """Look for an attribute named "visit_" on the visitor, and call it with the same kw params. @@ -169,105 +128,20 @@ class Visitable: try: meth = getter(visitor) except AttributeError as err: - return visitor.visit_unsupported_compilation(self, err, **kw) + return visitor.visit_unsupported_compilation(self, err, **kw) # type: ignore # noqa E501 else: - return meth(self, **kw) + return meth(self, **kw) # type: ignore # noqa E501 - cls._compiler_dispatch = ( + cls._compiler_dispatch = ( # type: ignore cls._original_compiler_dispatch ) = _compiler_dispatch - def __class_getitem__(cls, key): + def __class_getitem__(cls, key: str) -> Any: # allow generic classes in py3.9+ return cls -class _HasTraversalDispatch: - r"""Define infrastructure for the :class:`.InternalTraversal` class. - - .. versionadded:: 2.0 - - """ - - __slots__ = () - - def __init_subclass__(cls) -> None: - cls._generate_traversal_dispatch() - super().__init_subclass__() - - def dispatch(self, visit_symbol): - """Given a method from :class:`._HasTraversalDispatch`, return the - corresponding method on a subclass. - - """ - name = self._dispatch_lookup[visit_symbol] - return getattr(self, name, None) - - def run_generated_dispatch( - self, target, internal_dispatch, generate_dispatcher_name - ): - try: - dispatcher = target.__class__.__dict__[generate_dispatcher_name] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = self.generate_dispatch( - target.__class__, internal_dispatch, generate_dispatcher_name - ) - return dispatcher(target, self) - - def generate_dispatch( - self, target_cls, internal_dispatch, generate_dispatcher_name - ): - dispatcher = self._generate_dispatcher( - internal_dispatch, generate_dispatcher_name - ) - # assert isinstance(target_cls, type) - setattr(target_cls, generate_dispatcher_name, dispatcher) - return dispatcher - - @classmethod - def _generate_traversal_dispatch(cls): - lookup = {} - clsdict = cls.__dict__ - for key, sym in clsdict.items(): - if key.startswith("dp_"): - visit_key = key.replace("dp_", "visit_") - sym_name = sym.name - assert sym_name not in lookup, sym_name - lookup[sym] = lookup[sym_name] = visit_key - if hasattr(cls, "_dispatch_lookup"): - lookup.update(cls._dispatch_lookup) - cls._dispatch_lookup = lookup - - def _generate_dispatcher(self, internal_dispatch, method_name): - names = [] - for attrname, visit_sym in internal_dispatch: - meth = self.dispatch(visit_sym) - if meth: - visit_name = ExtendedInternalTraversal._dispatch_lookup[ - visit_sym - ] - names.append((attrname, visit_name)) - - code = ( - (" return [\n") - + ( - ", \n".join( - " (%r, self.%s, visitor.%s)" - % (attrname, attrname, visit_name) - for attrname, visit_name in names - ) - ) - + ("\n ]\n") - ) - meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" - return langhelpers._exec_code_in_env(meth_text, {}, method_name) - - -class InternalTraversal(_HasTraversalDispatch): +class InternalTraversal(Enum): r"""Defines visitor symbols used for internal traversal. The :class:`.InternalTraversal` class is used in two ways. One is that @@ -306,18 +180,16 @@ class InternalTraversal(_HasTraversalDispatch): """ - __slots__ = () - - dp_has_cache_key = symbol("HC") + dp_has_cache_key = "HC" """Visit a :class:`.HasCacheKey` object.""" - dp_has_cache_key_list = symbol("HL") + dp_has_cache_key_list = "HL" """Visit a list of :class:`.HasCacheKey` objects.""" - dp_clauseelement = symbol("CE") + dp_clauseelement = "CE" """Visit a :class:`_expression.ClauseElement` object.""" - dp_fromclause_canonical_column_collection = symbol("FC") + dp_fromclause_canonical_column_collection = "FC" """Visit a :class:`_expression.FromClause` object in the context of the ``columns`` attribute. @@ -329,30 +201,30 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_clauseelement_tuples = symbol("CTS") + dp_clauseelement_tuples = "CTS" """Visit a list of tuples which contain :class:`_expression.ClauseElement` objects. """ - dp_clauseelement_list = symbol("CL") + dp_clauseelement_list = "CL" """Visit a list of :class:`_expression.ClauseElement` objects. """ - dp_clauseelement_tuple = symbol("CT") + dp_clauseelement_tuple = "CT" """Visit a tuple of :class:`_expression.ClauseElement` objects. """ - dp_executable_options = symbol("EO") + dp_executable_options = "EO" - dp_with_context_options = symbol("WC") + dp_with_context_options = "WC" - dp_fromclause_ordered_set = symbol("CO") + dp_fromclause_ordered_set = "CO" """Visit an ordered set of :class:`_expression.FromClause` objects. """ - dp_string = symbol("S") + dp_string = "S" """Visit a plain string value. Examples include table and column names, bound parameter keys, special @@ -363,10 +235,10 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_string_list = symbol("SL") + dp_string_list = "SL" """Visit a list of strings.""" - dp_anon_name = symbol("AN") + dp_anon_name = "AN" """Visit a potentially "anonymized" string value. The string value is considered to be significant for cache key @@ -374,7 +246,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_boolean = symbol("B") + dp_boolean = "B" """Visit a boolean value. The boolean value is considered to be significant for cache key @@ -382,7 +254,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_operator = symbol("O") + dp_operator = "O" """Visit an operator. The operator is a function from the :mod:`sqlalchemy.sql.operators` @@ -393,7 +265,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_type = symbol("T") + dp_type = "T" """Visit a :class:`.TypeEngine` object The type object is considered to be significant for cache key @@ -401,7 +273,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_plain_dict = symbol("PD") + dp_plain_dict = "PD" """Visit a dictionary with string keys. The keys of the dictionary should be strings, the values should @@ -410,22 +282,22 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_dialect_options = symbol("DO") + dp_dialect_options = "DO" """Visit a dialect options structure.""" - dp_string_clauseelement_dict = symbol("CD") + dp_string_clauseelement_dict = "CD" """Visit a dictionary of string keys to :class:`_expression.ClauseElement` objects. """ - dp_string_multi_dict = symbol("MD") + dp_string_multi_dict = "MD" """Visit a dictionary of string keys to values which may either be plain immutable/hashable or :class:`.HasCacheKey` objects. """ - dp_annotations_key = symbol("AK") + dp_annotations_key = "AK" """Visit the _annotations_cache_key element. This is a dictionary of additional information about a ClauseElement @@ -436,7 +308,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_plain_obj = symbol("PO") + dp_plain_obj = "PO" """Visit a plain python object. The value should be immutable and hashable, such as an integer. @@ -444,7 +316,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_named_ddl_element = symbol("DD") + dp_named_ddl_element = "DD" """Visit a simple named DDL element. The current object used by this method is the :class:`.Sequence`. @@ -454,57 +326,56 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_prefix_sequence = symbol("PS") + dp_prefix_sequence = "PS" """Visit the sequence represented by :class:`_expression.HasPrefixes` or :class:`_expression.HasSuffixes`. """ - dp_table_hint_list = symbol("TH") + dp_table_hint_list = "TH" """Visit the ``_hints`` collection of a :class:`_expression.Select` object. """ - dp_setup_join_tuple = symbol("SJ") + dp_setup_join_tuple = "SJ" - dp_memoized_select_entities = symbol("ME") + dp_memoized_select_entities = "ME" - dp_statement_hint_list = symbol("SH") + dp_statement_hint_list = "SH" """Visit the ``_statement_hints`` collection of a :class:`_expression.Select` object. """ - dp_unknown_structure = symbol("UK") + dp_unknown_structure = "UK" """Visit an unknown structure. """ - dp_dml_ordered_values = symbol("DML_OV") + dp_dml_ordered_values = "DML_OV" """Visit the values() ordered tuple list of an :class:`_expression.Update` object.""" - dp_dml_values = symbol("DML_V") + dp_dml_values = "DML_V" """Visit the values() dictionary of a :class:`.ValuesBase` (e.g. Insert or Update) object. """ - dp_dml_multi_values = symbol("DML_MV") + dp_dml_multi_values = "DML_MV" """Visit the values() multi-valued list of dictionaries of an :class:`_expression.Insert` object. """ - dp_propagate_attrs = symbol("PA") + dp_propagate_attrs = "PA" """Visit the propagate attrs dict. This hardcodes to the particular elements we care about right now.""" - -class ExtendedInternalTraversal(InternalTraversal): - """Defines additional symbols that are useful in caching applications. + """Symbols that follow are additional symbols that are useful in + caching applications. Traversals for :class:`_expression.ClauseElement` objects only need to use those symbols present in :class:`.InternalTraversal`. However, for @@ -513,9 +384,7 @@ class ExtendedInternalTraversal(InternalTraversal): """ - __slots__ = () - - dp_ignore = symbol("IG") + dp_ignore = "IG" """Specify an object that should be ignored entirely. This currently applies function call argument caching where some @@ -523,29 +392,235 @@ class ExtendedInternalTraversal(InternalTraversal): """ - dp_inspectable = symbol("IS") + dp_inspectable = "IS" """Visit an inspectable object where the return value is a :class:`.HasCacheKey` object.""" - dp_multi = symbol("M") + dp_multi = "M" """Visit an object that may be a :class:`.HasCacheKey` or may be a plain hashable object.""" - dp_multi_list = symbol("MT") + dp_multi_list = "MT" """Visit a tuple containing elements that may be :class:`.HasCacheKey` or may be a plain hashable object.""" - dp_has_cache_key_tuples = symbol("HT") + dp_has_cache_key_tuples = "HT" """Visit a list of tuples which contain :class:`.HasCacheKey` objects. """ - dp_inspectable_list = symbol("IL") + dp_inspectable_list = "IL" """Visit a list of inspectable objects which upon inspection are HasCacheKey objects.""" +_TraverseInternalsType = List[Tuple[str, InternalTraversal]] +"""a structure that defines how a HasTraverseInternals should be +traversed. + +This structure consists of a list of (attributename, internaltraversal) +tuples, where the "attributename" refers to the name of an attribute on an +instance of the HasTraverseInternals object, and "internaltraversal" refers +to an :class:`.InternalTraversal` enumeration symbol defining what kind +of data this attribute stores, which indicates to the traverser how it should +be handled. + +""" + + +class HasTraverseInternals: + """base for classes that have a "traverse internals" element, + which defines all kinds of ways of traversing the elements of an object. + + Compared to :class:`.Visitable`, which relies upon an external visitor to + define how the object is travered (i.e. the :class:`.SQLCompiler`), the + :class:`.HasTraverseInternals` interface allows classes to define their own + traversal, that is, what attributes are accessed and in what order. + + """ + + __slots__ = () + + _traverse_internals: _TraverseInternalsType + + @util.preload_module("sqlalchemy.sql.traversals") + def get_children( + self, omit_attrs: Tuple[str, ...] = (), **kw: Any + ) -> Iterable[HasTraverseInternals]: + r"""Return immediate child :class:`.visitors.HasTraverseInternals` + elements of this :class:`.visitors.HasTraverseInternals`. + + This is used for visit traversal. + + \**kw may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + + """ + + traversals = util.preloaded.sql_traversals + + try: + traverse_internals = self._traverse_internals + except AttributeError: + # user-defined classes may not have a _traverse_internals + return [] + + dispatch = traversals._get_children.run_generated_dispatch + return itertools.chain.from_iterable( + meth(obj, **kw) + for attrname, obj, meth in dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ) + if attrname not in omit_attrs and obj is not None + ) + + +class _InternalTraversalDispatchType(Protocol): + def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any: + ... + + +class HasTraversalDispatch: + r"""Define infrastructure for classes that perform internal traversals + + .. versionadded:: 2.0 + + """ + + __slots__ = () + + _dispatch_lookup: ClassVar[Dict[Union[InternalTraversal, str], str]] = {} + + def dispatch(self, visit_symbol: InternalTraversal) -> Callable[..., Any]: + """Given a method from :class:`.HasTraversalDispatch`, return the + corresponding method on a subclass. + + """ + name = _dispatch_lookup[visit_symbol] + return getattr(self, name, None) # type: ignore + + def run_generated_dispatch( + self, + target: object, + internal_dispatch: _TraverseInternalsType, + generate_dispatcher_name: str, + ) -> Any: + dispatcher: _InternalTraversalDispatchType + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + # traversals.py -> _preconfigure_traversals() + # may be used to run these ahead of time, but + # is not enabled right now. + # this block will generate any remaining dispatchers. + dispatcher = self.generate_dispatch( + target.__class__, internal_dispatch, generate_dispatcher_name + ) + return dispatcher(target, self) + + def generate_dispatch( + self, + target_cls: Type[object], + internal_dispatch: _TraverseInternalsType, + generate_dispatcher_name: str, + ) -> _InternalTraversalDispatchType: + dispatcher = self._generate_dispatcher( + internal_dispatch, generate_dispatcher_name + ) + # assert isinstance(target_cls, type) + setattr(target_cls, generate_dispatcher_name, dispatcher) + return dispatcher + + def _generate_dispatcher( + self, internal_dispatch: _TraverseInternalsType, method_name: str + ) -> _InternalTraversalDispatchType: + names = [] + for attrname, visit_sym in internal_dispatch: + meth = self.dispatch(visit_sym) + if meth: + visit_name = _dispatch_lookup[visit_sym] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) + ) + + ("\n ]\n") + ) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + return cast( + _InternalTraversalDispatchType, + langhelpers._exec_code_in_env(meth_text, {}, method_name), + ) + + +ExtendedInternalTraversal = InternalTraversal + + +def _generate_traversal_dispatch() -> None: + lookup = _dispatch_lookup + + for sym in InternalTraversal: + key = sym.name + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.value + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + + +_dispatch_lookup = HasTraversalDispatch._dispatch_lookup +_generate_traversal_dispatch() + + +class ExternallyTraversible(HasTraverseInternals, Visitable): + __slots__ = () + + _annotations: Collection[Any] = () + + if typing.TYPE_CHECKING: + + def get_children( + self, omit_attrs: Tuple[str, ...] = (), **kw: Any + ) -> Iterable[ExternallyTraversible]: + ... + + def _clone(self: Self, **kw: Any) -> Self: + """clone this element""" + raise NotImplementedError() + + def _copy_internals( + self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any + ) -> Self: + """Reassign internal elements to be clones of themselves. + + Called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy. + + The given clone function should be used, which may be applying + additional transformations to the element (i.e. replacement + traversal, cloned traversal, annotations). + + """ + raise NotImplementedError() + + +_ET = TypeVar("_ET", bound=ExternallyTraversible) +_TraverseCallableType = Callable[[_ET], None] +_TraverseTransformCallableType = Callable[ + [ExternallyTraversible], Optional[ExternallyTraversible] +] + + class ExternalTraversal: """Base class for visitor objects which can traverse externally using the :func:`.visitors.traverse` function. @@ -555,7 +630,8 @@ class ExternalTraversal: """ - __traverse_options__ = {} + __traverse_options__: Dict[str, Any] = {} + _next: Optional[ExternalTraversal] def traverse_single(self, obj: Visitable, **kw: Any) -> Any: for v in self.visitor_iterator: @@ -563,20 +639,22 @@ class ExternalTraversal: if meth: return meth(obj, **kw) - def iterate(self, obj): + def iterate( + self, obj: ExternallyTraversible + ) -> Iterator[ExternallyTraversible]: """Traverse the given expression structure, returning an iterator of all elements. """ return iterate(obj, self.__traverse_options__) - def traverse(self, obj): + def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: """Traverse and visit the given expression structure.""" return traverse(obj, self.__traverse_options__, self._visitor_dict) @util.memoized_property - def _visitor_dict(self): + def _visitor_dict(self) -> Dict[str, _TraverseCallableType[Any]]: visitors = {} for name in dir(self): @@ -585,16 +663,16 @@ class ExternalTraversal: return visitors @property - def visitor_iterator(self): + def visitor_iterator(self) -> Iterator[ExternalTraversal]: """Iterate through this visitor and each 'chained' visitor.""" - v = self + v: Optional[ExternalTraversal] = self while v: yield v v = getattr(v, "_next", None) - def chain(self, visitor): - """'Chain' an additional ClauseVisitor onto this ClauseVisitor. + def chain(self, visitor: ExternalTraversal) -> ExternalTraversal: + """'Chain' an additional ExternalTraversal onto this ExternalTraversal The chained visitor will receive all visit events after this one. @@ -614,14 +692,16 @@ class CloningExternalTraversal(ExternalTraversal): """ - def copy_and_process(self, list_): + def copy_and_process( + self, list_: List[ExternallyTraversible] + ) -> List[ExternallyTraversible]: """Apply cloned traversal to the given list of elements, and return the new list. """ return [self.traverse(x) for x in list_] - def traverse(self, obj): + def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: """Traverse and visit the given expression structure.""" return cloned_traverse( @@ -638,7 +718,9 @@ class ReplacingExternalTraversal(CloningExternalTraversal): """ - def replace(self, elem): + def replace( + self, elem: ExternallyTraversible + ) -> Optional[ExternallyTraversible]: """Receive pre-copied elements during a cloning traversal. If the method returns a new element, the element is used @@ -647,15 +729,19 @@ class ReplacingExternalTraversal(CloningExternalTraversal): """ return None - def traverse(self, obj): + def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: """Traverse and visit the given expression structure.""" - def replace(elem): + def replace( + elem: ExternallyTraversible, + ) -> Optional[ExternallyTraversible]: for v in self.visitor_iterator: - e = v.replace(elem) + e = cast(ReplacingExternalTraversal, v).replace(elem) if e is not None: return e + return None + return replacement_traverse(obj, self.__traverse_options__, replace) @@ -667,7 +753,9 @@ CloningVisitor = CloningExternalTraversal ReplacingCloningVisitor = ReplacingExternalTraversal -def iterate(obj, opts=util.immutabledict()): +def iterate( + obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT +) -> Iterator[ExternallyTraversible]: r"""Traverse the given expression structure, returning an iterator. Traversal is configured to be breadth-first. @@ -702,7 +790,11 @@ def iterate(obj, opts=util.immutabledict()): stack.append(t.get_children(**opts)) -def traverse_using(iterator, obj, visitors): +def traverse_using( + iterator: Iterable[ExternallyTraversible], + obj: ExternallyTraversible, + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> ExternallyTraversible: """Visit the given expression structure using the given iterator of objects. @@ -734,7 +826,11 @@ def traverse_using(iterator, obj, visitors): return obj -def traverse(obj, opts, visitors): +def traverse( + obj: ExternallyTraversible, + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> ExternallyTraversible: """Traverse and visit the given expression structure using the default iterator. @@ -767,7 +863,11 @@ def traverse(obj, opts, visitors): return traverse_using(iterate(obj, opts), obj, visitors) -def cloned_traverse(obj, opts, visitors): +def cloned_traverse( + obj: ExternallyTraversible, + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseTransformCallableType], +) -> ExternallyTraversible: """Clone the given expression structure, allowing modifications by visitors. @@ -794,20 +894,24 @@ def cloned_traverse(obj, opts, visitors): """ - cloned = {} + cloned: Dict[int, ExternallyTraversible] = {} stop_on = set(opts.get("stop_on", [])) - def deferred_copy_internals(obj): + def deferred_copy_internals( + obj: ExternallyTraversible, + ) -> ExternallyTraversible: return cloned_traverse(obj, opts, visitors) - def clone(elem, **kw): + def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible: if elem in stop_on: return elem else: if id(elem) not in cloned: if "replace" in kw: - newelem = kw["replace"](elem) + newelem = cast( + Optional[ExternallyTraversible], kw["replace"](elem) + ) if newelem is not None: cloned[id(elem)] = newelem return newelem @@ -823,11 +927,15 @@ def cloned_traverse(obj, opts, visitors): obj = clone( obj, deferred_copy_internals=deferred_copy_internals, **opts ) - clone = None # remove gc cycles + clone = None # type: ignore[assignment] # remove gc cycles return obj -def replacement_traverse(obj, opts, replace): +def replacement_traverse( + obj: ExternallyTraversible, + opts: Mapping[str, Any], + replace: _TraverseTransformCallableType, +) -> ExternallyTraversible: """Clone the given expression structure, allowing element replacement by a given replacement function. @@ -854,10 +962,12 @@ def replacement_traverse(obj, opts, replace): cloned = {} stop_on = {id(x) for x in opts.get("stop_on", [])} - def deferred_copy_internals(obj): + def deferred_copy_internals( + obj: ExternallyTraversible, + ) -> ExternallyTraversible: return replacement_traverse(obj, opts, replace) - def clone(elem, **kw): + def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible: if ( id(elem) in stop_on or "no_replacement_traverse" in elem._annotations @@ -888,5 +998,5 @@ def replacement_traverse(obj, opts, replace): obj = clone( obj, deferred_copy_internals=deferred_copy_internals, **opts ) - clone = None # remove gc cycles + clone = None # type: ignore[assignment] # remove gc cycles return obj diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 8cb84f73f5..ae61155ffa 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -305,9 +305,11 @@ def _update_argspec_defaults_into_env(spec, env): return spec -def _exec_code_in_env(code, env, fn_name): +def _exec_code_in_env( + code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str +) -> Callable[..., Any]: exec(code, env) - return env[fn_name] + return env[fn_name] # type: ignore[no-any-return] _PF = TypeVar("_PF") @@ -1181,7 +1183,7 @@ class memoized_property(Generic[_T]): obj.__dict__.pop(name, None) -def memoized_instancemethod(fn): +def memoized_instancemethod(fn: _F) -> _F: """Decorate a method memoize its return value. Best applied to no-arg methods: memoization is not sensitive to @@ -1201,7 +1203,7 @@ def memoized_instancemethod(fn): self.__dict__[fn.__name__] = memo return result - return update_wrapper(oneshot, fn) + return update_wrapper(oneshot, fn) # type: ignore class HasMemoized: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 160eabd85f..c089616e4e 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -3,10 +3,11 @@ from __future__ import annotations import sys import typing from typing import Any -from typing import Callable # noqa from typing import cast from typing import Dict from typing import ForwardRef +from typing import Iterable +from typing import Tuple from typing import Type from typing import TypeVar from typing import Union @@ -16,6 +17,11 @@ from typing_extensions import NotRequired as NotRequired # noqa from . import compat _T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT") +_KT_co = TypeVar("_KT_co", covariant=True) +_KT_contra = TypeVar("_KT_contra", contravariant=True) +_VT = TypeVar("_VT") +_VT_co = TypeVar("_VT_co", covariant=True) Self = TypeVar("Self", bound=Any) @@ -45,6 +51,18 @@ else: from typing_extensions import Protocol as Protocol # noqa F401 from typing_extensions import TypedDict as TypedDict # noqa F401 +# copied from TypeShed, required in order to implement +# MutableMapping.update() + + +class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): + def keys(self) -> Iterable[_KT]: + ... + + def __getitem__(self, __k: _KT) -> _VT_co: + ... + + # work around https://github.com/microsoft/pyright/issues/3025 _LiteralStar = Literal["*"] @@ -120,7 +138,9 @@ def make_union_type(*types): return cast(Any, Union).__getitem__(types) -def expand_unions(type_, include_union=False, discard_none=False): +def expand_unions( + type_: Type[Any], include_union: bool = False, discard_none: bool = False +) -> Tuple[Type[Any], ...]: """Return a type as as a tuple of individual types, expanding for ``Union`` types.""" diff --git a/pyproject.toml b/pyproject.toml index b90feae498..407af71c3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,30 +40,12 @@ markers = [ ] [tool.pyright] -include = [ - "lib/sqlalchemy/engine/base.py", - "lib/sqlalchemy/engine/events.py", - "lib/sqlalchemy/engine/interfaces.py", - "lib/sqlalchemy/engine/_py_row.py", - "lib/sqlalchemy/engine/result.py", - "lib/sqlalchemy/engine/row.py", - "lib/sqlalchemy/engine/util.py", - "lib/sqlalchemy/engine/url.py", - "lib/sqlalchemy/pool/", - "lib/sqlalchemy/event/", - "lib/sqlalchemy/events.py", - "lib/sqlalchemy/exc.py", - "lib/sqlalchemy/log.py", - "lib/sqlalchemy/inspection.py", - "lib/sqlalchemy/schema.py", - "lib/sqlalchemy/types.py", - "lib/sqlalchemy/util/", -] + reportPrivateUsage = "none" reportUnusedClass = "none" reportUnusedFunction = "none" - +reportTypedDictNotRequiredAccess = "warning" [tool.mypy] mypy_path = "./lib/" @@ -99,6 +81,11 @@ ignore_errors = true # strict checking [[tool.mypy.overrides]] module = [ + "sqlalchemy.sql.annotation", + "sqlalchemy.sql.cache_key", + "sqlalchemy.sql.roles", + "sqlalchemy.sql.visitors", + "sqlalchemy.sql._py_util", "sqlalchemy.connectors.*", "sqlalchemy.engine.*", "sqlalchemy.ext.associationproxy", @@ -117,6 +104,12 @@ strict = true [[tool.mypy.overrides]] module = [ + "sqlalchemy.sql.coercions", + "sqlalchemy.sql.compiler", + #"sqlalchemy.sql.crud", + #"sqlalchemy.sql.default_comparator", + "sqlalchemy.sql.naming", + "sqlalchemy.sql.traversals", "sqlalchemy.util.*", "sqlalchemy.engine.cursor", "sqlalchemy.engine.default",