cls,
result_columns: List[ResultColumnsEntry],
loose_column_name_matching: bool = False,
- ) -> Dict[Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]]:
+ ) -> Dict[
+ Union[str, object], Tuple[str, Tuple[Any, ...], TypeEngine[Any], int]
+ ]:
"""when matching cursor.description to a set of names that are present
in a Compiled object, as is the case with TextualSelect, get all the
names we expect might match those in cursor.description.
"""
d: Dict[
- Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]
+ Union[str, object],
+ Tuple[str, Tuple[Any, ...], TypeEngine[Any], int],
] = {}
for ridx, elem in enumerate(result_columns):
key = elem[RM_RENDERED_NAME]
-
if key in d:
# conflicting keyname - just add the column-linked objects
# to the existing record. if there is a duplicate column
from .. import event
from .. import exc
from .. import pool
+from .. import TupleType
from .. import types as sqltypes
from .. import util
from ..sql import compiler
from ..sql.compiler import Compiled
from ..sql.compiler import ResultColumnsEntry
from ..sql.compiler import TypeCompiler
+ from ..sql.dml import DMLState
+ from ..sql.elements import BindParameter
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
cursor: DBAPICursor
compiled_parameters: List[_MutableCoreSingleExecuteParams]
parameters: _DBAPIMultiExecuteParams
- extracted_parameters: _CoreSingleExecuteParams
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]]
_empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT)
compiled: SQLCompiler,
parameters: _CoreMultiExecuteParams,
invoked_statement: Executable,
- extracted_parameters: _CoreSingleExecuteParams,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]],
cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
) -> ExecutionContext:
"""Initialize execution context for a Compiled construct."""
inputsizes, self.cursor, self.statement, self.parameters, self
)
- has_escaped_names = bool(compiled.escaped_bind_names)
- if has_escaped_names:
+ if compiled.escaped_bind_names:
escaped_bind_names = compiled.escaped_bind_names
+ else:
+ escaped_bind_names = None
if dialect.positional:
items = [
if key in self._expanded_parameters:
if bindparam.type._is_tuple_type:
- num = len(bindparam.type.types)
+ tup_type = cast(TupleType, bindparam.type)
+ num = len(tup_type.types)
dbtypes = inputsizes[bindparam]
generic_inputsizes.extend(
(
(
escaped_bind_names.get(paramname, paramname)
- if has_escaped_names
+ if escaped_bind_names is not None
else paramname
),
dbtypes[idx % num],
- bindparam.type.types[idx % num],
+ tup_type.types[idx % num],
)
for idx, paramname in enumerate(
self._expanded_parameters[key]
(
(
escaped_bind_names.get(paramname, paramname)
- if has_escaped_names
+ if escaped_bind_names is not None
else paramname
),
dbtype,
escaped_name = (
escaped_bind_names.get(key, key)
- if has_escaped_names
+ if escaped_bind_names is not None
else key
)
else:
assert column is not None
assert parameters is not None
- compile_state = cast(SQLCompiler, self.compiled).compile_state
+ compile_state = cast(
+ "DMLState", cast(SQLCompiler, self.compiled).compile_state
+ )
assert compile_state is not None
if (
isolate_multiinsert_groups
else:
d = {column.key: parameters[column.key]}
index = 0
+ assert compile_state._dict_parameters is not None
keys = compile_state._dict_parameters.keys()
d.update(
(key, parameters["%s_m%d" % (key, index)]) for key in keys
from ..sql.compiler import IdentifierPreparer
from ..sql.compiler import Linting
from ..sql.compiler import SQLCompiler
+ from ..sql.elements import BindParameter
from ..sql.elements import ClauseElement
from ..sql.schema import Column
- from ..sql.schema import ColumnDefault
+ from ..sql.schema import DefaultGenerator
from ..sql.schema import Sequence as Sequence_SchemaItem
from ..sql.sqltypes import Integer
from ..sql.type_api import TypeEngine
"""
+ _supports_statement_cache: bool
+ """internal evaluation for supports_statement_cache"""
+
bind_typing = BindTyping.NONE
"""define a means of passing typing information to the database and/or
driver for bound parameters.
compiled: SQLCompiler,
parameters: _CoreMultiExecuteParams,
invoked_statement: Executable,
- extracted_parameters: _CoreSingleExecuteParams,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]],
cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
) -> ExecutionContext:
raise NotImplementedError()
def _exec_default(
self,
column: Optional[Column[Any]],
- default: ColumnDefault,
+ default: DefaultGenerator,
type_: Optional[TypeEngine[Any]],
) -> Any:
raise NotImplementedError()
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import NoReturn
def result_tuple(
fields: Sequence[str], extra: Optional[Any] = None
-) -> Callable[[_RawRowType], Row]:
+) -> Callable[[Iterable[Any]], Row]:
parent = SimpleResultMetaData(fields, extra)
return functools.partial(
Row, parent, parent._processors, parent._keymap, Row._default_key_style
from .engine.interfaces import _DBAPIAnyExecuteParams
from .engine.interfaces import Dialect
from .sql.compiler import Compiled
+ from .sql.compiler import TypeCompiler
from .sql.elements import ClauseElement
if typing.TYPE_CHECKING:
def __init__(
self,
- compiler: "Compiled",
- element_type: Type["ClauseElement"],
+ compiler: Union[Compiled, TypeCompiler],
+ element_type: Type[ClauseElement],
message: Optional[str] = None,
):
super(UnsupportedCompilationError, self).__init__(
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import SupportsIndex
+from ..util.typing import SupportsKeysAndGetItem
if typing.TYPE_CHECKING:
from ..orm.attributes import InstrumentedAttribute
return (item[0], self._get(item[1]))
@overload
- def update(self, __m: Mapping[_KT, _VT], **kwargs: _VT) -> None:
+ def update(
+ self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT
+ ) -> None:
...
@overload
from __future__ import annotations
+import typing
+from typing import Any
from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+ from .cache_key import CacheConst
class prefix_anon_map(Dict[str, str]):
"""
- def __missing__(self, key):
+ def __missing__(self, key: str) -> str:
(ident, derived) = key.split(" ", 1)
anonymous_counter = self.get(derived, 1)
- self[derived] = anonymous_counter + 1
+ self[derived] = anonymous_counter + 1 # type: ignore
value = f"{derived}_{anonymous_counter}"
self[key] = value
return value
-class cache_anon_map(Dict[int, str]):
+class cache_anon_map(
+ Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]]
+):
"""A map that creates new keys for missing key access.
Produces an incrementing sequence given a series of unique keys.
_index = 0
- def get_anon(self, object_):
+ def get_anon(self, object_: Any) -> Tuple[str, bool]:
idself = id(object_)
if idself in self:
- return self[idself], True
+ s_val = self[idself]
+ assert s_val is not True
+ return s_val, True
else:
# inline of __missing__
self[idself] = id_ = str(self._index)
return id_, False
- def __missing__(self, key):
+ def __missing__(self, key: int) -> str:
self[key] = val = str(self._index)
self._index += 1
return val
from __future__ import annotations
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Mapping
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
+
from . import operators
-from .base import HasCacheKey
-from .traversals import anon_map
+from .cache_key import HasCacheKey
+from .visitors import anon_map
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .. import util
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+ from .visitors import _TraverseInternalsType
+ from ..util.typing import Self
+
+_AnnotationDict = Mapping[str, Any]
+
+EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT
+
-EMPTY_ANNOTATIONS = util.immutabledict()
+SelfSupportsAnnotations = TypeVar(
+ "SelfSupportsAnnotations", bound="SupportsAnnotations"
+)
-class SupportsAnnotations:
+class SupportsAnnotations(ExternallyTraversible):
__slots__ = ()
- _annotations = EMPTY_ANNOTATIONS
+ _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS
+ proxy_set: Set[SupportsAnnotations]
+ _is_immutable: bool
+
+ def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations:
+ raise NotImplementedError()
+
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
+ raise NotImplementedError()
@util.memoized_property
- def _annotations_cache_key(self):
+ def _annotations_cache_key(self) -> Tuple[Any, ...]:
anon_map_ = anon_map()
return (
"_annotations",
)
+SelfSupportsCloneAnnotations = TypeVar(
+ "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations"
+)
+
+
class SupportsCloneAnnotations(SupportsAnnotations):
- __slots__ = ()
+ if not typing.TYPE_CHECKING:
+ __slots__ = ()
- _clone_annotations_traverse_internals = [
+ _clone_annotations_traverse_internals: _TraverseInternalsType = [
("_annotations", InternalTraversal.dp_annotations_key)
]
- def _annotate(self, values):
+ def _annotate(
+ self: SelfSupportsCloneAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsCloneAnnotations:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
new.__dict__.pop("_generate_cache_key", None)
return new
- def _with_annotations(self, values):
+ def _with_annotations(
+ self: SelfSupportsCloneAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsCloneAnnotations:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
new.__dict__.pop("_generate_cache_key", None)
return new
- def _deannotate(self, values=None, clone=False):
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
"""return a copy of this :class:`_expression.ClauseElement`
with annotations
removed.
return self
+SelfSupportsWrappingAnnotations = TypeVar(
+ "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations"
+)
+
+
class SupportsWrappingAnnotations(SupportsAnnotations):
__slots__ = ()
- def _annotate(self, values):
+ _constructor: Callable[..., SupportsWrappingAnnotations]
+ entity_namespace: Mapping[str, Any]
+
+ def _annotate(self, values: _AnnotationDict) -> Annotated:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
"""
- return Annotated(self, values)
+ return Annotated._as_annotated_instance(self, values)
- def _with_annotations(self, values):
+ def _with_annotations(self, values: _AnnotationDict) -> Annotated:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
"""
- return Annotated(self, values)
-
- def _deannotate(self, values=None, clone=False):
+ return Annotated._as_annotated_instance(self, values)
+
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
"""return a copy of this :class:`_expression.ClauseElement`
with annotations
removed.
return self
-class Annotated:
- """clones a SupportsAnnotated and applies an 'annotations' dictionary.
+SelfAnnotated = TypeVar("SelfAnnotated", bound="Annotated")
+
+
+class Annotated(SupportsAnnotations):
+ """clones a SupportsAnnotations and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
_is_column_operators = False
- def __new__(cls, *args):
- if not args:
- # clone constructor
- return object.__new__(cls)
- else:
- element, values = args
- # pull appropriate subclass from registry of annotated
- # classes
- try:
- cls = annotated_classes[element.__class__]
- except KeyError:
- cls = _new_annotation_type(element.__class__, cls)
- return object.__new__(cls)
-
- def __init__(self, element, values):
+ @classmethod
+ def _as_annotated_instance(
+ cls, element: SupportsWrappingAnnotations, values: _AnnotationDict
+ ) -> Annotated:
+ try:
+ cls = annotated_classes[element.__class__]
+ except KeyError:
+ cls = _new_annotation_type(element.__class__, cls)
+ return cls(element, values)
+
+ _annotations: util.immutabledict[str, Any]
+ __element: SupportsWrappingAnnotations
+ _hash: int
+
+ def __new__(cls: Type[SelfAnnotated], *args: Any) -> SelfAnnotated:
+ return object.__new__(cls)
+
+ def __init__(
+ self, element: SupportsWrappingAnnotations, values: _AnnotationDict
+ ):
self.__dict__ = element.__dict__.copy()
self.__dict__.pop("_annotations_cache_key", None)
self.__dict__.pop("_generate_cache_key", None)
self._annotations = util.immutabledict(values)
self._hash = hash(element)
- def _annotate(self, values):
+ def _annotate(
+ self: SelfAnnotated, values: _AnnotationDict
+ ) -> SelfAnnotated:
_values = self._annotations.union(values)
return self._with_annotations(_values)
- def _with_annotations(self, values):
+ def _with_annotations(
+ self: SelfAnnotated, values: util.immutabledict[str, Any]
+ ) -> SelfAnnotated:
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone.__dict__.pop("_annotations_cache_key", None)
clone._annotations = values
return clone
- def _deannotate(self, values=None, clone=True):
+ @overload
+ def _deannotate(
+ self: SelfAnnotated,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfAnnotated:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> Annotated:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = True,
+ ) -> SupportsAnnotations:
if values is None:
return self.__element
else:
)
)
- def _compiler_dispatch(self, visitor, **kw):
- return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
+ if not typing.TYPE_CHECKING:
+ # manually proxy some methods that need extra attention
+ def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any:
+ return self.__element.__class__._compiler_dispatch(
+ self, visitor, **kw
+ )
- @property
- def _constructor(self):
- return self.__element._constructor
+ @property
+ def _constructor(self):
+ return self.__element._constructor
- def _clone(self, **kw):
+ def _clone(self: SelfAnnotated, **kw: Any) -> SelfAnnotated:
clone = self.__element._clone(**kw)
if clone is self.__element:
# detect immutable, don't change anything
clone.__dict__.update(self.__dict__)
return self.__class__(clone, self._annotations)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]:
return self.__class__, (self.__element, self._annotations)
- def __hash__(self):
+ def __hash__(self) -> int:
return self._hash
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if self._is_column_operators:
return self.__element.__class__.__eq__(self, other)
else:
return hash(other) == hash(self)
@property
- def entity_namespace(self):
+ def entity_namespace(self) -> Mapping[str, Any]:
if "entity_namespace" in self._annotations:
- return self._annotations["entity_namespace"].entity_namespace
+ return cast(
+ SupportsWrappingAnnotations,
+ self._annotations["entity_namespace"],
+ ).entity_namespace
else:
return self.__element.entity_namespace
# so that the resulting objects are pickleable; additionally, other
# decisions can be made up front about the type of object being annotated
# just once per class rather than per-instance.
-annotated_classes = {}
+annotated_classes: Dict[
+ Type[SupportsWrappingAnnotations], Type[Annotated]
+] = {}
+
+_SA = TypeVar("_SA", bound="SupportsAnnotations")
def _deep_annotate(
- element, annotations, exclude=None, detect_subquery_cols=False
-):
+ element: _SA,
+ annotations: _AnnotationDict,
+ exclude: Optional[Sequence[SupportsAnnotations]] = None,
+ detect_subquery_cols: bool = False,
+) -> _SA:
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
# annotated objects hack the __hash__() method so if we want to
# uniquely process them we have to use id()
- cloned_ids = {}
+ cloned_ids: Dict[int, SupportsAnnotations] = {}
- def clone(elem, **kw):
+ def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
kw["detect_subquery_cols"] = detect_subquery_cols
id_ = id(elem)
return newelem
if element is not None:
- element = clone(element)
- clone = None # remove gc cycles
+ element = cast(_SA, clone(element))
+ clone = None # type: ignore # remove gc cycles
return element
-def _deep_deannotate(element, values=None):
+def _deep_deannotate(
+ element: _SA, values: Optional[Sequence[str]] = None
+) -> _SA:
"""Deep copy the given element, removing annotations."""
- cloned = {}
+ cloned: Dict[Any, SupportsAnnotations] = {}
- def clone(elem, **kw):
+ def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
+ key: Any
if values:
key = id(elem)
else:
return cloned[key]
if element is not None:
- element = clone(element)
- clone = None # remove gc cycles
+ element = cast(_SA, clone(element))
+ clone = None # type: ignore # remove gc cycles
return element
-def _shallow_annotate(element, annotations):
+def _shallow_annotate(
+ element: SupportsAnnotations, annotations: _AnnotationDict
+) -> SupportsAnnotations:
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
return element
-def _new_annotation_type(cls, base_cls):
+def _new_annotation_type(
+ cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated]
+) -> Type[Annotated]:
+ """Generates a new class that subclasses Annotated and proxies a given
+ element type.
+
+ """
if issubclass(cls, Annotated):
return cls
elif cls in annotated_classes:
base_cls = annotated_classes[super_]
break
- annotated_classes[cls] = anno_cls = type(
- "Annotated%s" % cls.__name__, (base_cls, cls), {}
+ annotated_classes[cls] = anno_cls = cast(
+ Type[Annotated],
+ type("Annotated%s" % cls.__name__, (base_cls, cls), {}),
)
globals()["Annotated%s" % cls.__name__] = anno_cls
# some classes include this even if they have traverse_internals
# e.g. BindParameter, add it if present.
if cls.__dict__.get("inherit_cache", False):
- anno_cls.inherit_cache = True
+ anno_cls.inherit_cache = True # type: ignore
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
return anno_cls
-def _prepare_annotations(target_hierarchy, base_cls):
+def _prepare_annotations(
+ target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated]
+) -> None:
for cls in util.walk_subclasses(target_hierarchy):
_new_annotation_type(cls, base_cls)
import operator
import re
import typing
+from typing import MutableMapping
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import TypeVar
from . import roles
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
-from ..util._has_cy import HAS_CYEXTENSION
-
-if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_util import prefix_anon_map # noqa
-else:
- from sqlalchemy.cyextension.util import prefix_anon_map # noqa
if typing.TYPE_CHECKING:
+ from .elements import ColumnElement
from ..engine import Connection
from ..engine import Result
from ..engine.interfaces import _CoreMultiExecuteParams
# symbols, mypy reports: "error: _Fn? not callable"
_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_AmbiguousTableNameMap = MutableMapping[str, str]
+
class Immutable:
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
_is_singleton_constant = True
+ _singleton: SingletonConstant
+
+ proxy_set: Set[ColumnElement]
+
def __new__(cls, *arg, **kw):
return cls._singleton
plugins = {}
+ _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
+
@classmethod
def create_for_statement(cls, statement, compiler, **kw):
# factory construction.
from itertools import zip_longest
import typing
from typing import Any
-from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
from typing import NamedTuple
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import Union
from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import HasTraversalDispatch
+from .visitors import HasTraverseInternals
from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
from .. import util
from ..inspection import inspect
from ..util import HasMemoized
from ..util.typing import Literal
-
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from .elements import BindParameter
+ from .elements import ClauseElement
+ from .visitors import _TraverseInternalsType
+ from ..engine.base import _CompiledCacheType
+ from ..engine.interfaces import _CoreSingleExecuteParams
+
+
+class _CacheKeyTraversalDispatchType(Protocol):
+ def __call__(
+ s, self: HasCacheKey, visitor: _CacheKeyTraversal
+ ) -> CacheKey:
+ ...
class CacheConst(enum.Enum):
__slots__ = ()
- _cache_key_traversal = NO_CACHE
+ _cache_key_traversal: Union[
+ _TraverseInternalsType, Literal[CacheConst.NO_CACHE]
+ ] = NO_CACHE
_is_has_cache_key = True
"""
- inherit_cache = None
+ inherit_cache: Optional[bool] = None
"""Indicate if this :class:`.HasCacheKey` instance should make use of the
cache key generation scheme used by its immediate superclass.
__slots__ = ()
+ _generated_cache_key_traversal: Any
+
@classmethod
- def _generate_cache_attrs(cls):
+ def _generate_cache_attrs(
+ cls,
+ ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]:
"""generate cache key dispatcher for a new class.
This sets the _generated_cache_key_traversal attribute once called
_cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
if _cache_key_traversal is None:
try:
- # this would be HasTraverseInternals
- _cache_key_traversal = cls._traverse_internals
+ # check for _traverse_internals, which is part of
+ # HasTraverseInternals
+ _cache_key_traversal = cast(
+ "Type[HasTraverseInternals]", cls
+ )._traverse_internals
except AttributeError:
cls._generated_cache_key_traversal = NO_CACHE
return NO_CACHE
# more complicated, so for the moment this is a little less
# efficient on startup but simpler.
return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ cls,
+ _cache_key_traversal,
+ "_generated_cache_key_traversal",
)
else:
_cache_key_traversal = cls.__dict__.get(
return NO_CACHE
return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ cls,
+ _cache_key_traversal,
+ "_generated_cache_key_traversal",
)
@util.preload_module("sqlalchemy.sql.elements")
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key(
+ self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+ ) -> Optional[Tuple[Any, ...]]:
"""return an optional cache key.
The cache key is a tuple which can contain any series of
dispatcher: Union[
Literal[CacheConst.NO_CACHE],
- Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"],
+ _CacheKeyTraversalDispatchType,
]
try:
dispatcher = cls.__dict__["_generated_cache_key_traversal"]
except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
+ # traversals.py -> _preconfigure_traversals()
+ # may be used to run these ahead of time, but
+ # is not enabled right now.
# this block will generate any remaining dispatchers.
dispatcher = cls._generate_cache_attrs()
anon_map[NO_CACHE] = True
return None
- result = (id_, cls)
+ result: Tuple[Any, ...] = (id_, cls)
# inline of _cache_key_traversal_visitor.run_generated_dispatch()
# Columns, this should be long lived. For select()
# statements, not so much, but they usually won't have
# annotations.
- result += self._annotations_cache_key
+ result += self._annotations_cache_key # type: ignore
elif (
meth is InternalTraversal.dp_clauseelement_list
or meth is InternalTraversal.dp_clauseelement_tuple
)
return result
- def _generate_cache_key(self):
+ def _generate_cache_key(self) -> Optional[CacheKey]:
"""return a cache key.
The cache key is a tuple which can contain any series of
"""
- bindparams = []
+ bindparams: List[BindParameter[Any]] = []
_anon_map = anon_map()
key = self._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
+ assert key is not None
return CacheKey(key, bindparams)
@classmethod
- def _generate_cache_key_for_object(cls, obj):
- bindparams = []
+ def _generate_cache_key_for_object(
+ cls, obj: HasCacheKey
+ ) -> Optional[CacheKey]:
+ bindparams: List[BindParameter[Any]] = []
_anon_map = anon_map()
key = obj._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
+ assert key is not None
return CacheKey(key, bindparams)
+class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey):
+ pass
+
+
class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
__slots__ = ()
@HasMemoized.memoized_instancemethod
- def _generate_cache_key(self):
+ def _generate_cache_key(self) -> Optional[CacheKey]:
return HasCacheKey._generate_cache_key(self)
"""
key: Tuple[Any, ...]
- bindparams: Sequence[BindParameter]
+ bindparams: Sequence[BindParameter[Any]]
- def __hash__(self):
+ # can't set __hash__ attribute because it interferes
+ # with namedtuple
+ # can't use "if not TYPE_CHECKING" because mypy rejects it
+ # inside of a NamedTuple
+ def __hash__(self) -> Optional[int]: # type: ignore
"""CacheKey itself is not hashable - hash the .key portion"""
-
return None
- def to_offline_string(self, statement_cache, statement, parameters):
+ def to_offline_string(
+ self,
+ statement_cache: _CompiledCacheType,
+ statement: ClauseElement,
+ parameters: _CoreSingleExecuteParams,
+ ) -> str:
"""Generate an "offline string" form of this :class:`.CacheKey`
The "offline string" is basically the string SQL for the
return repr((sql_str, param_tuple))
- def __eq__(self, other):
- return self.key == other.key
+ def __eq__(self, other: Any) -> bool:
+ return bool(self.key == other.key)
@classmethod
- def _diff_tuples(cls, left, right):
+ def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
ck1 = CacheKey(left, [])
ck2 = CacheKey(right, [])
return ck1._diff(ck2)
- def _whats_different(self, other):
+ def _whats_different(self, other: CacheKey) -> Iterator[str]:
k1 = self.key
k2 = other.key
- stack = []
+ stack: List[int] = []
pickup_index = 0
while True:
s1, s2 = k1, k2
pickup_index = stack.pop(-1)
break
- def _diff(self, other):
+ def _diff(self, other: CacheKey) -> str:
return ", ".join(self._whats_different(other))
- def __str__(self):
- stack = [self.key]
+ def __str__(self) -> str:
+ stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key]
output = []
sentinel = object()
return "CacheKey(key=%s)" % ("\n".join(output),)
- def _generate_param_dict(self):
+ def _generate_param_dict(self) -> Dict[str, Any]:
"""used for testing"""
- from .compiler import prefix_anon_map
-
_anon_map = prefix_anon_map()
return {b.key % _anon_map: b.effective_value for b in self.bindparams}
- def _apply_params_to_element(self, original_cache_key, target_element):
+ def _apply_params_to_element(
+ self, original_cache_key: CacheKey, target_element: ClauseElement
+ ) -> ClauseElement:
translate = {
k.key: v.value
for k, v in zip(original_cache_key.bindparams, self.bindparams)
return target_element.params(translate)
-class _CacheKeyTraversal(ExtendedInternalTraversal):
+class _CacheKeyTraversal(HasTraversalDispatch):
# very common elements are inlined into the main _get_cache_key() method
# to produce a dramatic savings in Python function call overhead
visit_propagate_attrs = PROPAGATE_ATTRS
def visit_with_context_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple((fn.__code__, c_key) for fn, c_key in obj)
- def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_inspectable(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
- def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_string_list(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple(obj)
- def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_multi(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
obj._gen_cache_key(anon_map, bindparams)
else obj,
)
- def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_multi_list(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
)
def visit_has_cache_key_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
)
def visit_has_cache_key_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
)
def visit_executable_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
)
def visit_inspectable_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return self.visit_has_cache_key_list(
attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
)
def visit_clauseelement_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return self.visit_has_cache_key_tuples(
attrname, obj, parent, anon_map, bindparams
)
def visit_fromclause_ordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
)
def visit_clauseelement_unordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
cache_keys = [
)
def visit_named_ddl_element(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, obj.name)
def visit_prefix_sequence(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
)
def visit_setup_join_tuple(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple(
(
target._gen_cache_key(anon_map, bindparams),
)
def visit_table_hint_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
),
)
- def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_plain_dict(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
def visit_dialect_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
)
def visit_string_clauseelement_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
)
def visit_string_multi_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
)
def visit_fromclause_canonical_column_collection(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# inlining into the internals of ColumnCollection
return (
attrname,
)
def visit_unknown_structure(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
anon_map[NO_CACHE] = True
return ()
def visit_dml_ordered_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
),
)
- def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_dml_values(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# in py37 we can assume two dictionaries created in the same
# insert ordering will retain that sorting
return (
)
def visit_dml_multi_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# multivalues are simply not cacheable right now
anon_map[NO_CACHE] = True
return ()
import typing
from typing import Any
from typing import Any as TODO_Any
+from typing import Dict
+from typing import List
+from typing import NoReturn
from typing import Optional
from typing import Type
from typing import TypeVar
from . import selectable
from . import traversals
from .elements import ClauseElement
+ from .elements import ColumnClause
_SR = TypeVar("_SR", bound=roles.SQLRole)
_StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole)
if isinstance(resolved, str):
strname = resolved = expr
else:
- cols = []
+ cols: List[ColumnClause[Any]] = []
visitors.traverse(resolved, {}, {"column": cols.append})
if cols:
column = cols[0]
def _literal_coercion(self, element, **kw):
raise NotImplementedError()
- _post_coercion = None
+ _post_coercion: Any = None
_resolve_literal_only = False
_skip_clauseelement_for_target_match = False
self._use_inspection = issubclass(role_class, roles.UsesInspection)
def _implicit_coercions(
- self, element, resolved, argname=None, **kw
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
) -> Any:
self._raise_for_expected(element, argname, resolved)
def _raise_for_expected(
self,
- element,
- argname=None,
- resolved=None,
- advice=None,
- code=None,
- err=None,
- ):
+ element: Any,
+ argname: Optional[str] = None,
+ resolved: Optional[Any] = None,
+ advice: Optional[str] = None,
+ code: Optional[str] = None,
+ err: Optional[Exception] = None,
+ **kw: Any,
+ ) -> NoReturn:
if resolved is not None and resolved is not element:
got = "%r object resolved from %r object" % (resolved, element)
else:
_resolve_literal_only = True
-class _ReturnsStringKey:
+class _ReturnsStringKey(RoleImpl):
__slots__ = ()
- def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, str):
- return original_element
+ def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ if isinstance(element, str):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
-class _ColumnCoercions:
+class _ColumnCoercions(RoleImpl):
__slots__ = ()
def _warn_for_scalar_subquery_coercion(self):
def _no_text_coercion(
- element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None
-):
+ element: Any,
+ argname: Optional[str] = None,
+ exc_cls: Type[exc.SQLAlchemyError] = exc.ArgumentError,
+ extra: Optional[str] = None,
+ err: Optional[Exception] = None,
+) -> NoReturn:
raise exc_cls(
"%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
"explicitly declared as text(%(expr)r)"
) from err
-class _NoTextCoercion:
+class _NoTextCoercion(RoleImpl):
__slots__ = ()
def _literal_coercion(self, element, argname=None, **kw):
self._raise_for_expected(element, argname)
-class _CoerceLiterals:
+class _CoerceLiterals(RoleImpl):
__slots__ = ()
_coerce_consts = False
_coerce_star = False
return element
-class _SelectIsNotFrom:
+class _SelectIsNotFrom(RoleImpl):
__slots__ = ()
def _raise_for_expected(
- self, element, argname=None, resolved=None, advice=None, **kw
- ):
+ self,
+ element: Any,
+ argname: Optional[str] = None,
+ resolved: Optional[Any] = None,
+ advice: Optional[str] = None,
+ code: Optional[str] = None,
+ err: Optional[Exception] = None,
+ **kw: Any,
+ ) -> NoReturn:
if (
not advice
and isinstance(element, roles.SelectStatementRole)
else:
code = None
- return super(_SelectIsNotFrom, self)._raise_for_expected(
+ super()._raise_for_expected(
element,
argname=argname,
resolved=resolved,
advice=advice,
code=code,
+ err=err,
**kw,
)
+ # never reached
+ assert False
class HasCacheKeyImpl(RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, traversals.HasCacheKey):
- return original_element
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, HasCacheKey):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, ExecutableOption):
- return original_element
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, ExecutableOption):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
self._warn_for_implicit_coercion(resolved)
return self._post_coercion(resolved.select(), **kw)
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _warn_for_implicit_coercion(self, elem):
util.warn(
if isinstance(element, collections_abc.Iterable) and not isinstance(
element, str
):
- non_literal_expressions = {}
+ non_literal_expressions: Dict[
+ Optional[operators.ColumnOperators[Any]],
+ operators.ColumnOperators[Any],
+ ] = {}
element = list(element)
for o in element:
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
self._raise_for_expected(element, **kw)
+
else:
non_literal_expressions[o] = o
elif o is None:
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if isinstance(resolved, roles.StrictFromClauseRole):
return elements.ClauseList(*resolved.c)
else:
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, str):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, str):
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, argname=None, **kw):
"""coerce the given value to :class:`._truncated_label`.
class LimitOffsetImpl(RoleImpl):
__slots__ = ()
- def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ def _implicit_coercions(
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved is None:
return None
else:
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if isinstance(resolved, roles.ExpressionElementRole):
return resolved.label(None)
else:
new = super(LabeledColumnExprImpl, self)._implicit_coercions(
- original_element, resolved, argname=argname, **kw
+ element, resolved, argname=argname, **kw
)
if isinstance(new, roles.ExpressionElementRole):
return new.label(None)
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
return resolved
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_lambda_element:
return resolved
else:
- return super(StatementImpl, self)._implicit_coercions(
- original_element, resolved, argname=argname, **kw
+ return super()._implicit_coercions(
+ element, resolved, argname=argname, **kw
)
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_text_clause:
return resolved.columns()
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class HasCTEImpl(ReturnsRowsImpl):
self._raise_for_expected(element, argname)
def _implicit_coercions(
- self, original_element, resolved, argname=None, legacy=False, **kw
- ):
- if isinstance(original_element, roles.JoinTargetRole):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ legacy: bool = False,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, roles.JoinTargetRole):
# note that this codepath no longer occurs as of
# #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match
# were set to False.
- return original_element
+ return element
elif legacy and resolved._is_select_statement:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT "
# in _ORMJoin->Join
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
def _implicit_coercions(
self,
- original_element,
- resolved,
- argname=None,
- explicit_subquery=False,
- allow_select=True,
- **kw,
- ):
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ explicit_subquery: bool = False,
+ allow_select: bool = True,
+ **kw: Any,
+ ) -> Any:
if resolved._is_select_statement:
if explicit_subquery:
return resolved.subquery()
elif resolved._is_text_clause:
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _post_coercion(self, element, deannotate=False, **kw):
if deannotate:
def _implicit_coercions(
self,
- original_element,
- resolved,
- argname=None,
- allow_select=False,
- **kw,
- ):
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ explicit_subquery: bool = False,
+ allow_select: bool = False,
+ **kw: Any,
+ ) -> Any:
if resolved._is_select_statement and allow_select:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT constructs "
)
return resolved._implicit_subquery
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class AnonymizedFromClauseImpl(StrictFromClauseImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
else:
return resolved.select()
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class CompoundElementImpl(_NoTextCoercion, RoleImpl):
import typing
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
+from typing import FrozenSet
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
from typing import Union
from . import base
from . import schema
from . import selectable
from . import sqltypes
+from .base import _from_objects
from .base import NO_ARG
-from .base import prefix_anon_map
from .elements import quoted_name
from .schema import Column
+from .sqltypes import TupleType
from .type_api import TypeEngine
+from .visitors import prefix_anon_map
from .. import exc
from .. import util
from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .base import _AmbiguousTableNameMap
+ from .base import CompileState
+ from .cache_key import CacheKey
+ from .elements import BindParameter
+ from .elements import ColumnClause
+ from .elements import Label
+ from .functions import Function
+ from .selectable import Alias
+ from .selectable import AliasedReturnsRows
+ from .selectable import CompoundSelectState
from .selectable import CTE
from .selectable import FromClause
+ from .selectable import NamedFromClause
+ from .selectable import ReturnsRows
+ from .selectable import Select
+ from .selectable import SelectState
+ from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _MutableCoreSingleExecuteParams
+ from ..engine.interfaces import _SchemaTranslateMapType
from ..engine.result import _ProcessorType
_FromHintsType = Dict["FromClause", str]
operators.nulls_last_op: " NULLS LAST",
}
-FUNCTIONS = {
+FUNCTIONS: Dict[Type[Function], str] = {
functions.coalesce: "coalesce",
functions.current_date: "CURRENT_DATE",
functions.current_time: "CURRENT_TIME",
name: str
"""column name, may be labeled"""
- objects: List[Any]
- """list of objects that should be able to locate this column
+ objects: Tuple[Any, ...]
+ """sequence of objects that should be able to locate this column
in a RowMapping. This is typically string names and aliases
as well as Column objects.
"""
+class _ResultMapAppender(Protocol):
+ def __call__(
+ self,
+ keyname: str,
+ name: str,
+ objects: Sequence[Any],
+ type_: TypeEngine[Any],
+ ) -> None:
+ ...
+
+
# integer indexes into ResultColumnsEntry used by cursor.py.
# some profiling showed integer access faster than named tuple
RM_RENDERED_NAME: Literal[0] = 0
RM_TYPE: Literal[3] = 3
+class _BaseCompilerStackEntry(TypedDict):
+ asfrom_froms: Set[FromClause]
+ correlate_froms: Set[FromClause]
+ selectable: ReturnsRows
+
+
+class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
+ compile_state: CompileState
+ need_result_map_for_nested: bool
+ need_result_map_for_compound: bool
+ select_0: ReturnsRows
+ insert_from_select: Select
+
+
class ExpandedState(NamedTuple):
statement: str
additional_parameters: _CoreSingleExecuteParams
defaults.
"""
- _cached_metadata = None
+ _cached_metadata: Optional[CursorResultMetaData] = None
_result_columns: Optional[List[ResultColumnsEntry]] = None
- schema_translate_map = None
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None
- execution_options = util.EMPTY_DICT
+ execution_options: _ExecuteOptions = util.EMPTY_DICT
"""
Execution options propagated from the statement. In some cases,
sub-elements of the statement can modify these.
"""
- _annotations = util.EMPTY_DICT
+ preparer: IdentifierPreparer
+
+ _annotations: _AnnotationDict = util.EMPTY_DICT
- compile_state = None
+ compile_state: Optional[CompileState] = None
"""Optional :class:`.CompileState` object that maintains additional
state used by the compiler.
"""
- cache_key = None
+ cache_key: Optional[CacheKey] = None
+ """The :class:`.CacheKey` that was generated ahead of creating this
+ :class:`.Compiled` object.
+
+ This is used for routines that need access to the original
+ :class:`.CacheKey` instance generated when the :class:`.Compiled`
+ instance was first cached, typically in order to reconcile
+ the original list of :class:`.BindParameter` objects with a
+ per-statement list that's generated on each call.
+
+ """
_gen_time: float
+ """Generation time of this :class:`.Compiled`, used for reporting
+ cache stats."""
def __init__(
self,
return self.string or ""
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
"""Return the bind params for this compiled object.
:param params: a dict of string/object pairs whose values will
isplaintext: bool = False
+ binds: Dict[str, BindParameter[Any]]
+ """a dictionary of bind parameter keys to BindParameter instances."""
+
+ bind_names: Dict[BindParameter[Any], str]
+ """a dictionary of BindParameter instances to "compiled" names
+ that are actually present in the generated SQL"""
+
+ stack: List[_CompilerStackEntry]
+ """major statements such as SELECT, INSERT, UPDATE, DELETE are
+ tracked in this stack using an entry format."""
+
result_columns: List[ResultColumnsEntry]
"""relates label names in the final SQL to a tuple of local
column/label name, ColumnElement object (if any) and
"""
- insert_single_values_expr = None
+ insert_single_values_expr: Optional[str] = None
"""When an INSERT is compiled with a single set of parameters inside
a VALUES expression, the string is assigned here, where it can be
used for insert batching schemes to rewrite the VALUES expression.
"""
- literal_execute_params = frozenset()
+ literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as literal values at statement
execution time.
"""
- post_compile_params = frozenset()
+ post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as bound parameter placeholders
at statement execution time.
"""
- escaped_bind_names = util.EMPTY_DICT
+ escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
"""Late escaping of bound parameter names that has to be converted
to the original name when looking in the parameter dictionary.
"""if True, and this in insert, use cursor.lastrowid to populate
result.inserted_primary_key. """
- _cache_key_bind_match = None
+ _cache_key_bind_match: Optional[
+ Tuple[
+ Dict[
+ BindParameter[Any],
+ List[BindParameter[Any]],
+ ],
+ Dict[
+ str,
+ BindParameter[Any],
+ ],
+ ]
+ ] = None
"""a mapping that will relate the BindParameter object we compile
to those that are part of the extracted collection of parameters
in the cache key, if we were given a cache key.
"""
- positiontup: Optional[Sequence[str]] = None
+ positiontup: Optional[List[str]] = None
"""for a compiled construct that uses a positional paramstyle, will be
a sequence of strings, indicating the names of bound parameters in order.
inline: bool = False
+ ctes: Optional[MutableMapping[CTE, str]]
+
+ # Detect same CTE references - Dict[(level, name), cte]
+ # Level is required for supporting nesting
+ ctes_by_level_name: Dict[Tuple[int, str], CTE]
+
+ # To retrieve key/level in ctes_by_level_name -
+ # Dict[cte_reference, (level, cte_name, cte_opts)]
+ level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
+
+ ctes_recursive: bool
+ cte_positional: Dict[CTE, List[str]]
+
def __init__(
self,
dialect,
self.cache_key = cache_key
if cache_key:
- self._cache_key_bind_match = ckbm = {
- b.key: b for b in cache_key[1]
- }
- ckbm.update({b: [b] for b in cache_key[1]})
+ cksm = {b.key: b for b in cache_key[1]}
+ ckbm = {b: [b] for b in cache_key[1]}
+ self._cache_key_bind_match = (ckbm, cksm)
# compile INSERT/UPDATE defaults/sequences to expect executemany
# style execution, which may mean no pre-execute of defaults,
@property
def prefetch(self):
- return list(self.insert_prefetch + self.update_prefetch)
+ return list(self.insert_prefetch) + list(self.update_prefetch)
@util.memoized_property
def _global_attributes(self):
return {}
@util.memoized_instancemethod
- def _init_cte_state(self) -> None:
+ def _init_cte_state(self) -> MutableMapping[CTE, str]:
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
"""
# collect CTEs to tack on top of a SELECT
# To store the query to print - Dict[cte, text_query]
- self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ self.ctes = ctes
# Detect same CTE references - Dict[(level, name), cte]
# Level is required for supporting nesting
- self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
+ self.ctes_by_level_name = {}
# To retrieve key/level in ctes_by_level_name -
# Dict[cte_reference, (level, cte_name, cte_opts)]
- self.level_name_by_cte: Dict[
- CTE, Tuple[int, str, selectable._CTEOpts]
- ] = {}
+ self.level_name_by_cte = {}
- self.ctes_recursive: bool = False
+ self.ctes_recursive = False
if self.positional:
- self.cte_positional: Dict[CTE, List[str]] = {}
+ self.cte_positional = {}
+
+ return ctes
@contextlib.contextmanager
def _nested_result(self):
if not bindparam.type._is_tuple_type
else tuple(
elem_type._cached_bind_processor(self.dialect)
- for elem_type in bindparam.type.types
+ for elem_type in cast(TupleType, bindparam.type).types
),
)
for bindparam in self.bind_names
def construct_params(
self,
- params=None,
- _group_number=None,
- _check=True,
- extracted_parameters=None,
- ):
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ _group_number: Optional[int] = None,
+ _check: bool = True,
+ ) -> _MutableCoreSingleExecuteParams:
"""return a dictionary of bind parameter keys and values"""
has_escaped_names = bool(self.escaped_bind_names)
# way. The parameters present in self.bind_names may be clones of
# these original cache key params in the case of DML but the .key
# will be guaranteed to match.
- try:
- orig_extracted = self.cache_key[1]
- except TypeError as err:
+ if self.cache_key is None:
raise exc.CompileError(
"This compiled object has no original cache key; "
"can't pass extracted_parameters to construct_params"
- ) from err
+ )
+ else:
+ orig_extracted = self.cache_key[1]
- ckbm = self._cache_key_bind_match
+ ckbm_tuple = self._cache_key_bind_match
+ assert ckbm_tuple is not None
+ ckbm, _ = ckbm_tuple
resolved_extracted = {
bind: extracted
for b, extracted in zip(orig_extracted, extracted_parameters)
if bindparam.type._is_tuple_type:
inputsizes[bindparam] = [
- lookup_type(typ) for typ in bindparam.type.types
+ lookup_type(typ)
+ for typ in cast(TupleType, bindparam.type).types
]
else:
inputsizes[bindparam] = lookup_type(bindparam.type)
def _process_parameters_for_postcompile(
self,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_MutableCoreSingleExecuteParams] = None,
_populate_self: bool = False,
) -> ExpandedState:
"""handle special post compile parameters.
parameters = self.construct_params()
expanded_parameters = {}
+ positiontup: Optional[List[str]]
+
if self.positional:
positiontup = []
else:
positiontup = None
processors = self._bind_processors
+ single_processors = cast("Mapping[str, _ProcessorType]", processors)
+ tuple_processors = cast(
+ "Mapping[str, Sequence[_ProcessorType]]", processors
+ )
- new_processors = {}
+ new_processors: Dict[str, _ProcessorType] = {}
if self.positional and self._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'.
"the 'numeric' paramstyle at this time."
)
- replacement_expressions = {}
- to_update_sets = {}
+ replacement_expressions: Dict[str, Any] = {}
+ to_update_sets: Dict[str, Any] = {}
# notes:
# *unescaped* parameter names in:
# *escaped* parameter names in:
# construct_params(), replacement_expressions
- for name in (
- self.positiontup if self.positional else self.bind_names.values()
- ):
+ if self.positional and self.positiontup is not None:
+ names: Iterable[str] = self.positiontup
+ else:
+ names = self.bind_names.values()
+
+ for name in names:
escaped_name = (
self.escaped_bind_names.get(name, name)
if self.escaped_bind_names
if parameter in self.post_compile_params:
if escaped_name in replacement_expressions:
to_update = to_update_sets[escaped_name]
+ values = None
else:
# we are removing the parameter from parameters
# because it is a list value, which is not expected by
if not parameter.literal_execute:
parameters.update(to_update)
if parameter.type._is_tuple_type:
+ assert values is not None
new_processors.update(
(
"%s_%s_%s" % (name, i, j),
- processors[name][j - 1],
+ tuple_processors[name][j - 1],
)
for i, tuple_element in enumerate(values, 1)
- for j, value in enumerate(tuple_element, 1)
- if name in processors
- and processors[name][j - 1] is not None
+ for j, _ in enumerate(tuple_element, 1)
+ if name in tuple_processors
+ and tuple_processors[name][j - 1] is not None
)
else:
new_processors.update(
- (key, processors[name])
- for key, value in to_update
- if name in processors
+ (key, single_processors[name])
+ for key, _ in to_update
+ if name in single_processors
)
- if self.positional:
- positiontup.extend(name for name, value in to_update)
+ if positiontup is not None:
+ positiontup.extend(name for name, _ in to_update)
expanded_parameters[name] = [
- expand_key for expand_key, value in to_update
+ expand_key for expand_key, _ in to_update
]
- elif self.positional:
+ elif positiontup is not None:
positiontup.append(name)
def process_expanding(m):
# special use cases.
self.string = expanded_state.statement
self._bind_processors.update(expanded_state.processors)
- self.positiontup = expanded_state.positiontup
+ self.positiontup = list(expanded_state.positiontup or ())
self.post_compile_params = frozenset()
for key in expanded_state.parameter_expansion:
bind = self.binds.pop(key)
self._result_columns
)
+ _key_getters_for_crud_column: Tuple[
+ Callable[[Union[str, Column[Any]]], str],
+ Callable[[Column[Any]], str],
+ Callable[[Column[Any]], str],
+ ]
+
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
getter = self._key_getters_for_crud_column[2]
@util.memoized_property
@util.preload_module("sqlalchemy.engine.result")
def _inserted_primary_key_from_returning_getter(self):
- result = util.preloaded.engine_result
+ if typing.TYPE_CHECKING:
+ from ..engine import result
+ else:
+ result = util.preloaded.engine_result
param_key_getter = self._within_exec_param_key_getter
table = self.statement.table
- ret = {col: idx for idx, col in enumerate(self.returning)}
+ returning = self.returning
+ assert returning is not None
+ ret = {col: idx for idx, col in enumerate(returning)}
- getters = [
- (operator.itemgetter(ret[col]), True)
- if col in ret
- else (
- operator.methodcaller("get", param_key_getter(col), None),
- False,
- )
- for col in table.primary_key
- ]
+ getters = cast(
+ "List[Tuple[Callable[[Any], Any], bool]]",
+ [
+ (operator.itemgetter(ret[col]), True)
+ if col in ret
+ else (
+ operator.methodcaller("get", param_key_getter(col), None),
+ False,
+ )
+ for col in table.primary_key
+ ],
+ )
row_fn = result.result_tuple([col.key for col in table.primary_key])
self, element, within_columns_clause=False, **kwargs
):
if self.stack and self.dialect.supports_simple_order_by_label:
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ raise exc.CompileError(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ) from ke
(
with_cols,
# compiling the element outside of the context of a SELECT
return self.process(element._text_clause)
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ coercions._no_text_coercion(
+ element.element,
+ extra=(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ),
+ exc_cls=exc.CompileError,
+ err=ke,
+ )
+
with_cols, only_froms, only_cols = compile_state._label_resolve_dict
try:
if within_columns_clause:
def visit_column(
self,
- column,
- add_to_result_map=None,
- include_table=True,
- result_map_targets=(),
- ambiguous_table_name_map=None,
- **kwargs,
- ):
+ column: ColumnClause[Any],
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ include_table: bool = True,
+ result_map_targets: Tuple[Any, ...] = (),
+ ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
+ **kwargs: Any,
+ ) -> str:
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
)
else:
schema_prefix = ""
- tablename = table.name
+
+ tablename = cast("NamedFromClause", table).name
if (
not effective_schema
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"correlate_froms": set(),
"asfrom_froms": set(),
"selectable": taf,
compiled_col = self.visit_column(element, **kw)
return "(%s).%s" % (compiled_fn, compiled_col)
- def visit_function(self, func, add_to_result_map=None, **kwargs):
+ def visit_function(
+ self,
+ func: Function,
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ **kwargs: Any,
+ ) -> str:
if add_to_result_map is not None:
add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+
+ text: str
+
if disp:
text = disp(func, **kwargs)
else:
if compound_stmt._independent_ctes:
self._dispatch_independent_ctes(compound_stmt, kwargs)
- keyword = self.compound_keywords.get(cs.keyword)
+ keyword = self.compound_keywords[cs.keyword]
text = (" " + keyword + " ").join(
(
# a different set of parameter values. here, we accommodate for
# parameters that may have been cloned both before and after the cache
# key was been generated.
- ckbm = self._cache_key_bind_match
- if ckbm:
+ ckbm_tuple = self._cache_key_bind_match
+
+ if ckbm_tuple:
+ ckbm, cksm = ckbm_tuple
for bp in bindparam._cloned_set:
- if bp.key in ckbm:
- cb = ckbm[bp.key]
+ if bp.key in cksm:
+ cb = cksm[bp.key]
ckbm[cb].append(bindparam)
if bindparam.isoutparam:
if positional_names is not None:
positional_names.append(name)
else:
- self.positiontup.append(name)
+ self.positiontup.append(name) # type: ignore[union-attr]
elif not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
name = new_name
if escaped_from:
- if not self.escaped_bind_names:
- self.escaped_bind_names = {}
- self.escaped_bind_names[escaped_from] = name
+ self.escaped_bind_names = self.escaped_bind_names.union(
+ {escaped_from: name}
+ )
if post_compile:
return "__[POSTCOMPILE_%s]" % name
cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
**kwargs: Any,
) -> Optional[str]:
- self._init_cte_state()
+ self_ctes = self._init_cte_state()
+ assert self_ctes is self.ctes
kwargs["visiting_cte"] = cte
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
- del self.ctes[existing_cte]
+ del self_ctes[existing_cte]
existing_cte_reference_cte = existing_cte._get_reference_cte()
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
- if not cte_pre_alias_name and cte not in self.ctes:
+ if not cte_pre_alias_name and cte not in self_ctes:
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
cte, cte._suffixes, **kwargs
)
- self.ctes[cte] = text
+ self_ctes[cte] = text
if asfrom:
if from_linter:
from_linter.froms[cte] = cte_name
if not is_new_cte and embedded_in_current_named_cte:
- return self.preparer.format_alias(cte, cte_name)
+ return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501
if cte_pre_alias_name:
text = self.preparer.format_alias(cte, cte_pre_alias_name)
else:
return self.preparer.format_alias(cte, cte_name)
+ return None
+
def visit_table_valued_alias(self, element, **kw):
if element._is_lateral:
return self.visit_lateral(element, **kw)
self,
keyname: str,
name: str,
- objects: List[Any],
+ objects: Tuple[Any, ...],
type_: TypeEngine[Any],
) -> None:
if keyname is None or keyname == "*":
def get_statement_hint_text(self, hint_texts):
return " ".join(hint_texts)
- _default_stack_entry = util.immutabledict(
- [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
- )
+ _default_stack_entry: _CompilerStackEntry
+
+ if not typing.TYPE_CHECKING:
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(
self, select_stmt, asfrom, lateral=False, **kw
)
return froms
- translate_select_structure = None
+ translate_select_structure: Any = None
"""if not ``None``, should be a callable which accepts ``(select_stmt,
**kw)`` and returns a select object. this is used for structural changes
mostly to accommodate for LIMIT/OFFSET schemes
)
self._result_columns = [
- (key, name, tuple(translate.get(o, o) for o in obj), type_)
+ ResultColumnsEntry(
+ key, name, tuple(translate.get(o, o) for o in obj), type_
+ )
for key, name, obj, type_ in self._result_columns
]
implicit_correlate_froms=asfrom_froms,
)
- new_correlate_froms = set(selectable._from_objects(*froms))
+ new_correlate_froms = set(_from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"asfrom_froms": new_correlate_froms,
"correlate_froms": all_correlate_froms,
"selectable": select,
text += " \nWHERE " + t
if warn_linting:
+ assert from_linter is not None
from_linter.warn()
if select._group_by_clauses:
if not self.ctes:
return ""
+ ctes: MutableMapping[CTE, str]
+
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
for cte in list(self.ctes.keys()):
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
+ assert self.positiontup is not None
self.positiontup = (
- sum([self.cte_positional[cte] for cte in ctes], [])
+ list(
+ itertools.chain.from_iterable(
+ self.cte_positional[cte] for cte in ctes
+ )
+ )
+ self.positiontup
)
+
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
if is_multitable:
# main table might be a JOIN
- main_froms = set(selectable._from_objects(update_stmt.table))
+ main_froms = set(_from_objects(update_stmt.table))
render_extra_froms = [
f for f in extra_froms if f not in main_froms
]
def type_compiler(self):
return self.dialect.type_compiler
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
return None
def visit_ddl(self, ddl, **kwargs):
return get_col_spec(**kw)
+class _SchemaForObjectCallable(Protocol):
+ def __call__(self, obj: Any) -> str:
+ ...
+
+
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
- schema_for_object = operator.attrgetter("schema")
+ initial_quote: str
+
+ final_quote: str
+
+ _strings: MutableMapping[str, str]
+
+ schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
"""Return the .schema attribute for an object.
For the default IdentifierPreparer, the schema for an object is always
return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
- def _escape_identifier(self, value):
+ def _escape_identifier(self, value: str) -> str:
"""Escape an identifier.
Subclasses should override this to provide database-dependent
value = value.replace("%", "%%")
return value
- def _unescape_identifier(self, value):
+ def _unescape_identifier(self, value: str) -> str:
"""Canonicalize an escaped identifier.
Subclasses should override this to provide database-dependent
)
return element
- def quote_identifier(self, value):
+ def quote_identifier(self, value: str) -> str:
"""Quote an identifier.
Subclasses should override this to provide database-dependent
+ self.final_quote
)
- def _requires_quotes(self, value):
+ def _requires_quotes(self, value: str) -> bool:
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (
not taking case convention into account."""
return not self.legal_characters.match(str(value))
- def quote_schema(self, schema, force=None):
+ def quote_schema(self, schema: str, force: Any = None) -> str:
"""Conditionally quote a schema name.
return self.quote(schema)
- def quote(self, ident, force=None):
+ def quote(self, ident: str, force: Any = None) -> str:
"""Conditionally quote an identifier.
The identifier is quoted if it is a reserved word, contains
name = self.quote_schema(effective_schema) + "." + name
return name
- def format_label(self, label, name=None):
+ def format_label(
+ self, label: Label[Any], name: Optional[str] = None
+ ) -> str:
return self.quote(name or label.name)
- def format_alias(self, alias, name=None):
- return self.quote(name or alias.name)
+ def format_alias(
+ self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
+ ) -> str:
+ if name is None:
+ assert alias is not None
+ return self.quote(alias.name)
+ else:
+ return self.quote(name)
def format_savepoint(self, savepoint, name=None):
# Running the savepoint name through quoting is unnecessary
import collections.abc as collections_abc
import typing
+from typing import Any
+from typing import List
+from typing import MutableMapping
+from typing import Optional
from . import coercions
from . import roles
class DMLState(CompileState):
_no_parameters = True
- _dict_parameters = None
- _multi_parameters = None
+ _dict_parameters: Optional[MutableMapping[str, Any]] = None
+ _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None
_ordered_values = None
_parameter_ordering = None
_has_multi_parameters = False
import typing
from typing import Any
from typing import Callable
+from typing import Dict
from typing import Generic
+from typing import List
from typing import Optional
from typing import overload
from typing import Sequence
from .operators import ColumnOperators
from .traversals import HasCopyInternals
from .visitors import cloned_traverse
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .visitors import traverse
from .visitors import Visitable
from ..engine import Connection
from ..engine import Dialect
from ..engine import Engine
+ from ..engine.base import _CompiledCacheType
+ from ..engine.base import _SchemaTranslateMapType
_NUMERIC = Union[complex, "Decimal"]
SupportsWrappingAnnotations,
MemoizedHasCacheKey,
HasCopyInternals,
+ ExternallyTraversible,
CompilerElement,
):
"""Base class for elements of a programmatically constructed SQL
"""
return self._replace_params(True, optionaldict, kwargs)
- def params(self, *optionaldict, **kwargs):
+ def params(
+ self, *optionaldict: Dict[str, Any], **kwargs: Any
+ ) -> ClauseElement:
"""Return a copy with :func:`_expression.bindparam` elements
replaced.
"""
return self._replace_params(False, optionaldict, kwargs)
- def _replace_params(self, unique, optionaldict, kwargs):
+ def _replace_params(
+ self,
+ unique: bool,
+ optionaldict: Optional[Dict[str, Any]],
+ kwargs: Dict[str, Any],
+ ) -> ClauseElement:
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
def _compile_w_cache(
self,
- dialect,
- compiled_cache=None,
- column_keys=None,
- for_executemany=False,
- schema_translate_map=None,
- **kw,
+ dialect: Dialect,
+ compiled_cache: Optional[_CompiledCacheType] = None,
+ column_keys: Optional[List[str]] = None,
+ for_executemany: bool = False,
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+ **kw: Any,
):
if compiled_cache is not None and dialect._supports_statement_cache:
elem_cache_key = self._generate_cache_key()
"""
return Cast(self, type_)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[_T]:
"""Produce a column label, i.e. ``<columnname> AS <name>``.
This is a shortcut to the :func:`_expression.label` function.
("value", InternalTraversal.dp_plain_obj),
]
+ key: str
+ type: TypeEngine
+
_is_crud = False
_is_bind_parameter = True
_key_is_anon = False
from __future__ import annotations
from typing import Any
+from typing import Sequence
from typing import TypeVar
from . import annotation
identifier: str
+ packagenames: Sequence[str]
+
type: TypeEngine = sqltypes.NULLTYPE
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
from __future__ import annotations
import typing
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
-from sqlalchemy.util.langhelpers import TypingOnly
from .. import util
-
+from ..util import TypingOnly
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
+ from .base import ColumnCollection
from .elements import ClauseElement
+ from .elements import Label
from .selectable import FromClause
+ from .selectable import Subquery
class SQLRole:
class UsesInspection:
__slots__ = ()
- _post_inspect = None
+ _post_inspect: Literal[None] = None
uses_inspection = True
_role_name = "Column expression or FROM clause"
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> Sequence[ColumnsClauseRole]:
raise NotImplementedError()
__slots__ = ()
_role_name = "SQL expression element"
+ def label(self, name: Optional[str]) -> Label[Any]:
+ raise NotImplementedError()
+
class ConstExprRole(ExpressionElementRole):
__slots__ = ()
_is_subquery = False
@property
- def _hide_froms(self):
+ def _hide_froms(self) -> Iterable[FromClause]:
raise NotImplementedError()
__slots__ = ()
# does not allow text() or select() objects
+ c: ColumnCollection
+
@property
- def description(self):
+ def description(self) -> str:
raise NotImplementedError()
__slots__ = ()
# calls .alias() as a post processor
- def _anonymous_fromclause(self, name=None, flat=False):
+ def _anonymous_fromclause(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> FromClause:
raise NotImplementedError()
__slots__ = ()
_role_name = "Executable SQL or text() construct"
- _propagate_attrs = util.immutabledict()
+ _propagate_attrs: Mapping[str, Any] = util.immutabledict()
class SelectStatementRole(StatementRole, ReturnsRowsRole):
__slots__ = ()
_role_name = "SELECT construct or equivalent text() construct"
- def subquery(self):
+ def subquery(self) -> Subquery:
raise NotImplementedError(
"All SelectStatementRole objects should implement a "
".subquery() method."
from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import Executable
-from .base import SchemaEventTarget
+from .base import SchemaEventTarget as SchemaEventTarget
from .coercions import _document_text_coercion
from .elements import ClauseElement
from .elements import ColumnClause
def __init__(self, for_update=False):
self.for_update = for_update
+ @util.memoized_property
+ def is_callable(self):
+ raise NotImplementedError()
+
def _set_parent(self, column, **kw):
self.column = column
if self.for_update:
from .base import HasCompileState
from .base import HasMemoized
from .base import Immutable
-from .base import prefix_anon_map
from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import BindParameter
from .elements import TableValuedColumn
from .elements import UnaryExpression
from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
from .. import exc
from .. import util
-
and_ = BooleanClauseList.and_
_T = TypeVar("_T", bound=Any)
return self.alias(name=name)
+class NamedFromClause(FromClause):
+ named_with_column = True
+
+ name: str
+
+
class SelectLabelStyle(Enum):
"""Label style constants that may be passed to
:meth:`_sql.Select.set_label_style`."""
# -> Lateral -> FromClause, but we accept SelectBase
# w/ non-deprecated coercion
# -> TableSample -> only for FromClause
-class AliasedReturnsRows(NoInit, FromClause):
+class AliasedReturnsRows(NoInit, NamedFromClause):
"""Base class of aliases against tables, subqueries, and other
selectables."""
_is_from_container = True
- named_with_column = True
_supports_derived_columns = False
+ element: ClauseElement
+
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("name", InternalTraversal.dp_anon_name),
inherit_cache = True
+ element: FromClause
+
@classmethod
def _factory(cls, selectable, name=None, flat=False):
return coercions.expect(
+ HasSuffixes._has_suffixes_traverse_internals
)
+ element: HasCTE
+
@classmethod
def _factory(cls, selectable, name=None, recursive=False):
r"""Return a new :class:`_expression.CTE`,
nesting: bool
-class HasCTE(roles.HasCTERole):
+class HasCTE(roles.HasCTERole, ClauseElement):
"""Mixin that declares a class to include CTE support.
.. versionadded:: 1.1
inherit_cache = True
+ element: Select
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`.Subquery` object."""
self.element = state["element"]
-class TableClause(roles.DMLTableRole, Immutable, FromClause):
+class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
"""Represents a minimal "table" construct.
This is a lightweight table object that has only a name, a
("name", InternalTraversal.dp_string),
]
- named_with_column = True
-
_is_table = True
implicit_returning = False
SelfValues = typing.TypeVar("SelfValues", bound="Values")
-class Values(Generative, FromClause):
+class Values(Generative, NamedFromClause):
"""Represent a ``VALUES`` construct that can be used as a FROM element
in a statement.
"""
- named_with_column = True
__visit_name__ = "values"
_data = ()
from .elements import quoted_name
from .elements import Slice
from .elements import TypeCoerce as type_coerce # noqa
-from .traversals import InternalTraversal
from .type_api import Emulated
from .type_api import NativeForEmulated # noqa
from .type_api import to_instance
from .type_api import TypeDecorator
from .type_api import TypeEngine
from .type_api import Variant # noqa
+from .visitors import InternalTraversal
from .. import event
from .. import exc
from .. import inspection
import typing
from typing import Any
from typing import Callable
+from typing import Deque
from typing import Dict
+from typing import Set
+from typing import Tuple
from typing import Type
from typing import TypeVar
from .cache_key import HasCacheKey
from .visitors import _TraverseInternalsType
from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
+from .visitors import HasTraversalDispatch
from .visitors import HasTraverseInternals
-from .visitors import InternalTraversal
from .. import util
from ..util import langhelpers
def compare(obj1, obj2, **kw):
+ strategy: TraversalComparatorStrategy
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
else:
def _preconfigure_traversals(target_hierarchy):
for cls in util.walk_subclasses(target_hierarchy):
- if hasattr(cls, "_traverse_internals"):
- cls._generate_cache_attrs()
+ if hasattr(cls, "_generate_cache_attrs") and hasattr(
+ cls, "_traverse_internals"
+ ):
+ cls._generate_cache_attrs() # type: ignore
_copy_internals.generate_dispatch(
- cls,
- cls._traverse_internals,
+ cls, # type: ignore
+ cls._traverse_internals, # type: ignore
"_generated_copy_internals_traversal",
)
_get_children.generate_dispatch(
- cls,
- cls._traverse_internals,
+ cls, # type: ignore
+ cls._traverse_internals, # type: ignore
"_generated_get_children_traversal",
)
meth_text = f"def {method_name}(self, d):\n{code}\n"
return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- def _shallow_from_dict(self, d: Dict) -> None:
+ def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
cls = self.__class__
+ shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
try:
shallow_from_dict = cls.__dict__[
"_generated_shallow_from_dict_traversal"
]
except KeyError:
- shallow_from_dict = (
- cls._generated_shallow_from_dict_traversal # type: ignore
- ) = self._generate_shallow_from_dict(
+ shallow_from_dict = self._generate_shallow_from_dict(
cls._traverse_internals,
"_generated_shallow_from_dict_traversal",
)
+ cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa E501
+
shallow_from_dict(self, d)
def _shallow_to_dict(self) -> Dict[str, Any]:
cls = self.__class__
+ shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
+
try:
shallow_to_dict = cls.__dict__[
"_generated_shallow_to_dict_traversal"
]
except KeyError:
- shallow_to_dict = (
- cls._generated_shallow_to_dict_traversal # type: ignore
- ) = self._generate_shallow_to_dict(
+ shallow_to_dict = self._generate_shallow_to_dict(
cls._traverse_internals, "_generated_shallow_to_dict_traversal"
)
+ cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa E501
return shallow_to_dict(self)
- def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy):
+ def _shallow_copy_to(
+ self: SelfHasShallowCopy, other: SelfHasShallowCopy
+ ) -> None:
cls = self.__class__
+ shallow_copy: Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None]
try:
shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
except KeyError:
- shallow_copy = (
- cls._generated_shallow_copy_traversal # type: ignore
- ) = self._generate_shallow_copy(
+ shallow_copy = self._generate_shallow_copy(
cls._traverse_internals, "_generated_shallow_copy_traversal"
)
+ cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
shallow_copy(self, other)
- def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy:
+ def _clone(self: SelfHasShallowCopy, **kw: Any) -> SelfHasShallowCopy:
"""Create a shallow copy"""
c = self.__class__.__new__(self.__class__)
self._shallow_copy_to(c)
setattr(self, attrname, result)
-class _CopyInternalsTraversal(InternalTraversal):
+class _CopyInternalsTraversal(HasTraversalDispatch):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
return element
-class _GetChildrenTraversal(InternalTraversal):
+class _GetChildrenTraversal(HasTraversalDispatch):
"""Generate a _children_traversal internal traversal dispatch for classes
with a _traverse_internals collection."""
return name
-class TraversalComparatorStrategy(
- ExtendedInternalTraversal, util.MemoizedSlots
-):
+class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
__slots__ = "stack", "cache", "anon_map"
def __init__(self):
- self.stack = deque()
+ self.stack: Deque[
+ Tuple[ExternallyTraversible, ExternallyTraversible]
+ ] = deque()
self.cache = set()
def _memoized_attr_anon_map(self):
if seq1 is None:
return seq2 is None
- completed = set()
+ completed: Set[object] = set()
for clause in seq1:
for other_clause in set(seq2).difference(completed):
if self.compare_inner(clause, other_clause, **kw):
from . import operators
from . import roles
from . import visitors
-from .annotation import _deep_annotate # noqa
-from .annotation import _deep_deannotate # noqa
-from .annotation import _shallow_annotate # noqa
+from .annotation import _deep_annotate as _deep_annotate
+from .annotation import _deep_deannotate as _deep_deannotate
+from .annotation import _shallow_annotate as _shallow_annotate
from .base import _expand_cloned
from .base import _from_objects
from .base import ColumnSet
"""Visitor/traversal interface and library functions.
-SQLAlchemy schema and expression constructs rely on a Python-centric
-version of the classic "visitor" pattern as the primary way in which
-they apply functionality. The most common use of this pattern
-is statement compilation, where individual expression classes match
-up to rendering methods that produce a string result. Beyond this,
-the visitor system is also used to inspect expressions for various
-information and patterns, as well as for the purposes of applying
-transformations to expressions.
-
-Examples of how the visit system is used can be seen in the source code
-of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler``
-modules. Some background on clause adaption is also at
-https://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
"""
from __future__ import annotations
from collections import deque
+from enum import Enum
import itertools
import operator
import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import ClassVar
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Mapping
+from typing import Optional
from typing import Tuple
+from typing import Type
+from typing import TypeVar
+from typing import Union
from .. import exc
from .. import util
from ..util import langhelpers
-from ..util import symbol
from ..util._has_cy import HAS_CYEXTENSION
-from ..util.langhelpers import _symbol
+from ..util.typing import Protocol
+from ..util.typing import Self
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_util import cache_anon_map as anon_map # noqa
+ from ._py_util import prefix_anon_map as prefix_anon_map
+ from ._py_util import cache_anon_map as anon_map
else:
- from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
+ from sqlalchemy.cyextension.util import prefix_anon_map as prefix_anon_map
+ from sqlalchemy.cyextension.util import cache_anon_map as anon_map
+
__all__ = [
"iterate",
"Visitable",
"ExternalTraversal",
"InternalTraversal",
+ "anon_map",
]
-_TraverseInternalsType = List[Tuple[str, _symbol]]
-
-
-class HasTraverseInternals:
- """base for classes that have a "traverse internals" element,
- which defines all kinds of ways of traversing the elements of an object.
-
- """
-
- __slots__ = ()
-
- _traverse_internals: _TraverseInternalsType
-
- @util.preload_module("sqlalchemy.sql.traversals")
- def get_children(self, omit_attrs=(), **kw):
- r"""Return immediate child :class:`.visitors.Visitable`
- elements of this :class:`.visitors.Visitable`.
-
- This is used for visit traversal.
-
- \**kw may contain flags that change the collection that is
- returned, for example to return a subset of items in order to
- cut down on larger traversals, or to return child items from a
- different context (such as schema-level collections instead of
- clause-level).
-
- """
-
- traversals = util.preloaded.sql_traversals
-
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return []
- dispatch = traversals._get_children.run_generated_dispatch
- return itertools.chain.from_iterable(
- meth(obj, **kw)
- for attrname, obj, meth in dispatch(
- self, traverse_internals, "_generated_get_children_traversal"
- )
- if attrname not in omit_attrs and obj is not None
- )
+class _CompilerDispatchType(Protocol):
+ def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any:
+ ...
class Visitable:
"""Base class for visitable objects.
+ :class:`.Visitable` is used to implement the SQL compiler dispatch
+ functions. Other forms of traversal such as for cache key generation
+ are implemented separately using the :class:`.HasTraverseInternals`
+ interface.
+
.. versionchanged:: 2.0 The :class:`.Visitable` class was named
:class:`.Traversible` in the 1.4 series; the name is changed back
to :class:`.Visitable` in 2.0 which is what it was prior to 1.4.
__visit_name__: str
+ _original_compiler_dispatch: _CompilerDispatchType
+
+ if typing.TYPE_CHECKING:
+
+ def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str:
+ ...
+
def __init_subclass__(cls) -> None:
if "__visit_name__" in cls.__dict__:
cls._generate_compiler_dispatch()
super().__init_subclass__()
@classmethod
- def _generate_compiler_dispatch(cls):
- """Assign dispatch attributes to various kinds of
- "visitable" classes.
-
- Attributes include:
-
- * The ``_compiler_dispatch`` method, corresponding to
- ``__visit_name__``. This is called "external traversal" because the
- caller of each visit() method is responsible for sub-traversing the
- inner elements of each object. This is appropriate for string
- compilers and other traversals that need to call upon the inner
- elements in a specific pattern.
-
- * internal traversal collections ``_children_traversal``,
- ``_cache_key_traversal``, ``_copy_internals_traversal``, generated
- from an optional ``_traverse_internals`` collection of symbols which
- comes from the :class:`.InternalTraversal` list of symbols. This is
- called "internal traversal".
-
- """
+ def _generate_compiler_dispatch(cls) -> None:
visit_name = cls.__visit_name__
if "_compiler_dispatch" in cls.__dict__:
name = "visit_%s" % visit_name
getter = operator.attrgetter(name)
- def _compiler_dispatch(self, visitor, **kw):
+ def _compiler_dispatch(
+ self: Visitable, visitor: Any, **kw: Any
+ ) -> str:
"""Look for an attribute named "visit_<visit_name>" on the
visitor, and call it with the same kw params.
try:
meth = getter(visitor)
except AttributeError as err:
- return visitor.visit_unsupported_compilation(self, err, **kw)
+ return visitor.visit_unsupported_compilation(self, err, **kw) # type: ignore # noqa E501
else:
- return meth(self, **kw)
+ return meth(self, **kw) # type: ignore # noqa E501
- cls._compiler_dispatch = (
+ cls._compiler_dispatch = ( # type: ignore
cls._original_compiler_dispatch
) = _compiler_dispatch
- def __class_getitem__(cls, key):
+ def __class_getitem__(cls, key: str) -> Any:
# allow generic classes in py3.9+
return cls
-class _HasTraversalDispatch:
- r"""Define infrastructure for the :class:`.InternalTraversal` class.
-
- .. versionadded:: 2.0
-
- """
-
- __slots__ = ()
-
- def __init_subclass__(cls) -> None:
- cls._generate_traversal_dispatch()
- super().__init_subclass__()
-
- def dispatch(self, visit_symbol):
- """Given a method from :class:`._HasTraversalDispatch`, return the
- corresponding method on a subclass.
-
- """
- name = self._dispatch_lookup[visit_symbol]
- return getattr(self, name, None)
-
- def run_generated_dispatch(
- self, target, internal_dispatch, generate_dispatcher_name
- ):
- try:
- dispatcher = target.__class__.__dict__[generate_dispatcher_name]
- except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
- # this block will generate any remaining dispatchers.
- dispatcher = self.generate_dispatch(
- target.__class__, internal_dispatch, generate_dispatcher_name
- )
- return dispatcher(target, self)
-
- def generate_dispatch(
- self, target_cls, internal_dispatch, generate_dispatcher_name
- ):
- dispatcher = self._generate_dispatcher(
- internal_dispatch, generate_dispatcher_name
- )
- # assert isinstance(target_cls, type)
- setattr(target_cls, generate_dispatcher_name, dispatcher)
- return dispatcher
-
- @classmethod
- def _generate_traversal_dispatch(cls):
- lookup = {}
- clsdict = cls.__dict__
- for key, sym in clsdict.items():
- if key.startswith("dp_"):
- visit_key = key.replace("dp_", "visit_")
- sym_name = sym.name
- assert sym_name not in lookup, sym_name
- lookup[sym] = lookup[sym_name] = visit_key
- if hasattr(cls, "_dispatch_lookup"):
- lookup.update(cls._dispatch_lookup)
- cls._dispatch_lookup = lookup
-
- def _generate_dispatcher(self, internal_dispatch, method_name):
- names = []
- for attrname, visit_sym in internal_dispatch:
- meth = self.dispatch(visit_sym)
- if meth:
- visit_name = ExtendedInternalTraversal._dispatch_lookup[
- visit_sym
- ]
- names.append((attrname, visit_name))
-
- code = (
- (" return [\n")
- + (
- ", \n".join(
- " (%r, self.%s, visitor.%s)"
- % (attrname, attrname, visit_name)
- for attrname, visit_name in names
- )
- )
- + ("\n ]\n")
- )
- meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
- return langhelpers._exec_code_in_env(meth_text, {}, method_name)
-
-
-class InternalTraversal(_HasTraversalDispatch):
+class InternalTraversal(Enum):
r"""Defines visitor symbols used for internal traversal.
The :class:`.InternalTraversal` class is used in two ways. One is that
"""
- __slots__ = ()
-
- dp_has_cache_key = symbol("HC")
+ dp_has_cache_key = "HC"
"""Visit a :class:`.HasCacheKey` object."""
- dp_has_cache_key_list = symbol("HL")
+ dp_has_cache_key_list = "HL"
"""Visit a list of :class:`.HasCacheKey` objects."""
- dp_clauseelement = symbol("CE")
+ dp_clauseelement = "CE"
"""Visit a :class:`_expression.ClauseElement` object."""
- dp_fromclause_canonical_column_collection = symbol("FC")
+ dp_fromclause_canonical_column_collection = "FC"
"""Visit a :class:`_expression.FromClause` object in the context of the
``columns`` attribute.
"""
- dp_clauseelement_tuples = symbol("CTS")
+ dp_clauseelement_tuples = "CTS"
"""Visit a list of tuples which contain :class:`_expression.ClauseElement`
objects.
"""
- dp_clauseelement_list = symbol("CL")
+ dp_clauseelement_list = "CL"
"""Visit a list of :class:`_expression.ClauseElement` objects.
"""
- dp_clauseelement_tuple = symbol("CT")
+ dp_clauseelement_tuple = "CT"
"""Visit a tuple of :class:`_expression.ClauseElement` objects.
"""
- dp_executable_options = symbol("EO")
+ dp_executable_options = "EO"
- dp_with_context_options = symbol("WC")
+ dp_with_context_options = "WC"
- dp_fromclause_ordered_set = symbol("CO")
+ dp_fromclause_ordered_set = "CO"
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
- dp_string = symbol("S")
+ dp_string = "S"
"""Visit a plain string value.
Examples include table and column names, bound parameter keys, special
"""
- dp_string_list = symbol("SL")
+ dp_string_list = "SL"
"""Visit a list of strings."""
- dp_anon_name = symbol("AN")
+ dp_anon_name = "AN"
"""Visit a potentially "anonymized" string value.
The string value is considered to be significant for cache key
"""
- dp_boolean = symbol("B")
+ dp_boolean = "B"
"""Visit a boolean value.
The boolean value is considered to be significant for cache key
"""
- dp_operator = symbol("O")
+ dp_operator = "O"
"""Visit an operator.
The operator is a function from the :mod:`sqlalchemy.sql.operators`
"""
- dp_type = symbol("T")
+ dp_type = "T"
"""Visit a :class:`.TypeEngine` object
The type object is considered to be significant for cache key
"""
- dp_plain_dict = symbol("PD")
+ dp_plain_dict = "PD"
"""Visit a dictionary with string keys.
The keys of the dictionary should be strings, the values should
"""
- dp_dialect_options = symbol("DO")
+ dp_dialect_options = "DO"
"""Visit a dialect options structure."""
- dp_string_clauseelement_dict = symbol("CD")
+ dp_string_clauseelement_dict = "CD"
"""Visit a dictionary of string keys to :class:`_expression.ClauseElement`
objects.
"""
- dp_string_multi_dict = symbol("MD")
+ dp_string_multi_dict = "MD"
"""Visit a dictionary of string keys to values which may either be
plain immutable/hashable or :class:`.HasCacheKey` objects.
"""
- dp_annotations_key = symbol("AK")
+ dp_annotations_key = "AK"
"""Visit the _annotations_cache_key element.
This is a dictionary of additional information about a ClauseElement
"""
- dp_plain_obj = symbol("PO")
+ dp_plain_obj = "PO"
"""Visit a plain python object.
The value should be immutable and hashable, such as an integer.
"""
- dp_named_ddl_element = symbol("DD")
+ dp_named_ddl_element = "DD"
"""Visit a simple named DDL element.
The current object used by this method is the :class:`.Sequence`.
"""
- dp_prefix_sequence = symbol("PS")
+ dp_prefix_sequence = "PS"
"""Visit the sequence represented by :class:`_expression.HasPrefixes`
or :class:`_expression.HasSuffixes`.
"""
- dp_table_hint_list = symbol("TH")
+ dp_table_hint_list = "TH"
"""Visit the ``_hints`` collection of a :class:`_expression.Select`
object.
"""
- dp_setup_join_tuple = symbol("SJ")
+ dp_setup_join_tuple = "SJ"
- dp_memoized_select_entities = symbol("ME")
+ dp_memoized_select_entities = "ME"
- dp_statement_hint_list = symbol("SH")
+ dp_statement_hint_list = "SH"
"""Visit the ``_statement_hints`` collection of a
:class:`_expression.Select`
object.
"""
- dp_unknown_structure = symbol("UK")
+ dp_unknown_structure = "UK"
"""Visit an unknown structure.
"""
- dp_dml_ordered_values = symbol("DML_OV")
+ dp_dml_ordered_values = "DML_OV"
"""Visit the values() ordered tuple list of an
:class:`_expression.Update` object."""
- dp_dml_values = symbol("DML_V")
+ dp_dml_values = "DML_V"
"""Visit the values() dictionary of a :class:`.ValuesBase`
(e.g. Insert or Update) object.
"""
- dp_dml_multi_values = symbol("DML_MV")
+ dp_dml_multi_values = "DML_MV"
"""Visit the values() multi-valued list of dictionaries of an
:class:`_expression.Insert` object.
"""
- dp_propagate_attrs = symbol("PA")
+ dp_propagate_attrs = "PA"
"""Visit the propagate attrs dict. This hardcodes to the particular
elements we care about right now."""
-
-class ExtendedInternalTraversal(InternalTraversal):
- """Defines additional symbols that are useful in caching applications.
+ """Symbols that follow are additional symbols that are useful in
+ caching applications.
Traversals for :class:`_expression.ClauseElement` objects only need to use
those symbols present in :class:`.InternalTraversal`. However, for
"""
- __slots__ = ()
-
- dp_ignore = symbol("IG")
+ dp_ignore = "IG"
"""Specify an object that should be ignored entirely.
This currently applies function call argument caching where some
"""
- dp_inspectable = symbol("IS")
+ dp_inspectable = "IS"
"""Visit an inspectable object where the return value is a
:class:`.HasCacheKey` object."""
- dp_multi = symbol("M")
+ dp_multi = "M"
"""Visit an object that may be a :class:`.HasCacheKey` or may be a
plain hashable object."""
- dp_multi_list = symbol("MT")
+ dp_multi_list = "MT"
"""Visit a tuple containing elements that may be :class:`.HasCacheKey` or
may be a plain hashable object."""
- dp_has_cache_key_tuples = symbol("HT")
+ dp_has_cache_key_tuples = "HT"
"""Visit a list of tuples which contain :class:`.HasCacheKey`
objects.
"""
- dp_inspectable_list = symbol("IL")
+ dp_inspectable_list = "IL"
"""Visit a list of inspectable objects which upon inspection are
HasCacheKey objects."""
+_TraverseInternalsType = List[Tuple[str, InternalTraversal]]
+"""a structure that defines how a HasTraverseInternals should be
+traversed.
+
+This structure consists of a list of (attributename, internaltraversal)
+tuples, where the "attributename" refers to the name of an attribute on an
+instance of the HasTraverseInternals object, and "internaltraversal" refers
+to an :class:`.InternalTraversal` enumeration symbol defining what kind
+of data this attribute stores, which indicates to the traverser how it should
+be handled.
+
+"""
+
+
+class HasTraverseInternals:
+ """base for classes that have a "traverse internals" element,
+ which defines all kinds of ways of traversing the elements of an object.
+
+ Compared to :class:`.Visitable`, which relies upon an external visitor to
+ define how the object is travered (i.e. the :class:`.SQLCompiler`), the
+ :class:`.HasTraverseInternals` interface allows classes to define their own
+ traversal, that is, what attributes are accessed and in what order.
+
+ """
+
+ __slots__ = ()
+
+ _traverse_internals: _TraverseInternalsType
+
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(
+ self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Iterable[HasTraverseInternals]:
+ r"""Return immediate child :class:`.visitors.HasTraverseInternals`
+ elements of this :class:`.visitors.HasTraverseInternals`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
+
+class _InternalTraversalDispatchType(Protocol):
+ def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any:
+ ...
+
+
+class HasTraversalDispatch:
+ r"""Define infrastructure for classes that perform internal traversals
+
+ .. versionadded:: 2.0
+
+ """
+
+ __slots__ = ()
+
+ _dispatch_lookup: ClassVar[Dict[Union[InternalTraversal, str], str]] = {}
+
+ def dispatch(self, visit_symbol: InternalTraversal) -> Callable[..., Any]:
+ """Given a method from :class:`.HasTraversalDispatch`, return the
+ corresponding method on a subclass.
+
+ """
+ name = _dispatch_lookup[visit_symbol]
+ return getattr(self, name, None) # type: ignore
+
+ def run_generated_dispatch(
+ self,
+ target: object,
+ internal_dispatch: _TraverseInternalsType,
+ generate_dispatcher_name: str,
+ ) -> Any:
+ dispatcher: _InternalTraversalDispatchType
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ # traversals.py -> _preconfigure_traversals()
+ # may be used to run these ahead of time, but
+ # is not enabled right now.
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.generate_dispatch(
+ target.__class__, internal_dispatch, generate_dispatcher_name
+ )
+ return dispatcher(target, self)
+
+ def generate_dispatch(
+ self,
+ target_cls: Type[object],
+ internal_dispatch: _TraverseInternalsType,
+ generate_dispatcher_name: str,
+ ) -> _InternalTraversalDispatchType:
+ dispatcher = self._generate_dispatcher(
+ internal_dispatch, generate_dispatcher_name
+ )
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
+ return dispatcher
+
+ def _generate_dispatcher(
+ self, internal_dispatch: _TraverseInternalsType, method_name: str
+ ) -> _InternalTraversalDispatchType:
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = self.dispatch(visit_sym)
+ if meth:
+ visit_name = _dispatch_lookup[visit_sym]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
+ )
+ + ("\n ]\n")
+ )
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ return cast(
+ _InternalTraversalDispatchType,
+ langhelpers._exec_code_in_env(meth_text, {}, method_name),
+ )
+
+
+ExtendedInternalTraversal = InternalTraversal
+
+
+def _generate_traversal_dispatch() -> None:
+ lookup = _dispatch_lookup
+
+ for sym in InternalTraversal:
+ key = sym.name
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.value
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+
+
+_dispatch_lookup = HasTraversalDispatch._dispatch_lookup
+_generate_traversal_dispatch()
+
+
+class ExternallyTraversible(HasTraverseInternals, Visitable):
+ __slots__ = ()
+
+ _annotations: Collection[Any] = ()
+
+ if typing.TYPE_CHECKING:
+
+ def get_children(
+ self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Iterable[ExternallyTraversible]:
+ ...
+
+ def _clone(self: Self, **kw: Any) -> Self:
+ """clone this element"""
+ raise NotImplementedError()
+
+ def _copy_internals(
+ self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Self:
+ """Reassign internal elements to be clones of themselves.
+
+ Called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy.
+
+ The given clone function should be used, which may be applying
+ additional transformations to the element (i.e. replacement
+ traversal, cloned traversal, annotations).
+
+ """
+ raise NotImplementedError()
+
+
+_ET = TypeVar("_ET", bound=ExternallyTraversible)
+_TraverseCallableType = Callable[[_ET], None]
+_TraverseTransformCallableType = Callable[
+ [ExternallyTraversible], Optional[ExternallyTraversible]
+]
+
+
class ExternalTraversal:
"""Base class for visitor objects which can traverse externally using
the :func:`.visitors.traverse` function.
"""
- __traverse_options__ = {}
+ __traverse_options__: Dict[str, Any] = {}
+ _next: Optional[ExternalTraversal]
def traverse_single(self, obj: Visitable, **kw: Any) -> Any:
for v in self.visitor_iterator:
if meth:
return meth(obj, **kw)
- def iterate(self, obj):
+ def iterate(
+ self, obj: ExternallyTraversible
+ ) -> Iterator[ExternallyTraversible]:
"""Traverse the given expression structure, returning an iterator
of all elements.
"""
return iterate(obj, self.__traverse_options__)
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
@util.memoized_property
- def _visitor_dict(self):
+ def _visitor_dict(self) -> Dict[str, _TraverseCallableType[Any]]:
visitors = {}
for name in dir(self):
return visitors
@property
- def visitor_iterator(self):
+ def visitor_iterator(self) -> Iterator[ExternalTraversal]:
"""Iterate through this visitor and each 'chained' visitor."""
- v = self
+ v: Optional[ExternalTraversal] = self
while v:
yield v
v = getattr(v, "_next", None)
- def chain(self, visitor):
- """'Chain' an additional ClauseVisitor onto this ClauseVisitor.
+ def chain(self, visitor: ExternalTraversal) -> ExternalTraversal:
+ """'Chain' an additional ExternalTraversal onto this ExternalTraversal
The chained visitor will receive all visit events after this one.
"""
- def copy_and_process(self, list_):
+ def copy_and_process(
+ self, list_: List[ExternallyTraversible]
+ ) -> List[ExternallyTraversible]:
"""Apply cloned traversal to the given list of elements, and return
the new list.
"""
return [self.traverse(x) for x in list_]
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
return cloned_traverse(
"""
- def replace(self, elem):
+ def replace(
+ self, elem: ExternallyTraversible
+ ) -> Optional[ExternallyTraversible]:
"""Receive pre-copied elements during a cloning traversal.
If the method returns a new element, the element is used
"""
return None
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
- def replace(elem):
+ def replace(
+ elem: ExternallyTraversible,
+ ) -> Optional[ExternallyTraversible]:
for v in self.visitor_iterator:
- e = v.replace(elem)
+ e = cast(ReplacingExternalTraversal, v).replace(elem)
if e is not None:
return e
+ return None
+
return replacement_traverse(obj, self.__traverse_options__, replace)
ReplacingCloningVisitor = ReplacingExternalTraversal
-def iterate(obj, opts=util.immutabledict()):
+def iterate(
+ obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT
+) -> Iterator[ExternallyTraversible]:
r"""Traverse the given expression structure, returning an iterator.
Traversal is configured to be breadth-first.
stack.append(t.get_children(**opts))
-def traverse_using(iterator, obj, visitors):
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: ExternallyTraversible,
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> ExternallyTraversible:
"""Visit the given expression structure using the given iterator of
objects.
return obj
-def traverse(obj, opts, visitors):
+def traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> ExternallyTraversible:
"""Traverse and visit the given expression structure using the default
iterator.
return traverse_using(iterate(obj, opts), obj, visitors)
-def cloned_traverse(obj, opts, visitors):
+def cloned_traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseTransformCallableType],
+) -> ExternallyTraversible:
"""Clone the given expression structure, allowing modifications by
visitors.
"""
- cloned = {}
+ cloned: Dict[int, ExternallyTraversible] = {}
stop_on = set(opts.get("stop_on", []))
- def deferred_copy_internals(obj):
+ def deferred_copy_internals(
+ obj: ExternallyTraversible,
+ ) -> ExternallyTraversible:
return cloned_traverse(obj, opts, visitors)
- def clone(elem, **kw):
+ def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible:
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
if "replace" in kw:
- newelem = kw["replace"](elem)
+ newelem = cast(
+ Optional[ExternallyTraversible], kw["replace"](elem)
+ )
if newelem is not None:
cloned[id(elem)] = newelem
return newelem
obj = clone(
obj, deferred_copy_internals=deferred_copy_internals, **opts
)
- clone = None # remove gc cycles
+ clone = None # type: ignore[assignment] # remove gc cycles
return obj
-def replacement_traverse(obj, opts, replace):
+def replacement_traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> ExternallyTraversible:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
cloned = {}
stop_on = {id(x) for x in opts.get("stop_on", [])}
- def deferred_copy_internals(obj):
+ def deferred_copy_internals(
+ obj: ExternallyTraversible,
+ ) -> ExternallyTraversible:
return replacement_traverse(obj, opts, replace)
- def clone(elem, **kw):
+ def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible:
if (
id(elem) in stop_on
or "no_replacement_traverse" in elem._annotations
obj = clone(
obj, deferred_copy_internals=deferred_copy_internals, **opts
)
- clone = None # remove gc cycles
+ clone = None # type: ignore[assignment] # remove gc cycles
return obj
return spec
-def _exec_code_in_env(code, env, fn_name):
+def _exec_code_in_env(
+ code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
+) -> Callable[..., Any]:
exec(code, env)
- return env[fn_name]
+ return env[fn_name] # type: ignore[no-any-return]
_PF = TypeVar("_PF")
obj.__dict__.pop(name, None)
-def memoized_instancemethod(fn):
+def memoized_instancemethod(fn: _F) -> _F:
"""Decorate a method memoize its return value.
Best applied to no-arg methods: memoization is not sensitive to
self.__dict__[fn.__name__] = memo
return result
- return update_wrapper(oneshot, fn)
+ return update_wrapper(oneshot, fn) # type: ignore
class HasMemoized:
import sys
import typing
from typing import Any
-from typing import Callable # noqa
from typing import cast
from typing import Dict
from typing import ForwardRef
+from typing import Iterable
+from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
from . import compat
_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT")
+_KT_co = TypeVar("_KT_co", covariant=True)
+_KT_contra = TypeVar("_KT_contra", contravariant=True)
+_VT = TypeVar("_VT")
+_VT_co = TypeVar("_VT_co", covariant=True)
Self = TypeVar("Self", bound=Any)
from typing_extensions import Protocol as Protocol # noqa F401
from typing_extensions import TypedDict as TypedDict # noqa F401
+# copied from TypeShed, required in order to implement
+# MutableMapping.update()
+
+
+class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
+ def keys(self) -> Iterable[_KT]:
+ ...
+
+ def __getitem__(self, __k: _KT) -> _VT_co:
+ ...
+
+
# work around https://github.com/microsoft/pyright/issues/3025
_LiteralStar = Literal["*"]
return cast(Any, Union).__getitem__(types)
-def expand_unions(type_, include_union=False, discard_none=False):
+def expand_unions(
+ type_: Type[Any], include_union: bool = False, discard_none: bool = False
+) -> Tuple[Type[Any], ...]:
"""Return a type as as a tuple of individual types, expanding for
``Union`` types."""
]
[tool.pyright]
-include = [
- "lib/sqlalchemy/engine/base.py",
- "lib/sqlalchemy/engine/events.py",
- "lib/sqlalchemy/engine/interfaces.py",
- "lib/sqlalchemy/engine/_py_row.py",
- "lib/sqlalchemy/engine/result.py",
- "lib/sqlalchemy/engine/row.py",
- "lib/sqlalchemy/engine/util.py",
- "lib/sqlalchemy/engine/url.py",
- "lib/sqlalchemy/pool/",
- "lib/sqlalchemy/event/",
- "lib/sqlalchemy/events.py",
- "lib/sqlalchemy/exc.py",
- "lib/sqlalchemy/log.py",
- "lib/sqlalchemy/inspection.py",
- "lib/sqlalchemy/schema.py",
- "lib/sqlalchemy/types.py",
- "lib/sqlalchemy/util/",
-]
+
reportPrivateUsage = "none"
reportUnusedClass = "none"
reportUnusedFunction = "none"
-
+reportTypedDictNotRequiredAccess = "warning"
[tool.mypy]
mypy_path = "./lib/"
# strict checking
[[tool.mypy.overrides]]
module = [
+ "sqlalchemy.sql.annotation",
+ "sqlalchemy.sql.cache_key",
+ "sqlalchemy.sql.roles",
+ "sqlalchemy.sql.visitors",
+ "sqlalchemy.sql._py_util",
"sqlalchemy.connectors.*",
"sqlalchemy.engine.*",
"sqlalchemy.ext.associationproxy",
[[tool.mypy.overrides]]
module = [
+ "sqlalchemy.sql.coercions",
+ "sqlalchemy.sql.compiler",
+ #"sqlalchemy.sql.crud",
+ #"sqlalchemy.sql.default_comparator",
+ "sqlalchemy.sql.naming",
+ "sqlalchemy.sql.traversals",
"sqlalchemy.util.*",
"sqlalchemy.engine.cursor",
"sqlalchemy.engine.default",