From: Mike Bayer Date: Sun, 20 Mar 2022 20:39:36 +0000 (-0400) Subject: pep484 - SQL internals X-Git-Tag: rel_2_0_0b1~404^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6f02d5edd88fe2475629438b0730181a2b00c5fe;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep484 - SQL internals non-strict checking for mostly internal or semi-internal code Change-Id: Ib91b47f1a8ccc15e666b94bad1ce78c4ab15b0ec --- diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index de01a1b461..4a6ae08b25 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -7,6 +7,8 @@ from __future__ import annotations +from typing import Any + from . import util as _util from .engine import AdaptedConnection as AdaptedConnection from .engine import BaseRow as BaseRow @@ -191,7 +193,6 @@ from .sql.expression import tuple_ as tuple_ from .sql.expression import type_coerce as type_coerce from .sql.expression import TypeClause as TypeClause from .sql.expression import TypeCoerce as TypeCoerce -from .sql.expression import typing as typing from .sql.expression import UnaryExpression as UnaryExpression from .sql.expression import union as union from .sql.expression import union_all as union_all @@ -254,7 +255,7 @@ from .types import VARCHAR as VARCHAR __version__ = "2.0.0b1" -def __go(lcls): +def __go(lcls: Any) -> None: from . import util as _sa_util _sa_util.preloaded.import_prefix("sqlalchemy") diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx index c33a6e4a50..fe2cb94ffe 100644 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -26,6 +26,10 @@ cdef class OrderedSet(set): cdef list _list + @classmethod + def __class_getitem__(cls, key): + return cls + def __init__(self, d=None): set.__init__(self) if d is not None: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 061794bded..714ad3c85e 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -73,6 +73,7 @@ if typing.TYPE_CHECKING: from ..sql.functions import FunctionElement from ..sql.schema import ColumnDefault from ..sql.schema import HasSchemaAttr + from ..sql.schema import SchemaItem """Defines :class:`_engine.Connection` and :class:`_engine.Engine`. @@ -2004,7 +2005,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _run_ddl_visitor( self, visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: DDLElement, + element: SchemaItem, **kwargs: Any, ) -> None: """run a DDL visitor. @@ -2749,7 +2750,7 @@ class Engine( def _run_ddl_visitor( self, visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: DDLElement, + element: SchemaItem, **kwargs: Any, ) -> None: with self.begin() as conn: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4a833d2e54..d605af3efa 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -54,7 +54,9 @@ from ..sql import expression from ..sql._typing import is_tuple_type from ..sql.compiler import DDLCompiler from ..sql.compiler import SQLCompiler +from ..sql.elements import ColumnClause from ..sql.elements import quoted_name +from ..sql.schema import default_is_scalar if typing.TYPE_CHECKING: from types import ModuleType @@ -1164,7 +1166,7 @@ class DefaultExecutionContext(ExecutionContext): return () @util.memoized_property - def returning_cols(self) -> Optional[Sequence[Column[Any]]]: + def returning_cols(self) -> Optional[Sequence[ColumnClause[Any]]]: if TYPE_CHECKING: assert isinstance(self.compiled, SQLCompiler) return self.compiled.returning @@ -1778,15 +1780,11 @@ class DefaultExecutionContext(ExecutionContext): # to avoid many calls of get_insert_default()/ # get_update_default() for c in insert_prefetch: - if c.default and not c.default.is_sequence and c.default.is_scalar: - if TYPE_CHECKING: - assert isinstance(c.default, ColumnDefault) + if c.default and default_is_scalar(c.default): scalar_defaults[c] = c.default.arg for c in update_prefetch: - if c.onupdate and c.onupdate.is_scalar: - if TYPE_CHECKING: - assert isinstance(c.onupdate, ColumnDefault) + if c.onupdate and default_is_scalar(c.onupdate): scalar_defaults[c] = c.onupdate.arg for param in self.compiled_parameters: @@ -1817,9 +1815,7 @@ class DefaultExecutionContext(ExecutionContext): ) = self.compiled_parameters[0] for c in compiled.insert_prefetch: - if c.default and not c.default.is_sequence and c.default.is_scalar: - if TYPE_CHECKING: - assert isinstance(c.default, ColumnDefault) + if c.default and default_is_scalar(c.default): val = c.default.arg else: val = self.get_insert_default(c) diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index a0ba966039..c94dd1032e 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -32,6 +32,7 @@ if typing.TYPE_CHECKING: from ..sql.ddl import SchemaDropper from ..sql.ddl import SchemaGenerator from ..sql.schema import HasSchemaAttr + from ..sql.schema import SchemaItem class MockConnection: @@ -55,7 +56,7 @@ class MockConnection: def _run_ddl_visitor( self, visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: DDLElement, + element: SchemaItem, **kwargs: Any, ) -> None: kwargs["checkfirst"] = False diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index e1281365e9..b8ece2b1d2 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -317,7 +317,6 @@ class Inspector(inspection.Inspectable["Inspector"]): with an already-given :class:`_schema.MetaData`. """ - with self._operation_context() as conn: tnames = self.dialect.get_table_names( conn, schema, info_cache=self.info_cache diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 943c71b367..7192675dfa 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -88,7 +88,7 @@ class SQLAlchemyAttribute: return cls(typ=typ, info=info, **data) -def name_is_dunder(name): +def name_is_dunder(name: str) -> bool: return bool(re.match(r"^__.+?__$", name)) diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 9a23d89d3f..2feade04ee 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -225,7 +225,7 @@ def instance_logger( else: name = _qual_logger_name_for_cls(instance.__class__) - instance._echo = echoflag + instance._echo = echoflag # type: ignore logger: Union[logging.Logger, InstanceLogger] @@ -239,7 +239,7 @@ def instance_logger( # levels by calling logger._log() logger = InstanceLogger(echoflag, name) - instance.logger = logger + instance.logger = logger # type: ignore class echo_property: diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index d4d010cbe6..dd3931faf6 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -511,7 +511,8 @@ class Composite( """ - __hash__ = None + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore @util.memoized_property def clauses(self): diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index b4697912bc..58c7c4efd5 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -476,7 +476,8 @@ class Relationship( "the set of foreign key values." ) - __hash__ = None + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore def __eq__(self, other): """Implement the ``==`` operator. diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 169ddf3dbb..2e766f9766 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from typing import Any from .base import Executable as Executable from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS @@ -97,7 +98,7 @@ from .expression import within_group as within_group from .visitors import ClauseVisitor as ClauseVisitor -def __go(lcls): +def __go(lcls: Any) -> None: from .. import util as _sa_util from . import base diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 9043aa6d05..e9acc7e6dc 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -8,7 +8,7 @@ from __future__ import annotations from typing import Any -from typing import Union +from typing import Optional from . import coercions from . import roles @@ -23,8 +23,6 @@ from .selectable import Select from .selectable import TableClause from .selectable import TableSample from .selectable import Values -from ..util.typing import _LiteralStar -from ..util.typing import Literal def alias(selectable, name=None, flat=False): @@ -283,9 +281,7 @@ def outerjoin(left, right, onclause=None, full=False): return Join(left, right, onclause, isouter=True, full=full) -def select( - *entities: Union[_LiteralStar, Literal[1], _ColumnsClauseElement] -) -> "Select": +def select(*entities: _ColumnsClauseElement) -> Select: r"""Construct a new :class:`_expression.Select`. @@ -326,7 +322,7 @@ def select( return Select(*entities) -def table(name: str, *columns: ColumnClause, **kw: Any) -> "TableClause": +def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause: """Produce a new :class:`_expression.TableClause`. The object returned is an instance of @@ -435,7 +431,11 @@ def union_all(*selects): return CompoundSelect._create_union_all(*selects) -def values(*columns, name=None, literal_binds=False) -> "Values": +def values( + *columns: ColumnClause[Any], + name: Optional[str] = None, + literal_binds: bool = False, +) -> Values: r"""Construct a :class:`_expression.Values` construct. The column expressions and the actual data for diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 2be98b88fe..b50a7bf6a1 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -9,6 +9,7 @@ from typing import Union from . import roles from .. import util from ..inspection import Inspectable +from ..util.typing import Literal if TYPE_CHECKING: from .elements import quoted_name @@ -24,12 +25,13 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) _ColumnsClauseElement = Union[ + Literal["*", 1], roles.ColumnsClauseRole, - Type, + Type[Any], Inspectable[roles.HasColumnElementClauseElement], ] _FromClauseElement = Union[ - roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement] + roles.FromClauseRole, Type[Any], Inspectable[roles.HasFromClauseElement] ] _ColumnExpression = Union[ diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6a6b389de8..8f51359155 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -12,22 +12,32 @@ from __future__ import annotations -import collections.abc as collections_abc from enum import Enum from functools import reduce import itertools from itertools import zip_longest import operator import re -import typing from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic from typing import Iterable +from typing import Iterator from typing import List +from typing import Mapping from typing import MutableMapping +from typing import NoReturn from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import roles from . import visitors @@ -36,17 +46,26 @@ from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa from .visitors import ClauseVisitor from .visitors import ExtendedInternalTraversal +from .visitors import ExternallyTraversible from .visitors import InternalTraversal +from .. import event from .. import exc from .. import util from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing +from ..util.typing import Protocol from ..util.typing import Self +from ..util.typing import TypeGuard -if typing.TYPE_CHECKING: +if TYPE_CHECKING: + from . import coercions + from . import elements + from . import type_api from .elements import BindParameter + from .elements import ColumnClause from .elements import ColumnElement + from .elements import SQLCoreOperations from ..engine import Connection from ..engine import Result from ..engine.base import _CompiledCacheType @@ -58,10 +77,12 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import CacheStats from ..engine.interfaces import Compiled from ..engine.interfaces import Dialect + from ..event import dispatcher -coercions = None -elements = None -type_api = None +if not TYPE_CHECKING: + coercions = None # noqa + elements = None # noqa + type_api = None # noqa class _NoArg(Enum): @@ -70,13 +91,24 @@ class _NoArg(Enum): NO_ARG = _NoArg.NO_ARG -# if I use sqlalchemy.util.typing, which has the exact same -# symbols, mypy reports: "error: _Fn? not callable" -_Fn = typing.TypeVar("_Fn", bound=typing.Callable) +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) _AmbiguousTableNameMap = MutableMapping[str, str] +class _EntityNamespace(Protocol): + def __getattr__(self, key: str) -> SQLCoreOperations[Any]: + ... + + +class _HasEntityNamespace(Protocol): + entity_namespace: _EntityNamespace + + +def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: + return hasattr(element, "entity_namespace") + + class Immutable: """mark a ClauseElement as 'immutable' when expressions are cloned.""" @@ -107,10 +139,14 @@ class SingletonConstant(Immutable): def __new__(cls, *arg, **kw): return cls._singleton + @util.non_memoized_property + def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + raise NotImplementedError() + @classmethod def _create_singleton(cls): obj = object.__new__(cls) - obj.__init__() + obj.__init__() # type: ignore # for a long time this was an empty frozenset, meaning # a SingletonConstant would never be a "corresponding column" in @@ -139,12 +175,11 @@ def _select_iterables(elements): ) -_Self = typing.TypeVar("_Self", bound="_GenerativeType") -_Args = compat_typing.ParamSpec("_Args") +_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType") class _GenerativeType(compat_typing.Protocol): - def _generate(self: "_Self") -> "_Self": + def _generate(self: _SelfGenerativeType) -> _SelfGenerativeType: ... @@ -158,8 +193,8 @@ def _generative(fn: _Fn) -> _Fn: @util.decorator def _generative( - fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs - ) -> _Self: + fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any + ) -> _SelfGenerativeType: """Mark a method as generative.""" self = self._generate() @@ -167,9 +202,9 @@ def _generative(fn: _Fn) -> _Fn: assert x is self, "generative methods must return self" return self - decorated = _generative(fn) - decorated.non_generative = fn - return decorated + decorated = _generative(fn) # type: ignore + decorated.non_generative = fn # type: ignore + return decorated # type: ignore def _exclusive_against(*names, **kw): @@ -233,7 +268,7 @@ def _cloned_difference(a, b): ) -class _DialectArgView(collections_abc.MutableMapping): +class _DialectArgView(MutableMapping[str, Any]): """A dictionary view of dialect-level arguments in the form _. @@ -290,7 +325,7 @@ class _DialectArgView(collections_abc.MutableMapping): ) -class _DialectArgDict(collections_abc.MutableMapping): +class _DialectArgDict(MutableMapping[str, Any]): """A dictionary view of dialect-level arguments for a specific dialect. @@ -343,6 +378,8 @@ class DialectKWArgs: """ + __slots__ = () + _dialect_kwargs_traverse_internals = [ ("dialect_options", InternalTraversal.dp_dialect_options) ] @@ -534,7 +571,7 @@ class CompileState: __slots__ = ("statement", "_ambiguous_table_name_map") - plugins = {} + plugins: Dict[Tuple[str, str], Type[CompileState]] = {} _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] @@ -639,9 +676,9 @@ class InPlaceGenerative(HasMemoized): class HasCompileState(Generative): """A class that has a :class:`.CompileState` associated with it.""" - _compile_state_plugin = None + _compile_state_plugin: Optional[Type[CompileState]] = None - _attributes = util.immutabledict() + _attributes: util.immutabledict[str, Any] = util.EMPTY_DICT _compile_state_factory = CompileState.create_for_statement @@ -655,6 +692,8 @@ class _MetaOptions(type): """ + _cache_attrs: Tuple[str, ...] + def __add__(self, other): o1 = self() @@ -674,6 +713,8 @@ class Options(metaclass=_MetaOptions): __slots__ = () + _cache_attrs: Tuple[str, ...] + def __init_subclass__(cls) -> None: dict_ = cls.__dict__ cls._cache_attrs = tuple( @@ -732,13 +773,13 @@ class Options(metaclass=_MetaOptions): return self + {name: getattr(self, name) + value} @hybridmethod - def _state_dict(self): + def _state_dict_inst(self) -> Mapping[str, Any]: return self.__dict__ - _state_dict_const = util.immutabledict() + _state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT - @_state_dict.classlevel - def _state_dict(cls): + @_state_dict_inst.classlevel + def _state_dict(cls) -> Mapping[str, Any]: return cls._state_dict_const @classmethod @@ -825,10 +866,10 @@ class CacheableOptions(Options, HasCacheKey): __slots__ = () @hybridmethod - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key_inst(self, anon_map, bindparams): return HasCacheKey._gen_cache_key(self, anon_map, bindparams) - @_gen_cache_key.classlevel + @_gen_cache_key_inst.classlevel def _gen_cache_key(cls, anon_map, bindparams): return (cls, ()) @@ -849,11 +890,11 @@ class ExecutableOption(HasCopyInternals): def _clone(self, **kw): """Create a shallow copy of this ExecutableOption.""" c = self.__class__.__new__(self.__class__) - c.__dict__ = dict(self.__dict__) + c.__dict__ = dict(self.__dict__) # type: ignore return c -SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable") +SelfExecutable = TypeVar("SelfExecutable", bound="Executable") class Executable(roles.StatementRole, Generative): @@ -866,9 +907,12 @@ class Executable(roles.StatementRole, Generative): """ supports_execution: bool = True - _execution_options: _ImmutableExecuteOptions = util.immutabledict() - _with_options = () - _with_context_options = () + _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT + _with_options: Tuple[ExecutableOption, ...] = () + _with_context_options: Tuple[ + Tuple[Callable[[CompileState], None], Any], ... + ] = () + _compile_options: Optional[CacheableOptions] _executable_traverse_internals = [ ("_with_options", InternalTraversal.dp_executable_options), @@ -886,7 +930,9 @@ class Executable(roles.StatementRole, Generative): is_delete = False is_dml = False - if typing.TYPE_CHECKING: + if TYPE_CHECKING: + + __visit_name__: str def _compile_w_cache( self, @@ -916,11 +962,13 @@ class Executable(roles.StatementRole, Generative): raise NotImplementedError() @property - def _effective_plugin_target(self): + def _effective_plugin_target(self) -> str: return self.__visit_name__ @_generative - def options(self: SelfExecutable, *options) -> SelfExecutable: + def options( + self: SelfExecutable, *options: ExecutableOption + ) -> SelfExecutable: """Apply options to this statement. In the general sense, options are any kind of Python object @@ -957,7 +1005,7 @@ class Executable(roles.StatementRole, Generative): @_generative def _set_compile_options( - self: SelfExecutable, compile_options + self: SelfExecutable, compile_options: CacheableOptions ) -> SelfExecutable: """Assign the compile options to a new value. @@ -970,16 +1018,19 @@ class Executable(roles.StatementRole, Generative): @_generative def _update_compile_options( - self: SelfExecutable, options + self: SelfExecutable, options: CacheableOptions ) -> SelfExecutable: """update the _compile_options with new keys.""" + assert self._compile_options is not None self._compile_options += options return self @_generative def _add_context_option( - self: SelfExecutable, callable_, cache_args + self: SelfExecutable, + callable_: Callable[[CompileState], None], + cache_args: Any, ) -> SelfExecutable: """Add a context option to this statement. @@ -995,7 +1046,7 @@ class Executable(roles.StatementRole, Generative): return self @_generative - def execution_options(self: SelfExecutable, **kw) -> SelfExecutable: + def execution_options(self: SelfExecutable, **kw: Any) -> SelfExecutable: """Set non-SQL options for the statement which take effect during execution. @@ -1112,7 +1163,7 @@ class Executable(roles.StatementRole, Generative): self._execution_options = self._execution_options.union(kw) return self - def get_execution_options(self): + def get_execution_options(self) -> _ExecuteOptions: """Get the non-SQL options which will take effect during execution. .. versionadded:: 1.3 @@ -1124,7 +1175,7 @@ class Executable(roles.StatementRole, Generative): return self._execution_options -class SchemaEventTarget: +class SchemaEventTarget(event.EventTarget): """Base class for elements that are the targets of :class:`.DDLEvents` events. @@ -1132,6 +1183,8 @@ class SchemaEventTarget: """ + dispatch: dispatcher[SchemaEventTarget] + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: """Associate with this SchemaEvent's parent object.""" @@ -1149,7 +1202,10 @@ class SchemaVisitor(ClauseVisitor): __traverse_options__ = {"schema_visitor": True} -class ColumnCollection: +_COL = TypeVar("_COL", bound="ColumnClause[Any]") + + +class ColumnCollection(Generic[_COL]): """Collection of :class:`_expression.ColumnElement` instances, typically for :class:`_sql.FromClause` objects. @@ -1260,32 +1316,36 @@ class ColumnCollection: __slots__ = "_collection", "_index", "_colset" - def __init__(self, columns=None): + _collection: List[Tuple[str, _COL]] + _index: Dict[Union[str, int], _COL] + _colset: Set[_COL] + + def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None): object.__setattr__(self, "_colset", set()) object.__setattr__(self, "_index", {}) object.__setattr__(self, "_collection", []) if columns: self._initial_populate(columns) - def _initial_populate(self, iter_): + def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None: self._populate_separate_keys(iter_) @property - def _all_columns(self): + def _all_columns(self) -> List[_COL]: return [col for (k, col) in self._collection] - def keys(self): + def keys(self) -> List[str]: """Return a sequence of string key names for all columns in this collection.""" return [k for (k, col) in self._collection] - def values(self): + def values(self) -> List[_COL]: """Return a sequence of :class:`_sql.ColumnClause` or :class:`_schema.Column` objects for all columns in this collection.""" return [col for (k, col) in self._collection] - def items(self): + def items(self) -> List[Tuple[str, _COL]]: """Return a sequence of (key, column) tuples for all columns in this collection each consisting of a string key name and a :class:`_sql.ColumnClause` or @@ -1294,17 +1354,17 @@ class ColumnCollection: return list(self._collection) - def __bool__(self): + def __bool__(self) -> bool: return bool(self._collection) - def __len__(self): + def __len__(self) -> int: return len(self._collection) - def __iter__(self): + def __iter__(self) -> Iterator[_COL]: # turn to a list first to maintain over a course of changes return iter([col for k, col in self._collection]) - def __getitem__(self, key): + def __getitem__(self, key: Union[str, int]) -> _COL: try: return self._index[key] except KeyError as err: @@ -1313,13 +1373,13 @@ class ColumnCollection: else: raise - def __getattr__(self, key): + def __getattr__(self, key: str) -> _COL: try: return self._index[key] except KeyError as err: raise AttributeError(key) from err - def __contains__(self, key): + def __contains__(self, key: str) -> bool: if key not in self._index: if not isinstance(key, str): raise exc.ArgumentError( @@ -1329,7 +1389,7 @@ class ColumnCollection: else: return True - def compare(self, other): + def compare(self, other: ColumnCollection[Any]) -> bool: """Compare this :class:`_expression.ColumnCollection` to another based on the names of the keys""" @@ -1339,10 +1399,10 @@ class ColumnCollection: else: return True - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.compare(other) - def get(self, key, default=None): + def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]: """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object based on a string key name from this :class:`_expression.ColumnCollection`.""" @@ -1352,39 +1412,40 @@ class ColumnCollection: else: return default - def __str__(self): + def __str__(self) -> str: return "%s(%s)" % ( self.__class__.__name__, ", ".join(str(c) for c in self), ) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> NoReturn: raise NotImplementedError() - def __delitem__(self, key): + def __delitem__(self, key: str) -> NoReturn: raise NotImplementedError() - def __setattr__(self, key, obj): + def __setattr__(self, key: str, obj: Any) -> NoReturn: raise NotImplementedError() - def clear(self): + def clear(self) -> NoReturn: """Dictionary clear() is not implemented for :class:`_sql.ColumnCollection`.""" raise NotImplementedError() - def remove(self, column): - """Dictionary remove() is not implemented for - :class:`_sql.ColumnCollection`.""" + def remove(self, column: Any) -> None: raise NotImplementedError() - def update(self, iter_): + def update(self, iter_: Any) -> NoReturn: """Dictionary update() is not implemented for :class:`_sql.ColumnCollection`.""" raise NotImplementedError() - __hash__ = None + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore - def _populate_separate_keys(self, iter_): + def _populate_separate_keys( + self, iter_: Iterable[Tuple[str, _COL]] + ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) self._collection[:] = cols @@ -1394,7 +1455,7 @@ class ColumnCollection: ) self._index.update({k: col for k, col in reversed(self._collection)}) - def add(self, column, key=None): + def add(self, column: _COL, key: Optional[str] = None) -> None: """Add a column to this :class:`_sql.ColumnCollection`. .. note:: @@ -1416,17 +1477,17 @@ class ColumnCollection: if key not in self._index: self._index[key] = column - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return {"_collection": self._collection, "_index": self._index} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: object.__setattr__(self, "_index", state["_index"]) object.__setattr__(self, "_collection", state["_collection"]) object.__setattr__( self, "_colset", {col for k, col in self._collection} ) - def contains_column(self, col): + def contains_column(self, col: _COL) -> bool: """Checks if a column object exists in this collection""" if col not in self._colset: if isinstance(col, str): @@ -1438,13 +1499,15 @@ class ColumnCollection: else: return True - def as_immutable(self): + def as_immutable(self) -> ImmutableColumnCollection[_COL]: """Return an "immutable" form of this :class:`_sql.ColumnCollection`.""" return ImmutableColumnCollection(self) - def corresponding_column(self, column, require_embedded=False): + def corresponding_column( + self, column: _COL, require_embedded: bool = False + ) -> Optional[_COL]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from this :class:`_expression.ColumnCollection` @@ -1497,7 +1560,7 @@ class ColumnCollection: not require_embedded or embedded(expanded_proxy_set, target_set) ): - if col is None: + if col is None or intersect is None: # no corresponding column yet, pick this one. @@ -1542,7 +1605,7 @@ class ColumnCollection: return col -class DedupeColumnCollection(ColumnCollection): +class DedupeColumnCollection(ColumnCollection[_COL]): """A :class:`_expression.ColumnCollection` that maintains deduplicating behavior. @@ -1555,7 +1618,7 @@ class DedupeColumnCollection(ColumnCollection): """ - def add(self, column, key=None): + def add(self, column: _COL, key: Optional[str] = None) -> None: if key is not None and column.key != key: raise exc.ArgumentError( @@ -1589,7 +1652,9 @@ class DedupeColumnCollection(ColumnCollection): self._index[l] = column self._index[key] = column - def _populate_separate_keys(self, iter_): + def _populate_separate_keys( + self, iter_: Iterable[Tuple[str, _COL]] + ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) @@ -1614,10 +1679,10 @@ class DedupeColumnCollection(ColumnCollection): for col in replace_col: self.replace(col) - def extend(self, iter_): + def extend(self, iter_: Iterable[_COL]) -> None: self._populate_separate_keys((col.key, col) for col in iter_) - def remove(self, column): + def remove(self, column: _COL) -> None: if column not in self._colset: raise ValueError( "Can't remove column %r; column is not in this collection" @@ -1634,7 +1699,7 @@ class DedupeColumnCollection(ColumnCollection): # delete higher index del self._index[len(self._collection)] - def replace(self, column): + def replace(self, column: _COL) -> None: """add the given column to this collection, removing unaliased versions of this column as well as existing columns with the same key. @@ -1687,7 +1752,9 @@ class DedupeColumnCollection(ColumnCollection): self._index.update(self._collection) -class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): +class ImmutableColumnCollection( + util.ImmutableContainer, ColumnCollection[_COL] +): __slots__ = ("_parent",) def __init__(self, collection): @@ -1701,12 +1768,19 @@ class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): def __setstate__(self, state): parent = state["_parent"] - self.__init__(parent) + self.__init__(parent) # type: ignore - add = extend = remove = util.ImmutableContainer._immutable + def add(self, column: Any, key: Any = ...) -> Any: + self._immutable() + def extend(self, elements: Any) -> None: + self._immutable() -class ColumnSet(util.ordered_column_set): + def remove(self, item: Any) -> None: + self._immutable() + + +class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): def contains_column(self, col): return col in self @@ -1714,9 +1788,6 @@ class ColumnSet(util.ordered_column_set): for col in cols: self.add(col) - def __add__(self, other): - return list(self) + list(other) - def __eq__(self, other): l = [] for c in other: @@ -1729,7 +1800,9 @@ class ColumnSet(util.ordered_column_set): return hash(tuple(x for x in self)) -def _entity_namespace(entity): +def _entity_namespace( + entity: Union[_HasEntityNamespace, ExternallyTraversible] +) -> _EntityNamespace: """Return the nearest .entity_namespace for the given entity. If not immediately available, does an iterate to find a sub-element @@ -1737,16 +1810,20 @@ def _entity_namespace(entity): """ try: - return entity.entity_namespace + return cast(_HasEntityNamespace, entity).entity_namespace except AttributeError: - for elem in visitors.iterate(entity): - if hasattr(elem, "entity_namespace"): + for elem in visitors.iterate(cast(ExternallyTraversible, entity)): + if _is_has_entity_namespace(elem): return elem.entity_namespace else: raise -def _entity_namespace_key(entity, key, default=NO_ARG): +def _entity_namespace_key( + entity: Union[_HasEntityNamespace, ExternallyTraversible], + key: str, + default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG, +) -> SQLCoreOperations[Any]: """Return an entry from an entity_namespace. @@ -1760,7 +1837,7 @@ def _entity_namespace_key(entity, key, default=NO_ARG): if default is not NO_ARG: return getattr(ns, key, default) else: - return getattr(ns, key) + return getattr(ns, key) # type: ignore except AttributeError as err: raise exc.InvalidRequestError( 'Entity namespace for "%s" has no property "%s"' % (entity, key) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index f8019b9c64..5ba52ae51c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -71,6 +71,7 @@ from .schema import Column from .sqltypes import TupleType from .type_api import TypeEngine from .visitors import prefix_anon_map +from .visitors import Visitable from .. import exc from .. import util from ..util.typing import Literal @@ -614,10 +615,10 @@ class Compiled: raise NotImplementedError() - def process(self, obj, **kwargs): + def process(self, obj: Visitable, **kwargs: Any) -> str: return obj._compiler_dispatch(self, **kwargs) - def __str__(self): + def __str__(self) -> str: """Return the string text of the generated SQL or DDL.""" return self.string or "" @@ -723,7 +724,7 @@ class SQLCompiler(Compiled): """list of columns for which onupdate default values should be evaluated before an UPDATE takes place""" - returning: Optional[List[Column[Any]]] + returning: Optional[List[ColumnClause[Any]]] """list of columns that will be delivered to cursor.description or dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE @@ -1485,15 +1486,12 @@ class SQLCompiler(Compiled): self._result_columns ) - _key_getters_for_crud_column: Tuple[ - Callable[[Union[str, Column[Any]]], str], - Callable[[Column[Any]], str], - Callable[[Column[Any]], str], - ] + # assigned by crud.py for insert/update statements + _get_bind_name_for_col: _BindNameForColProtocol @util.memoized_property def _within_exec_param_key_getter(self) -> Callable[[Any], str]: - getter = self._key_getters_for_crud_column[2] + getter = self._get_bind_name_for_col if self.escaped_bind_names: def _get(obj): @@ -4098,7 +4096,9 @@ class SQLCompiler(Compiled): def for_update_clause(self, select, **kw): return " FOR UPDATE" - def returning_clause(self, stmt, returning_cols): + def returning_clause( + self, stmt: UpdateBase, returning_cols: List[ColumnClause[Any]] + ) -> str: raise exc.CompileError( "RETURNING is not supported by this " "dialect's statement compiler." @@ -4243,12 +4243,13 @@ class SQLCompiler(Compiled): } ) - crud_params = crud._get_crud_params( + crud_params_struct = crud._get_crud_params( self, insert_stmt, compile_state, **kw ) + crud_params_single = crud_params_struct.single_params if ( - not crud_params + not crud_params_single and not self.dialect.supports_default_values and not self.dialect.supports_default_metavalue and not self.dialect.supports_empty_insert @@ -4266,9 +4267,9 @@ class SQLCompiler(Compiled): "version settings does not support " "in-place multirow inserts." % self.dialect.name ) - crud_params_single = crud_params[0] + crud_params_single = crud_params_struct.single_params else: - crud_params_single = crud_params + crud_params_single = crud_params_struct.single_params preparer = self.preparer supports_default_values = self.dialect.supports_default_values @@ -4293,7 +4294,7 @@ class SQLCompiler(Compiled): if crud_params_single or not supports_default_values: text += " (%s)" % ", ".join( - [expr for c, expr, value in crud_params_single] + [expr for _, expr, _ in crud_params_single] ) if self.returning or insert_stmt._returning: @@ -4323,19 +4324,24 @@ class SQLCompiler(Compiled): ) else: text += " %s" % select_text - elif not crud_params and supports_default_values: + elif not crud_params_single and supports_default_values: text += " DEFAULT VALUES" elif compile_state._has_multi_parameters: text += " VALUES %s" % ( ", ".join( "(%s)" - % (", ".join(value for c, expr, value in crud_param_set)) - for crud_param_set in crud_params + % (", ".join(value for _, _, value in crud_param_set)) + for crud_param_set in crud_params_struct.all_multi_params ) ) else: insert_single_values_expr = ", ".join( - [value for c, expr, value in crud_params] + [ + value + for _, _, value in cast( + "List[Tuple[Any, Any, str]]", crud_params_single + ) + ] ) text += " VALUES (%s)" % insert_single_values_expr if toplevel and insert_stmt._post_values_clause is None: @@ -4443,9 +4449,10 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause( update_stmt, update_stmt.table, render_extra_froms, **kw ) - crud_params = crud._get_crud_params( + crud_params_struct = crud._get_crud_params( self, update_stmt, compile_state, **kw ) + crud_params = crud_params_struct.single_params if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( @@ -4460,7 +4467,12 @@ class SQLCompiler(Compiled): text += table_text text += " SET " - text += ", ".join(expr + "=" + value for c, expr, value in crud_params) + text += ", ".join( + expr + "=" + value + for _, expr, value in cast( + "List[Tuple[Any, str, str]]", crud_params + ) + ) if self.returning or update_stmt._returning: if self.returning_precedes_values: @@ -5446,6 +5458,11 @@ class _SchemaForObjectCallable(Protocol): ... +class _BindNameForColProtocol(Protocol): + def __call__(self, col: ColumnClause[Any]) -> str: + ... + + class IdentifierPreparer: """Handle quoting and case-folding of identifiers based on options.""" diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 4292aa9162..533a2f6cd9 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -13,13 +13,44 @@ from __future__ import annotations import functools import operator +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import MutableMapping +from typing import NamedTuple +from typing import Optional +from typing import overload +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union from . import coercions from . import dml from . import elements from . import roles +from .schema import default_is_clause_element +from .schema import default_is_sequence from .. import exc from .. import util +from ..util.typing import Literal + +if TYPE_CHECKING: + from .compiler import _BindNameForColProtocol + from .compiler import SQLCompiler + from .dml import DMLState + from .dml import Insert + from .dml import Update + from .dml import UpdateDMLState + from .dml import ValuesBase + from .elements import ClauseElement + from .elements import ColumnClause + from .elements import ColumnElement + from .elements import TextClause + from .schema import _SQLExprDefault + from .schema import Column + from .selectable import TableClause REQUIRED = util.symbol( "REQUIRED", @@ -36,7 +67,27 @@ values present. ) -def _get_crud_params(compiler, stmt, compile_state, **kw): +class _CrudParams(NamedTuple): + single_params: List[ + Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + ] + all_multi_params: List[ + List[ + Tuple[ + ColumnClause[Any], + str, + str, + ] + ] + ] + + +def _get_crud_params( + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + **kw: Any, +) -> _CrudParams: """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -59,24 +110,32 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): _column_as_key, _getattr_col_key, _col_bind_name, - ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state) + ) = _key_getters_for_crud_column(compiler, stmt, compile_state) - compiler._key_getters_for_crud_column = getters + compiler._get_bind_name_for_col = _col_bind_name # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: - return [ - ( - c, - compiler.preparer.format_column(c), - _create_bind_param(compiler, c, None, required=True), - ) - for c in stmt.table.columns - ] + return _CrudParams( + [ + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + ) + for c in stmt.table.columns + ], + [], + ) + + stmt_parameter_tuples: Optional[List[Any]] + spd: Optional[MutableMapping[str, Any]] if compile_state._has_multi_parameters: - spd = compile_state._multi_parameters[0] + mp = compile_state._multi_parameters + assert mp is not None + spd = mp[0] stmt_parameter_tuples = list(spd.items()) elif compile_state._ordered_values: spd = compile_state._dict_parameters @@ -92,6 +151,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): if compiler.column_keys is None: parameters = {} elif stmt_parameter_tuples: + assert spd is not None parameters = dict( (_column_as_key(key), REQUIRED) for key in compiler.column_keys @@ -103,7 +163,9 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ) # create a list of column assignment clauses as tuples - values = [] + values: List[ + Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + ] = [] if stmt_parameter_tuples is not None: _get_stmt_parameter_tuples_params( @@ -116,11 +178,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): kw, ) - check_columns = {} + check_columns: Dict[str, ColumnClause[Any]] = {} # special logic that only occurs for multi-table UPDATE # statements - if compile_state.isupdate and compile_state.is_multitable: + if dml.isupdate(compile_state) and compile_state.is_multitable: _get_update_multitable_params( compiler, stmt, @@ -134,6 +196,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ) if compile_state.isinsert and stmt._select_names: + # is an insert from select, is not a multiparams + + assert not compile_state._has_multi_parameters + _scan_insert_from_select_cols( compiler, stmt, @@ -173,14 +239,17 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ) if compile_state._has_multi_parameters: - values = _extend_values_for_multiparams( + # is a multiparams, is not an insert from a select + assert not stmt._select_names + multi_extended_values = _extend_values_for_multiparams( compiler, stmt, compile_state, - values, - _column_as_key, + cast("List[Tuple[ColumnClause[Any], str, str]]", values), + cast("Callable[..., str]", _column_as_key), kw, ) + return _CrudParams(values, multi_extended_values) elif ( not values and compiler.for_executemany @@ -198,12 +267,41 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ) ] - return values + return _CrudParams(values, []) +@overload def _create_bind_param( - compiler, col, value, process=True, required=False, name=None, **kw -): + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + process: Literal[True] = ..., + required: bool = False, + name: Optional[str] = None, + **kw: Any, +) -> str: + ... + + +@overload +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + **kw: Any, +) -> str: + ... + + +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + process: bool = True, + required: bool = False, + name: Optional[str] = None, + **kw: Any, +) -> Union[str, elements.BindParameter[Any]]: if name is None: name = col.key bindparam = elements.BindParameter( @@ -211,8 +309,9 @@ def _create_bind_param( ) bindparam._is_crud = True if process: - bindparam = bindparam._compiler_dispatch(compiler, **kw) - return bindparam + return bindparam._compiler_dispatch(compiler, **kw) + else: + return bindparam def _handle_values_anonymous_param(compiler, col, value, name, **kw): @@ -253,8 +352,14 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw): return value._compiler_dispatch(compiler, **kw) -def _key_getters_for_crud_column(compiler, stmt, compile_state): - if compile_state.isupdate and compile_state._extra_froms: +def _key_getters_for_crud_column( + compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState +) -> Tuple[ + Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]], + Callable[[Column[Any]], Union[str, Tuple[str, str]]], + _BindNameForColProtocol, +]: + if dml.isupdate(compile_state) and compile_state._extra_froms: # when extra tables are present, refer to the columns # in those extra tables as table-qualified, including in # dictionaries and when rendering bind param names. @@ -267,30 +372,36 @@ def _key_getters_for_crud_column(compiler, stmt, compile_state): coercions.expect_as_key, roles.DMLColumnRole ) - def _column_as_key(key): + def _column_as_key( + key: Union[ColumnClause[Any], str] + ) -> Union[str, Tuple[str, str]]: str_key = c_key_role(key) - if hasattr(key, "table") and key.table in _et: - return (key.table.name, str_key) + if hasattr(key, "table") and key.table in _et: # type: ignore + return (key.table.name, str_key) # type: ignore else: - return str_key + return str_key # type: ignore - def _getattr_col_key(col): + def _getattr_col_key( + col: ColumnClause[Any], + ) -> Union[str, Tuple[str, str]]: if col.table in _et: - return (col.table.name, col.key) + return (col.table.name, col.key) # type: ignore else: return col.key - def _col_bind_name(col): + def _col_bind_name(col: ColumnClause[Any]) -> str: if col.table in _et: + if TYPE_CHECKING: + assert isinstance(col.table, TableClause) return "%s_%s" % (col.table.name, col.key) else: return col.key else: - _column_as_key = functools.partial( + _column_as_key = functools.partial( # type: ignore coercions.expect_as_key, roles.DMLColumnRole ) - _getattr_col_key = _col_bind_name = operator.attrgetter("key") + _getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa E501 return _column_as_key, _getattr_col_key, _col_bind_name @@ -321,7 +432,7 @@ def _scan_insert_from_select_cols( compiler.stack[-1]["insert_from_select"] = stmt.select - add_select_cols = [] + add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = [] if stmt.include_insert_from_select_defaults: col_set = set(cols) for col in stmt.table.columns: @@ -707,16 +818,22 @@ def _append_param_insert_hasdefault( ) -def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): +def _append_param_insert_select_hasdefault( + compiler: SQLCompiler, + stmt: ValuesBase, + c: ColumnClause[Any], + values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]], + kw: Dict[str, Any], +) -> None: - if c.default.is_sequence: + if default_is_sequence(c.default): if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): values.append( (c, compiler.preparer.format_column(c), c.default.next_value()) ) - elif c.default.is_clause_element: + elif default_is_clause_element(c.default): values.append( (c, compiler.preparer.format_column(c), c.default.arg.self_group()) ) @@ -777,28 +894,76 @@ def _append_param_update( compiler.returning.append(c) +@overload def _create_insert_prefetch_bind_param( - compiler, c, process=True, name=None, **kw -): + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[True] = ..., + **kw: Any, +) -> str: + ... + + +@overload +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[False], + **kw: Any, +) -> elements.BindParameter[Any]: + ... + + +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: bool = True, + name: Optional[str] = None, + **kw: Any, +) -> Union[elements.BindParameter[Any], str]: param = _create_bind_param( compiler, c, None, process=process, name=name, **kw ) - compiler.insert_prefetch.append(c) + compiler.insert_prefetch.append(c) # type: ignore return param +@overload def _create_update_prefetch_bind_param( - compiler, c, process=True, name=None, **kw -): + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[True] = ..., + **kw: Any, +) -> str: + ... + + +@overload +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[False], + **kw: Any, +) -> elements.BindParameter[Any]: + ... + + +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: bool = True, + name: Optional[str] = None, + **kw: Any, +) -> Union[elements.BindParameter[Any], str]: param = _create_bind_param( compiler, c, None, process=process, name=name, **kw ) - compiler.update_prefetch.append(c) + compiler.update_prefetch.append(c) # type: ignore return param -class _multiparam_column(elements.ColumnElement): +class _multiparam_column(elements.ColumnElement[Any]): _is_multiparam_column = True def __init__(self, original, index): @@ -822,14 +987,20 @@ class _multiparam_column(elements.ColumnElement): ) -def _process_multiparam_default_bind(compiler, stmt, c, index, kw): +def _process_multiparam_default_bind( + compiler: SQLCompiler, + stmt: ValuesBase, + c: ColumnClause[Any], + index: int, + kw: Dict[str, Any], +) -> str: if not c.default: raise exc.CompileError( "INSERT value for column %s is explicitly rendered as a bound" "parameter in the VALUES clause; " "a Python-side value or SQL expression is required" % c ) - elif c.default.is_clause_element: + elif default_is_clause_element(c.default): return compiler.process(c.default.arg.self_group(), **kw) elif c.default.is_sequence: # these conditions would have been established @@ -844,9 +1015,13 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): else: col = _multiparam_column(c, index) if isinstance(stmt, dml.Insert): - return _create_insert_prefetch_bind_param(compiler, col, **kw) + return _create_insert_prefetch_bind_param( + compiler, col, process=True, **kw + ) else: - return _create_update_prefetch_bind_param(compiler, col, **kw) + return _create_update_prefetch_bind_param( + compiler, col, process=True, **kw + ) def _get_update_multitable_params( @@ -926,18 +1101,26 @@ def _get_update_multitable_params( def _extend_values_for_multiparams( - compiler, - stmt, - compile_state, - values, - _column_as_key, - kw, -): - values_0 = values - values = [values] - - for i, row in enumerate(compile_state._multi_parameters[1:]): - extension = [] + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + initial_values: List[Tuple[ColumnClause[Any], str, str]], + _column_as_key: Callable[..., str], + kw: Dict[str, Any], +) -> List[List[Tuple[ColumnClause[Any], str, str]]]: + values_0 = initial_values + values = [initial_values] + + mp = compile_state._multi_parameters + assert mp is not None + for i, row in enumerate(mp[1:]): + extension: List[ + Tuple[ + ColumnClause[Any], + str, + str, + ] + ] = [] row = {_column_as_key(key): v for key, v in row.items()} diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 91bb0a5c58..944a0a5ce6 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -26,6 +26,7 @@ from . import roles from . import type_api from .elements import and_ from .elements import BinaryExpression +from .elements import ClauseElement from .elements import ClauseList from .elements import CollationClause from .elements import CollectionAggregate @@ -43,7 +44,7 @@ _T = typing.TypeVar("_T", bound=Any) if typing.TYPE_CHECKING: from .elements import ColumnElement from .operators import custom_op - from .sqltypes import TypeEngine + from .type_api import TypeEngine def _boolean_compare( @@ -53,10 +54,10 @@ def _boolean_compare( *, negate_op: Optional[OperatorType] = None, reverse: bool = False, - _python_is_types=(util.NoneType, bool), - _any_all_expr=False, + _python_is_types: Tuple[Type[Any], ...] = (type(None), bool), + _any_all_expr: bool = False, result_type: Optional[ - Union[Type["TypeEngine[bool]"], "TypeEngine[bool]"] + Union[Type[TypeEngine[bool]], TypeEngine[bool]] ] = None, **kwargs: Any, ) -> BinaryExpression[bool]: @@ -165,7 +166,7 @@ def _custom_op_operate( def _binary_operate( expr: ColumnElement[Any], op: OperatorType, - obj: roles.BinaryElementRole, + obj: roles.BinaryElementRole[Any], *, reverse: bool = False, result_type: Optional[ @@ -192,7 +193,7 @@ def _binary_operate( def _conjunction_operate( - expr: ColumnElement[Any], op: OperatorType, other, **kw + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any ) -> ColumnElement[Any]: if op is operators.and_: return and_(expr, other) @@ -203,7 +204,10 @@ def _conjunction_operate( def _scalar( - expr: ColumnElement[Any], op: OperatorType, fn, **kw + expr: ColumnElement[Any], + op: OperatorType, + fn: Callable[[ColumnElement[Any]], ColumnElement[Any]], + **kw: Any, ) -> ColumnElement[Any]: return fn(expr) @@ -211,9 +215,9 @@ def _scalar( def _in_impl( expr: ColumnElement[Any], op: OperatorType, - seq_or_selectable, + seq_or_selectable: ClauseElement, negate_op: OperatorType, - **kw, + **kw: Any, ) -> ColumnElement[Any]: seq_or_selectable = coercions.expect( roles.InElementRole, seq_or_selectable, expr=expr, operator=op @@ -227,7 +231,7 @@ def _in_impl( def _getitem_impl( - expr: ColumnElement[Any], op: OperatorType, other, **kw + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any ) -> ColumnElement[Any]: if isinstance(expr.type, type_api.INDEXABLE): other = coercions.expect( @@ -239,7 +243,7 @@ def _getitem_impl( def _unsupported_impl( - expr: ColumnElement[Any], op: OperatorType, *arg, **kw + expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any ) -> NoReturn: raise NotImplementedError( "Operator '%s' is not supported on " "this expression" % op.__name__ @@ -247,7 +251,7 @@ def _unsupported_impl( def _inv_impl( - expr: ColumnElement[Any], op: OperatorType, **kw + expr: ColumnElement[Any], op: OperatorType, **kw: Any ) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.__inv__`.""" @@ -260,14 +264,14 @@ def _inv_impl( def _neg_impl( - expr: ColumnElement[Any], op: OperatorType, **kw + expr: ColumnElement[Any], op: OperatorType, **kw: Any ) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.__neg__`.""" return UnaryExpression(expr, operator=operators.neg, type_=expr.type) def _match_impl( - expr: ColumnElement[Any], op: OperatorType, other, **kw + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any ) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.match`.""" @@ -289,7 +293,7 @@ def _match_impl( def _distinct_impl( - expr: ColumnElement[Any], op: OperatorType, **kw + expr: ColumnElement[Any], op: OperatorType, **kw: Any ) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.distinct`.""" return UnaryExpression( @@ -298,7 +302,11 @@ def _distinct_impl( def _between_impl( - expr: ColumnElement[Any], op: OperatorType, cleft, cright, **kw + expr: ColumnElement[Any], + op: OperatorType, + cleft: Any, + cright: Any, + **kw: Any, ) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.between`.""" return BinaryExpression( @@ -329,26 +337,32 @@ def _between_impl( def _collate_impl( - expr: ColumnElement[Any], op: OperatorType, collation, **kw -) -> ColumnElement[Any]: + expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any +) -> ColumnElement[str]: return CollationClause._create_collation_expression(expr, collation) def _regexp_match_impl( - expr: ColumnElement[Any], op: OperatorType, pattern, flags, **kw + expr: ColumnElement[str], + op: OperatorType, + pattern: Any, + flags: Optional[str], + **kw: Any, ) -> ColumnElement[Any]: if flags is not None: - flags = coercions.expect( + flags_expr = coercions.expect( roles.BinaryElementRole, flags, expr=expr, operator=operators.regexp_replace_op, ) + else: + flags_expr = None return _boolean_compare( expr, op, pattern, - flags=flags, + flags=flags_expr, negate_op=operators.not_regexp_match_op if op is operators.regexp_match_op else operators.regexp_match_op, @@ -359,10 +373,10 @@ def _regexp_match_impl( def _regexp_replace_impl( expr: ColumnElement[Any], op: OperatorType, - pattern, - replacement, - flags, - **kw, + pattern: Any, + replacement: Any, + flags: Optional[str], + **kw: Any, ) -> ColumnElement[Any]: replacement = coercions.expect( roles.BinaryElementRole, @@ -371,21 +385,29 @@ def _regexp_replace_impl( operator=operators.regexp_replace_op, ) if flags is not None: - flags = coercions.expect( + flags_expr = coercions.expect( roles.BinaryElementRole, flags, expr=expr, operator=operators.regexp_replace_op, ) + else: + flags_expr = None return _binary_operate( - expr, op, pattern, replacement=replacement, flags=flags, **kw + expr, op, pattern, replacement=replacement, flags=flags_expr, **kw ) # a mapping of operators with the method they use, along with # additional keyword arguments to be passed operator_lookup: Dict[ - str, Tuple[Callable[..., ColumnElement[Any]], util.immutabledict] + str, + Tuple[ + Callable[..., ColumnElement[Any]], + util.immutabledict[ + str, Union[OperatorType, Callable[..., ColumnElement[Any]]] + ], + ], ] = { "and_": (_conjunction_operate, util.EMPTY_DICT), "or_": (_conjunction_operate, util.EMPTY_DICT), diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 1271c5977c..10316dd2bb 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -12,11 +12,13 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and from __future__ import annotations import collections.abc as collections_abc +import operator import typing from typing import Any from typing import List from typing import MutableMapping from typing import Optional +from typing import TYPE_CHECKING from . import coercions from . import roles @@ -36,10 +38,29 @@ from .elements import Null from .selectable import HasCTE from .selectable import HasPrefixes from .selectable import ReturnsRows +from .selectable import TableClause from .sqltypes import NullType from .visitors import InternalTraversal from .. import exc from .. import util +from ..util.typing import TypeGuard + + +if TYPE_CHECKING: + + def isupdate(dml) -> TypeGuard[UpdateDMLState]: + ... + + def isdelete(dml) -> TypeGuard[DeleteDMLState]: + ... + + def isinsert(dml) -> TypeGuard[InsertDMLState]: + ... + +else: + isupdate = operator.attrgetter("isupdate") + isdelete = operator.attrgetter("isdelete") + isinsert = operator.attrgetter("isinsert") class DMLState(CompileState): @@ -49,6 +70,7 @@ class DMLState(CompileState): _ordered_values = None _parameter_ordering = None _has_multi_parameters = False + isupdate = False isdelete = False isinsert = False @@ -237,6 +259,8 @@ class UpdateBase( _hints = util.immutabledict() named_with_column = False + table: TableClause + _return_defaults = False _return_defaults_columns = None _returning = () diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 48c3c3be66..691eb10ec4 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -18,6 +18,7 @@ import itertools import operator import re import typing +from typing import AbstractSet from typing import Any from typing import Callable from typing import cast @@ -83,6 +84,7 @@ if typing.TYPE_CHECKING: from .operators import OperatorType from .schema import Column from .schema import DefaultGenerator + from .schema import FetchedValue from .schema import ForeignKey from .selectable import FromClause from .selectable import NamedFromClause @@ -290,7 +292,7 @@ class ClauseElement( """ - @util.memoized_property + @util.ro_memoized_property def description(self) -> Optional[str]: return None @@ -319,7 +321,7 @@ class ClauseElement( _cache_key_traversal = None - negation_clause: ClauseElement + negation_clause: ColumnElement[bool] if typing.TYPE_CHECKING: @@ -1153,9 +1155,7 @@ class ColumnElement( primary_key: bool = False _is_clone_of: Optional[ColumnElement[_T]] - @util.memoized_property - def foreign_keys(self) -> Iterable[ForeignKey]: - return [] + foreign_keys: AbstractSet[ForeignKey] = frozenset() @util.memoized_property def _proxies(self) -> List[ColumnElement[Any]]: @@ -1494,6 +1494,8 @@ class ColumnElement( else: key = name + assert key is not None + co: ColumnClause[_T] = ColumnClause( coercions.expect(roles.TruncatedLabelRole, name) if name_is_truncatable @@ -1506,7 +1508,6 @@ class ColumnElement( co._proxies = [self] if selectable._is_clone_of is not None: co._is_clone_of = selectable._is_clone_of.columns.get(key) - assert key is not None return key, co def cast(self, type_: TypeEngine[_T]) -> Cast[_T]: @@ -4050,13 +4051,14 @@ class NamedColumn(ColumnElement[_T]): is_literal = False table: Optional[FromClause] = None name: str + key: str def _compare_name_for_result(self, other): return (hasattr(other, "name") and self.name == other.name) or ( hasattr(other, "_label") and self._label == other._label ) - @util.memoized_property + @util.ro_memoized_property def description(self) -> str: return self.name @@ -4125,6 +4127,7 @@ class NamedColumn(ColumnElement[_T]): _selectable=selectable, is_literal=False, ) + c._propagate_attrs = selectable._propagate_attrs if name is None: c.key = self.key @@ -4192,8 +4195,8 @@ class ColumnClause( onupdate: Optional[DefaultGenerator] = None default: Optional[DefaultGenerator] = None - server_default: Optional[DefaultGenerator] = None - server_onupdate: Optional[DefaultGenerator] = None + server_default: Optional[FetchedValue] = None + server_onupdate: Optional[FetchedValue] = None _is_multiparam_column = False diff --git a/lib/sqlalchemy/sql/events.py b/lib/sqlalchemy/sql/events.py index 1a1fc4c417..0d74e2e4c1 100644 --- a/lib/sqlalchemy/sql/events.py +++ b/lib/sqlalchemy/sql/events.py @@ -7,11 +7,23 @@ from __future__ import annotations +from typing import Any +from typing import TYPE_CHECKING + from .base import SchemaEventTarget from .. import event +if TYPE_CHECKING: + from .schema import Column + from .schema import Constraint + from .schema import SchemaItem + from .schema import Table + from ..engine.base import Connection + from ..engine.interfaces import ReflectedColumn + from ..engine.reflection import Inspector + -class DDLEvents(event.Events): +class DDLEvents(event.Events[SchemaEventTarget]): """ Define event listeners for schema objects, that is, :class:`.SchemaItem` and other :class:`.SchemaEventTarget` @@ -93,7 +105,9 @@ class DDLEvents(event.Events): _target_class_doc = "SomeSchemaClassOrObject" _dispatch_target = SchemaEventTarget - def before_create(self, target, connection, **kw): + def before_create( + self, target: SchemaEventTarget, connection: Connection, **kw: Any + ) -> None: r"""Called before CREATE statements are emitted. :param target: the :class:`_schema.MetaData` or :class:`_schema.Table` @@ -120,7 +134,9 @@ class DDLEvents(event.Events): """ - def after_create(self, target, connection, **kw): + def after_create( + self, target: SchemaEventTarget, connection: Connection, **kw: Any + ) -> None: r"""Called after CREATE statements are emitted. :param target: the :class:`_schema.MetaData` or :class:`_schema.Table` @@ -142,7 +158,9 @@ class DDLEvents(event.Events): """ - def before_drop(self, target, connection, **kw): + def before_drop( + self, target: SchemaEventTarget, connection: Connection, **kw: Any + ) -> None: r"""Called before DROP statements are emitted. :param target: the :class:`_schema.MetaData` or :class:`_schema.Table` @@ -164,7 +182,9 @@ class DDLEvents(event.Events): """ - def after_drop(self, target, connection, **kw): + def after_drop( + self, target: SchemaEventTarget, connection: Connection, **kw: Any + ) -> None: r"""Called after DROP statements are emitted. :param target: the :class:`_schema.MetaData` or :class:`_schema.Table` @@ -186,7 +206,9 @@ class DDLEvents(event.Events): """ - def before_parent_attach(self, target, parent): + def before_parent_attach( + self, target: SchemaEventTarget, parent: SchemaItem + ) -> None: """Called before a :class:`.SchemaItem` is associated with a parent :class:`.SchemaItem`. @@ -201,7 +223,9 @@ class DDLEvents(event.Events): """ - def after_parent_attach(self, target, parent): + def after_parent_attach( + self, target: SchemaEventTarget, parent: SchemaItem + ) -> None: """Called after a :class:`.SchemaItem` is associated with a parent :class:`.SchemaItem`. @@ -216,13 +240,17 @@ class DDLEvents(event.Events): """ - def _sa_event_column_added_to_pk_constraint(self, const, col): + def _sa_event_column_added_to_pk_constraint( + self, const: Constraint, col: Column[Any] + ) -> None: """internal event hook used for primary key naming convention updates. """ - def column_reflect(self, inspector, table, column_info): + def column_reflect( + self, inspector: Inspector, table: Table, column_info: ReflectedColumn + ) -> None: """Called for each unit of 'column info' retrieved when a :class:`_schema.Table` is being reflected. diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 36ddbf309b..455e74f7b0 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -43,7 +43,6 @@ from ._elements_constructors import text as text from ._elements_constructors import true as true from ._elements_constructors import tuple_ as tuple_ from ._elements_constructors import type_coerce as type_coerce -from ._elements_constructors import typing as typing from ._elements_constructors import within_group as within_group from ._selectable_constructors import alias as alias from ._selectable_constructors import cte as cte diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 4c4f49aa84..beb73c1b50 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -211,9 +211,11 @@ class StrictFromClauseRole(FromClauseRole): __slots__ = () # does not allow text() or select() objects - c: ColumnCollection + c: ColumnCollection[Any] - @property + # this should be ->str , however, working around: + # https://github.com/python/mypy/issues/12440 + @util.ro_non_memoized_property def description(self) -> str: raise NotImplementedError() diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 5cfb55603f..540b62e8aa 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -30,16 +30,22 @@ as components in SQL expressions. """ from __future__ import annotations +from abc import ABC import collections +import operator import typing from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import MutableMapping from typing import Optional from typing import overload from typing import Sequence as _typing_Sequence +from typing import Set +from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -48,6 +54,7 @@ from . import ddl from . import roles from . import type_api from . import visitors +from .base import ColumnCollection from .base import DedupeColumnCollection from .base import DialectKWArgs from .base import Executable @@ -67,12 +74,15 @@ from .. import exc from .. import inspection from .. import util from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import TypeGuard if typing.TYPE_CHECKING: from .type_api import TypeEngine from ..engine import Connection from ..engine import Engine - + from ..engine.interfaces import ExecutionContext + from ..engine.mock import MockConnection _T = TypeVar("_T", bound="Any") _ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement] _TAB = TypeVar("_TAB", bound="Table") @@ -102,7 +112,7 @@ NULL_UNSPECIFIED = util.symbol( ) -def _get_table_key(name, schema): +def _get_table_key(name: str, schema: Optional[str]) -> str: if schema is None: return name else: @@ -207,7 +217,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): __visit_name__ = "table" - constraints = None + constraints: Set[Constraint] """A collection of all :class:`_schema.Constraint` objects associated with this :class:`_schema.Table`. @@ -235,7 +245,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): """ - indexes = None + indexes: Set[Index] """A collection of all :class:`_schema.Index` objects associated with this :class:`_schema.Table`. @@ -249,6 +259,14 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): ("schema", InternalTraversal.dp_string) ] + if TYPE_CHECKING: + + @util.non_memoized_property + def columns(self) -> ColumnCollection[Column[Any]]: + ... + + c: ColumnCollection[Column[Any]] + def _gen_cache_key(self, anon_map, bindparams): if self._annotations: return (self,) + self._annotations_cache_key @@ -736,11 +754,12 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): ) @property - def _sorted_constraints(self): + def _sorted_constraints(self) -> List[Constraint]: """Return the set of constraints as a list, sorted by creation order. """ + return sorted(self.constraints, key=lambda c: c._creation_order) @property @@ -801,6 +820,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): ) self.info = kwargs.pop("info", self.info) + exclude_columns: _typing_Sequence[str] + if autoload: if not autoload_replace: # don't replace columns already present. @@ -1074,8 +1095,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): return metadata.tables[key] args = [] - for c in self.columns: - args.append(c._copy(schema=schema)) + for col in self.columns: + args.append(col._copy(schema=schema)) table = Table( name, metadata, @@ -1084,28 +1105,30 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): *args, **self.kwargs, ) - for c in self.constraints: - if isinstance(c, ForeignKeyConstraint): - referred_schema = c._referred_schema + for const in self.constraints: + if isinstance(const, ForeignKeyConstraint): + referred_schema = const._referred_schema if referred_schema_fn: fk_constraint_schema = referred_schema_fn( - self, schema, c, referred_schema + self, schema, const, referred_schema ) else: fk_constraint_schema = ( schema if referred_schema == self.schema else None ) table.append_constraint( - c._copy(schema=fk_constraint_schema, target_table=table) + const._copy( + schema=fk_constraint_schema, target_table=table + ) ) - elif not c._type_bound: + elif not const._type_bound: # skip unique constraints that would be generated # by the 'unique' flag on Column - if c._column_flag: + if const._column_flag: continue table.append_constraint( - c._copy(schema=schema, target_table=table) + const._copy(schema=schema, target_table=table) ) for index in self.indexes: # skip indexes that would be generated @@ -1734,23 +1757,25 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): name = kwargs.pop("name", None) type_ = kwargs.pop("type_", None) - args = list(args) - if args: - if isinstance(args[0], str): + l_args = list(args) + del args + + if l_args: + if isinstance(l_args[0], str): if name is not None: raise exc.ArgumentError( "May not pass name positionally and as a keyword." ) - name = args.pop(0) - if args: - coltype = args[0] + name = l_args.pop(0) + if l_args: + coltype = l_args[0] if hasattr(coltype, "_sqla_type"): if type_ is not None: raise exc.ArgumentError( "May not pass type_ positionally and as a keyword." ) - type_ = args.pop(0) + type_ = l_args.pop(0) if name is not None: name = quoted_name(name, kwargs.pop("quote", None)) @@ -1772,7 +1797,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): else: self.nullable = not primary_key - self.default = kwargs.pop("default", None) + default = kwargs.pop("default", None) + onupdate = kwargs.pop("onupdate", None) + self.server_default = kwargs.pop("server_default", None) self.server_onupdate = kwargs.pop("server_onupdate", None) @@ -1784,7 +1811,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): self.system = kwargs.pop("system", False) self.doc = kwargs.pop("doc", None) - self.onupdate = kwargs.pop("onupdate", None) self.autoincrement = kwargs.pop("autoincrement", "auto") self.constraints = set() self.foreign_keys = set() @@ -1803,32 +1829,38 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): if isinstance(impl, SchemaEventTarget): impl._set_parent_with_dispatch(self) - if self.default is not None: - if isinstance(self.default, (ColumnDefault, Sequence)): - args.append(self.default) - else: - args.append(ColumnDefault(self.default)) + if default is not None: + if not isinstance(default, (ColumnDefault, Sequence)): + default = ColumnDefault(default) + + self.default = default + l_args.append(default) + else: + self.default = None + + if onupdate is not None: + if not isinstance(onupdate, (ColumnDefault, Sequence)): + onupdate = ColumnDefault(onupdate, for_update=True) + + self.onupdate = onupdate + l_args.append(onupdate) + else: + self.onpudate = None if self.server_default is not None: if isinstance(self.server_default, FetchedValue): - args.append(self.server_default._as_for_update(False)) + l_args.append(self.server_default._as_for_update(False)) else: - args.append(DefaultClause(self.server_default)) - - if self.onupdate is not None: - if isinstance(self.onupdate, (ColumnDefault, Sequence)): - args.append(self.onupdate) - else: - args.append(ColumnDefault(self.onupdate, for_update=True)) + l_args.append(DefaultClause(self.server_default)) if self.server_onupdate is not None: if isinstance(self.server_onupdate, FetchedValue): - args.append(self.server_onupdate._as_for_update(True)) + l_args.append(self.server_onupdate._as_for_update(True)) else: - args.append( + l_args.append( DefaultClause(self.server_onupdate, for_update=True) ) - self._init_items(*args) + self._init_items(*l_args) util.set_creation_order(self) @@ -1837,7 +1869,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): self._extra_kwargs(**kwargs) - foreign_keys = None + table: Table + + constraints: Set[Constraint] + + foreign_keys: Set[ForeignKey] """A collection of all :class:`_schema.ForeignKey` marker objects associated with this :class:`_schema.Column`. @@ -1850,7 +1886,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): """ - index = None + index: bool """The value of the :paramref:`_schema.Column.index` parameter. Does not indicate if this :class:`_schema.Column` is actually indexed @@ -1861,7 +1897,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): :attr:`_schema.Table.indexes` """ - unique = None + unique: bool """The value of the :paramref:`_schema.Column.unique` parameter. Does not indicate if this :class:`_schema.Column` is actually subject to @@ -2074,8 +2110,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): server_default = self.server_default server_onupdate = self.server_onupdate if isinstance(server_default, (Computed, Identity)): + args.append(server_default._copy(**kw)) server_default = server_onupdate = None - args.append(self.server_default._copy(**kw)) type_ = self.type if isinstance(type_, SchemaEventTarget): @@ -2203,9 +2239,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): __visit_name__ = "foreign_key" + parent: Column[Any] + def __init__( self, - column: Union[str, Column, SQLCoreOperations], + column: Union[str, Column[Any], SQLCoreOperations[Any]], _constraint: Optional["ForeignKeyConstraint"] = None, use_alter: bool = False, name: Optional[str] = None, @@ -2296,7 +2334,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): self._table_column = self._colspec if not isinstance( - self._table_column.table, (util.NoneType, TableClause) + self._table_column.table, (type(None), TableClause) ): raise exc.ArgumentError( "ForeignKey received Column not bound " @@ -2309,7 +2347,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): # object passes itself in when creating ForeignKey # markers. self.constraint = _constraint - self.parent = None + + # .parent is not Optional under normal use + self.parent = None # type: ignore + self.use_alter = use_alter self.name = name self.onupdate = onupdate @@ -2501,19 +2542,18 @@ class ForeignKey(DialectKWArgs, SchemaItem): return parenttable, tablekey, colname def _link_to_col_by_colstring(self, parenttable, table, colname): - if not hasattr(self.constraint, "_referred_table"): - self.constraint._referred_table = table - else: - assert self.constraint._referred_table is table - _column = None if colname is None: # colname is None in the case that ForeignKey argument # was specified as table name only, in which case we # match the column name to the same column on the # parent. - key = self.parent - _column = table.c.get(self.parent.key, None) + # this use case wasn't working in later 1.x series + # as it had no test coverage; fixed in 2.0 + parent = self.parent + assert parent is not None + key = parent.key + _column = table.c.get(key, None) elif self.link_to_name: key = colname for c in table.c: @@ -2533,10 +2573,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): key, ) - self._set_target_column(_column) + return _column def _set_target_column(self, column): - assert isinstance(self.parent.table, Table) + assert self.parent is not None # propagate TypeEngine to parent if it didn't have one if self.parent.type._isnull: @@ -2561,11 +2601,6 @@ class ForeignKey(DialectKWArgs, SchemaItem): If no target column has been established, an exception is raised. - .. versionchanged:: 0.9.0 - Foreign key target column resolution now occurs as soon as both - the ForeignKey object and the remote Column to which it refers - are both associated with the same MetaData object. - """ if isinstance(self._colspec, str): @@ -2586,14 +2621,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): "parent MetaData" % parenttable ) else: - raise exc.NoReferencedColumnError( - "Could not initialize target column for " - "ForeignKey '%s' on table '%s': " - "table '%s' has no column named '%s'" - % (self._colspec, parenttable.name, tablekey, colname), - tablekey, - colname, + table = parenttable.metadata.tables[tablekey] + return self._link_to_col_by_colstring( + parenttable, table, colname ) + elif hasattr(self._colspec, "__clause_element__"): _column = self._colspec.__clause_element__() return _column @@ -2601,18 +2633,22 @@ class ForeignKey(DialectKWArgs, SchemaItem): _column = self._colspec return _column - def _set_parent(self, column, **kw): - if self.parent is not None and self.parent is not column: + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + assert isinstance(parent, Column) + + if self.parent is not None and self.parent is not parent: raise exc.InvalidRequestError( "This ForeignKey already has a parent !" ) - self.parent = column + self.parent = parent self.parent.foreign_keys.add(self) self.parent._on_table_attach(self._set_table) def _set_remote_table(self, table): - parenttable, tablekey, colname = self._resolve_col_tokens() - self._link_to_col_by_colstring(parenttable, table, colname) + parenttable, _, colname = self._resolve_col_tokens() + _column = self._link_to_col_by_colstring(parenttable, table, colname) + self._set_target_column(_column) + assert self.constraint is not None self.constraint._validate_dest_table(table) def _remove_from_metadata(self, metadata): @@ -2651,10 +2687,15 @@ class ForeignKey(DialectKWArgs, SchemaItem): if table_key in parenttable.metadata.tables: table = parenttable.metadata.tables[table_key] try: - self._link_to_col_by_colstring(parenttable, table, colname) + _column = self._link_to_col_by_colstring( + parenttable, table, colname + ) except exc.NoReferencedColumnError: # this is OK, we'll try later pass + else: + self._set_target_column(_column) + parenttable.metadata._fk_memos[fk_key].append(self) elif hasattr(self._colspec, "__clause_element__"): _column = self._colspec.__clause_element__() @@ -2664,6 +2705,31 @@ class ForeignKey(DialectKWArgs, SchemaItem): self._set_target_column(_column) +if TYPE_CHECKING: + + def default_is_sequence( + obj: Optional[DefaultGenerator], + ) -> TypeGuard[Sequence]: + ... + + def default_is_clause_element( + obj: Optional[DefaultGenerator], + ) -> TypeGuard[ColumnElementColumnDefault]: + ... + + def default_is_scalar( + obj: Optional[DefaultGenerator], + ) -> TypeGuard[ScalarElementColumnDefault]: + ... + +else: + default_is_sequence = operator.attrgetter("is_sequence") + + default_is_clause_element = operator.attrgetter("is_clause_element") + + default_is_scalar = operator.attrgetter("is_scalar") + + class DefaultGenerator(Executable, SchemaItem): """Base class for column *default* values.""" @@ -2671,18 +2737,18 @@ class DefaultGenerator(Executable, SchemaItem): is_sequence = False is_server_default = False + is_clause_element = False + is_callable = False is_scalar = False - column = None + column: Optional[Column[Any]] def __init__(self, for_update=False): self.for_update = for_update - @util.memoized_property - def is_callable(self): - raise NotImplementedError() - - def _set_parent(self, column, **kw): - self.column = column + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + if TYPE_CHECKING: + assert isinstance(parent, Column) + self.column = parent if self.for_update: self.column.onupdate = self else: @@ -2696,7 +2762,7 @@ class DefaultGenerator(Executable, SchemaItem): ) -class ColumnDefault(DefaultGenerator): +class ColumnDefault(DefaultGenerator, ABC): """A plain default value on a column. This could correspond to a constant, a callable function, @@ -2718,7 +2784,30 @@ class ColumnDefault(DefaultGenerator): """ - def __init__(self, arg, **kwargs): + arg: Any + + @overload + def __new__( + cls, arg: Callable[..., Any], for_update: bool = ... + ) -> CallableColumnDefault: + ... + + @overload + def __new__( + cls, arg: ColumnElement[Any], for_update: bool = ... + ) -> ColumnElementColumnDefault: + ... + + # if I return ScalarElementColumnDefault here, which is what's actually + # returned, mypy complains that + # overloads overlap w/ incompatible return types. + @overload + def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault: + ... + + def __new__( + cls, arg: Any = None, for_update: bool = False + ) -> ColumnDefault: """Construct a new :class:`.ColumnDefault`. @@ -2744,70 +2833,121 @@ class ColumnDefault(DefaultGenerator): statement and parameters. """ - super(ColumnDefault, self).__init__(**kwargs) + if isinstance(arg, FetchedValue): raise exc.ArgumentError( "ColumnDefault may not be a server-side default type." ) - if callable(arg): - arg = self._maybe_wrap_callable(arg) + elif callable(arg): + cls = CallableColumnDefault + elif isinstance(arg, ClauseElement): + cls = ColumnElementColumnDefault + elif arg is not None: + cls = ScalarElementColumnDefault + + return object.__new__(cls) + + def __repr__(self): + return f"{self.__class__.__name__}({self.arg!r})" + + +class ScalarElementColumnDefault(ColumnDefault): + """default generator for a fixed scalar Python value + + .. versionadded: 2.0 + + """ + + is_scalar = True + + def __init__(self, arg: Any, for_update: bool = False): + self.for_update = for_update self.arg = arg - @util.memoized_property - def is_callable(self): - return callable(self.arg) - @util.memoized_property - def is_clause_element(self): - return isinstance(self.arg, ClauseElement) +# _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"] +_SQLExprDefault = Union["ColumnElement[Any]", "TextClause"] - @util.memoized_property - def is_scalar(self): - return ( - not self.is_callable - and not self.is_clause_element - and not self.is_sequence - ) + +class ColumnElementColumnDefault(ColumnDefault): + """default generator for a SQL expression + + .. versionadded:: 2.0 + + """ + + is_clause_element = True + + arg: _SQLExprDefault + + def __init__( + self, + arg: _SQLExprDefault, + for_update: bool = False, + ): + self.for_update = for_update + self.arg = arg @util.memoized_property @util.preload_module("sqlalchemy.sql.sqltypes") def _arg_is_typed(self): sqltypes = util.preloaded.sql_sqltypes - if self.is_clause_element: - return not isinstance(self.arg.type, sqltypes.NullType) - else: - return False + return not isinstance(self.arg.type, sqltypes.NullType) + + +class _CallableColumnDefaultProtocol(Protocol): + def __call__(self, context: ExecutionContext) -> Any: + ... - def _maybe_wrap_callable(self, fn): + +class CallableColumnDefault(ColumnDefault): + """default generator for a callable Python function + + .. versionadded:: 2.0 + + """ + + is_callable = True + arg: _CallableColumnDefaultProtocol + + def __init__( + self, + arg: Union[_CallableColumnDefaultProtocol, Callable[[], Any]], + for_update: bool = False, + ): + self.for_update = for_update + self.arg = self._maybe_wrap_callable(arg) + + def _maybe_wrap_callable( + self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]] + ) -> _CallableColumnDefaultProtocol: """Wrap callables that don't accept a context. This is to allow easy compatibility with default callables that aren't specific to accepting of a context. """ + try: argspec = util.get_callable_argspec(fn, no_self=True) except TypeError: - return util.wrap_callable(lambda ctx: fn(), fn) + return util.wrap_callable(lambda ctx: fn(), fn) # type: ignore defaulted = argspec[3] is not None and len(argspec[3]) or 0 positionals = len(argspec[0]) - defaulted if positionals == 0: - return util.wrap_callable(lambda ctx: fn(), fn) + return util.wrap_callable(lambda ctx: fn(), fn) # type: ignore elif positionals == 1: - return fn + return fn # type: ignore else: raise exc.ArgumentError( "ColumnDefault Python function takes zero or one " "positional arguments" ) - def __repr__(self): - return "ColumnDefault(%r)" % (self.arg,) - class IdentityOptions: """Defines options for a named database sequence or an identity column. @@ -2899,6 +3039,8 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): is_sequence = True + column: Optional[Column[Any]] = None + def __init__( self, name, @@ -3087,14 +3229,6 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): else: self.data_type = None - @util.memoized_property - def is_callable(self): - return False - - @util.memoized_property - def is_clause_element(self): - return False - @util.preload_module("sqlalchemy.sql.functions") def next_value(self): """Return a :class:`.next_value` function element @@ -3235,6 +3369,9 @@ class Constraint(DialectKWArgs, SchemaItem): __visit_name__ = "constraint" + _creation_order: int + _column_flag: bool + def __init__( self, name=None, @@ -3316,8 +3453,6 @@ class Constraint(DialectKWArgs, SchemaItem): class ColumnCollectionMixin: - - columns = None """A :class:`_expression.ColumnCollection` of :class:`_schema.Column` objects. @@ -3326,8 +3461,17 @@ class ColumnCollectionMixin: """ + columns: ColumnCollection[Column[Any]] + _allow_multiple_tables = False + if TYPE_CHECKING: + + def _set_parent_with_dispatch( + self, parent: SchemaEventTarget, **kw: Any + ) -> None: + ... + def __init__(self, *columns, **kw): _autoattach = kw.pop("_autoattach", True) self._column_flag = kw.pop("_column_flag", False) @@ -3404,14 +3548,16 @@ class ColumnCollectionMixin: ) ) - def _col_expressions(self, table): + def _col_expressions(self, table: Table) -> List[Column[Any]]: return [ table.c[col] if isinstance(col, str) else col for col in self._pending_colargs ] - def _set_parent(self, table, **kw): - for col in self._col_expressions(table): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + if TYPE_CHECKING: + assert isinstance(parent, Table) + for col in self._col_expressions(parent): if col is not None: self.columns.add(col) @@ -3446,7 +3592,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): self, *columns, _autoattach=_autoattach, _column_flag=_column_flag ) - columns = None + columns: DedupeColumnCollection[Column[Any]] """A :class:`_expression.ColumnCollection` representing the set of columns for this constraint. @@ -3568,7 +3714,7 @@ class CheckConstraint(ColumnCollectionConstraint): """ self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext) - columns = [] + columns: List[Column[Any]] = [] visitors.traverse(self.sqltext, {}, {"column": columns.append}) super(CheckConstraint, self).__init__( @@ -3779,17 +3925,17 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): assert table is self.parent self._set_parent_with_dispatch(table) - def _append_element(self, column, fk): + def _append_element(self, column: Column[Any], fk: ForeignKey) -> None: self.columns.add(column) self.elements.append(fk) - columns = None + columns: DedupeColumnCollection[Column[Any]] """A :class:`_expression.ColumnCollection` representing the set of columns for this constraint. """ - elements = None + elements: List[ForeignKey] """A sequence of :class:`_schema.ForeignKey` objects. Each :class:`_schema.ForeignKey` @@ -4271,7 +4417,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): self._validate_dialect_kwargs(kw) - self.expressions = [] + self.expressions: List[ColumnElement[Any]] = [] # will call _set_parent() if table-bound column # objects are present ColumnCollectionMixin.__init__( @@ -4501,11 +4647,13 @@ class MetaData(HasSchemaAttr): ) if info: self.info = info - self._schemas = set() - self._sequences = {} - self._fk_memos = collections.defaultdict(list) + self._schemas: Set[str] = set() + self._sequences: Dict[str, Sequence] = {} + self._fk_memos: Dict[ + Tuple[str, str], List[ForeignKey] + ] = collections.defaultdict(list) - tables: Dict[str, Table] + tables: util.FacadeDict[str, Table] """A dictionary of :class:`_schema.Table` objects keyed to their name or "table key". @@ -4539,7 +4687,7 @@ class MetaData(HasSchemaAttr): def _remove_table(self, name, schema): key = _get_table_key(name, schema) - removed = dict.pop(self.tables, key, None) + removed = dict.pop(self.tables, key, None) # type: ignore if removed is not None: for fk in removed.foreign_keys: fk._remove_from_metadata(self) @@ -4634,12 +4782,12 @@ class MetaData(HasSchemaAttr): """ return ddl.sort_tables( - sorted(self.tables.values(), key=lambda t: t.key) + sorted(self.tables.values(), key=lambda t: t.key) # type: ignore ) def reflect( self, - bind: Union["Engine", "Connection"], + bind: Union[Engine, Connection], schema: Optional[str] = None, views: bool = False, only: Optional[_typing_Sequence[str]] = None, @@ -4647,7 +4795,7 @@ class MetaData(HasSchemaAttr): autoload_replace: bool = True, resolve_fks: bool = True, **dialect_kwargs: Any, - ): + ) -> None: r"""Load all available table definitions from the database. Automatically creates ``Table`` entries in this ``MetaData`` for any @@ -4748,12 +4896,14 @@ class MetaData(HasSchemaAttr): if schema is not None: reflect_opts["schema"] = schema - available = util.OrderedSet(insp.get_table_names(schema)) + available: util.OrderedSet[str] = util.OrderedSet( + insp.get_table_names(schema) + ) if views: available.update(insp.get_view_names(schema)) if schema is not None: - available_w_schema = util.OrderedSet( + available_w_schema: util.OrderedSet[str] = util.OrderedSet( ["%s.%s" % (schema, name) for name in available] ) else: @@ -4796,10 +4946,10 @@ class MetaData(HasSchemaAttr): def create_all( self, - bind: Union["Engine", "Connection"], + bind: Union[Engine, Connection, MockConnection], tables: Optional[_typing_Sequence[Table]] = None, checkfirst: bool = True, - ): + ) -> None: """Create all tables stored in this metadata. Conditional by default, will not attempt to recreate tables already @@ -4824,10 +4974,10 @@ class MetaData(HasSchemaAttr): def drop_all( self, - bind: Union["Engine", "Connection"], + bind: Union[Engine, Connection, MockConnection], tables: Optional[_typing_Sequence[Table]] = None, checkfirst: bool = True, - ): + ) -> None: """Drop all tables stored in this metadata. Conditional by default, will not attempt to drop tables not present in diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e143d14761..8665a74db6 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -463,7 +463,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _is_clone_of: Optional[FromClause] - schema = None + schema: Optional[str] = None """Define the 'schema' attribute for this :class:`_expression.FromClause`. This is typically ``None`` for most objects except that of @@ -673,7 +673,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return self._cloned_set.intersection(other._cloned_set) - @property + @util.non_memoized_property def description(self) -> str: """A brief description of this :class:`_expression.FromClause`. @@ -710,7 +710,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.columns @util.memoized_property - def columns(self) -> ColumnCollection: + def columns(self) -> ColumnCollection[Any]: """A named-based collection of :class:`_expression.ColumnElement` objects maintained by this :class:`_expression.FromClause`. @@ -796,7 +796,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): # this is awkward. maybe there's a better way if TYPE_CHECKING: - c: ColumnCollection + c: ColumnCollection[Any] else: c = property( attrgetter("columns"), @@ -2399,6 +2399,8 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): _is_table = True + fullname: str + implicit_returning = False """:class:`_expression.TableClause` doesn't support having a primary key or column diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 829c1b72e7..1a6de34b03 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -345,6 +345,12 @@ class Integer(HasExpressionLookup, TypeEngine[int]): __visit_name__ = "integer" + if TYPE_CHECKING: + + @util.ro_memoized_property + def _type_affinity(self) -> Type[Integer]: + ... + def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -1892,8 +1898,8 @@ class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]): operators.truediv: {Numeric: self.__class__}, } - @util.non_memoized_property - def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]: + @util.ro_non_memoized_property + def _type_affinity(self) -> Type[Interval]: return Interval diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 5a0aba694f..9a934a50bc 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -705,7 +705,7 @@ class TypeEngine(Visitable, Generic[_T]): """ return self - @util.memoized_property + @util.ro_memoized_property def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]: """Return a rudimental 'affinity' value expressing the general class of type.""" @@ -719,7 +719,7 @@ class TypeEngine(Visitable, Generic[_T]): else: return self.__class__ - @util.memoized_property + @util.ro_memoized_property def _generic_type_affinity( self, ) -> Type[TypeEngine[_T]]: @@ -1694,7 +1694,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): tt.impl = tt.impl_instance = typedesc return tt - @util.non_memoized_property + @util.ro_non_memoized_property def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]: return self.impl_instance._type_affinity diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index b40d22f8ed..7e616cd74f 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -130,6 +130,8 @@ from .langhelpers import ( from .langhelpers import PluginLoader as PluginLoader from .langhelpers import portable_instancemethod as portable_instancemethod from .langhelpers import quoted_token_parser as quoted_token_parser +from .langhelpers import ro_memoized_property as ro_memoized_property +from .langhelpers import ro_non_memoized_property as ro_non_memoized_property from .langhelpers import safe_reraise as safe_reraise from .langhelpers import set_creation_order as set_creation_order from .langhelpers import string_or_unprintable as string_or_unprintable diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 85ae1b65df..2d974b7372 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -135,7 +135,7 @@ def coerce_to_immutabledict(d): EMPTY_DICT: immutabledict[Any, Any] = immutabledict() -class FacadeDict(ImmutableDictBase[Any, Any]): +class FacadeDict(ImmutableDictBase[_KT, _VT]): """A dictionary that is not publicly mutable.""" def __new__(cls, *args): diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 35294715cb..8cf50c724c 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -55,8 +55,8 @@ _T_co = TypeVar("_T_co", covariant=True) _F = TypeVar("_F", bound=Callable[..., Any]) _MP = TypeVar("_MP", bound="memoized_property[Any]") _MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]") -_HP = TypeVar("_HP", bound="hybridproperty") -_HM = TypeVar("_HM", bound="hybridmethod") +_HP = TypeVar("_HP", bound="hybridproperty[Any]") +_HM = TypeVar("_HM", bound="hybridmethod[Any]") if compat.py310: @@ -1234,12 +1234,23 @@ class _memoized_property(generic_fn_descriptor[_T_co]): # superclass has non-memoized, the class hierarchy of the descriptors # would need to be reversed; "class non_memoized(memoized)". so there's no # way to achieve this. +# additional issues, RO properties: +# https://github.com/python/mypy/issues/12440 if TYPE_CHECKING: + + # allow memoized and non-memoized to be freely mixed by having them + # be the same class memoized_property = generic_fn_descriptor non_memoized_property = generic_fn_descriptor + + # for read only situations, mypy only sees @property as read only. + # read only is needed when a subtype specializes the return type + # of a property, meaning assignment needs to be disallowed + ro_memoized_property = property + ro_non_memoized_property = property else: - memoized_property = _memoized_property - non_memoized_property = _non_memoized_property + memoized_property = ro_memoized_property = _memoized_property + non_memoized_property = ro_non_memoized_property = _non_memoized_property def memoized_instancemethod(fn: _F) -> _F: @@ -1515,7 +1526,9 @@ def duck_type_collection( return default -def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any: +def assert_arg_type( + arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str +) -> Any: if isinstance(arg, argtype): return arg else: @@ -1576,37 +1589,37 @@ class classproperty(property): return self.fget(cls) # type: ignore -class hybridproperty: - def __init__(self, func): +class hybridproperty(Generic[_T]): + def __init__(self, func: Callable[..., _T]): self.func = func self.clslevel = func - def __get__(self, instance, owner): + def __get__(self, instance: Any, owner: Any) -> _T: if instance is None: clsval = self.clslevel(owner) return clsval else: return self.func(instance) - def classlevel(self, func): + def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]: self.clslevel = func return self -class hybridmethod: +class hybridmethod(Generic[_T]): """Decorate a function as cls- or instance- level.""" - def __init__(self, func): + def __init__(self, func: Callable[..., _T]): self.func = self.__func__ = func self.clslevel = func - def __get__(self, instance, owner): + def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]: if instance is None: - return self.clslevel.__get__(owner, owner.__class__) + return self.clslevel.__get__(owner, owner.__class__) # type:ignore else: - return self.func.__get__(instance, owner) + return self.func.__get__(instance, owner) # type:ignore - def classlevel(self, func): + def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]: self.clslevel = func return self diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 03621a6018..0b9d8c62c7 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -34,8 +34,10 @@ else: if compat.py310: from typing import TypeGuard as TypeGuard + from typing import TypeAlias as TypeAlias else: from typing_extensions import TypeGuard as TypeGuard + from typing_extensions import TypeAlias as TypeAlias if typing.TYPE_CHECKING or compat.py38: from typing import SupportsIndex as SupportsIndex diff --git a/pyproject.toml b/pyproject.toml index 7117c2689f..acbc69537a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,57 +51,73 @@ reportTypedDictNotRequiredAccess = "warning" [tool.mypy] mypy_path = "./lib/" show_error_codes = true -strict = false +strict = true incremental = true -# disabled checking [[tool.mypy.overrides]] -module="sqlalchemy.*" -ignore_errors = true -warn_unused_ignores = false - -strict = true - -# some debate at -# https://github.com/python/mypy/issues/8754. -# implicit_reexport = true -# individual packages or even modules should be listed here -# with strictness-specificity set up. there's no way we are going to get -# the whole library 100% strictly typed, so we have to tune this based on -# the type of module or package we are dealing with - -[[tool.mypy.overrides]] # ad-hoc ignores module = [ "sqlalchemy.engine.reflection", # interim, should be strict + + # TODO for strict: + "sqlalchemy.ext.asyncio.*", + "sqlalchemy.ext.automap", + "sqlalchemy.ext.compiler", + "sqlalchemy.ext.declarative.*", + "sqlalchemy.ext.mutable", + "sqlalchemy.ext.horizontal_shard", + + "sqlalchemy.sql._selectable_constructors", + "sqlalchemy.sql._dml_constructors", + + # TODO for non-strict: + "sqlalchemy.ext.baked", + "sqlalchemy.ext.instrumentation", + "sqlalchemy.ext.indexable", + "sqlalchemy.ext.orderinglist", + "sqlalchemy.ext.serializer", + + "sqlalchemy.sql.selectable", # would be nice as strict + "sqlalchemy.sql.ddl", + "sqlalchemy.sql.functions", # would be nice as strict + "sqlalchemy.sql.lambdas", + "sqlalchemy.sql.dml", # would be nice as strict + "sqlalchemy.sql.util", + + # not yet classified: + "sqlalchemy.orm.*", + "sqlalchemy.dialects.*", + "sqlalchemy.cyextension.*", + "sqlalchemy.future.*", + "sqlalchemy.testing.*", + ] +warn_unused_ignores = false ignore_errors = true # strict checking [[tool.mypy.overrides]] + module = [ - "sqlalchemy.sql.annotation", - "sqlalchemy.sql.cache_key", - "sqlalchemy.sql._elements_constructors", - "sqlalchemy.sql.operators", - "sqlalchemy.sql.type_api", - "sqlalchemy.sql.roles", - "sqlalchemy.sql.visitors", - "sqlalchemy.sql._py_util", + # packages "sqlalchemy.connectors.*", + "sqlalchemy.event.*", + "sqlalchemy.ext.*", + "sqlalchemy.sql.*", "sqlalchemy.engine.*", - "sqlalchemy.ext.hybrid", - "sqlalchemy.ext.associationproxy", "sqlalchemy.pool.*", - "sqlalchemy.event.*", + + # modules "sqlalchemy.events", "sqlalchemy.exc", "sqlalchemy.inspection", "sqlalchemy.schema", "sqlalchemy.types", ] + +warn_unused_ignores = false ignore_errors = false strict = true @@ -109,20 +125,24 @@ strict = true [[tool.mypy.overrides]] module = [ - #"sqlalchemy.sql.*", - "sqlalchemy.sql.sqltypes", - "sqlalchemy.sql.elements", + "sqlalchemy.engine.cursor", + "sqlalchemy.engine.default", + + "sqlalchemy.sql.base", "sqlalchemy.sql.coercions", "sqlalchemy.sql.compiler", - #"sqlalchemy.sql.default_comparator", + "sqlalchemy.sql.crud", + "sqlalchemy.sql.elements", # would be nice as strict "sqlalchemy.sql.naming", + "sqlalchemy.sql.schema", # would be nice as strict + "sqlalchemy.sql.sqltypes", # would be nice as strict "sqlalchemy.sql.traversals", + "sqlalchemy.util.*", - "sqlalchemy.engine.cursor", - "sqlalchemy.engine.default", ] +warn_unused_ignores = false ignore_errors = false # mostly strict without requiring totally untyped things to be diff --git a/test/profiles.txt b/test/profiles.txt index ec1cba7241..074b649f2e 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -1,29 +1,29 @@ # /home/classic/dev/sqlalchemy/test/profiles.txt # This file is written out on a per-environment basis. -# For each test in aaa_profiling, the corresponding function and +# For each test in aaa_profiling, the corresponding function and # environment is located within this file. If it doesn't exist, # the test is skipped. -# If a callcount does exist, it is compared to what we received. +# If a callcount does exist, it is compared to what we received. # assertions are raised if the counts do not match. -# -# To add a new callcount test, apply the function_call_count -# decorator and re-run the tests using the --write-profiles +# +# To add a new callcount test, apply the function_call_count +# decorator and re-run the tests using the --write-profiles # option - this file will be rewritten including the new count. -# +# # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 72 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 70 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 70 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 75 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_select diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index f3260f2721..52c7799695 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -354,7 +354,8 @@ class DefaultObjectTest(fixtures.TestBase): assert_raises_message( sa.exc.ArgumentError, r"SQL expression for WHERE/HAVING role expected, " - r"got (?:Sequence|ColumnDefault|DefaultClause)\('y'.*\)", + r"got (?:Sequence|(?:ScalarElement)ColumnDefault|" + r"DefaultClause)\('y'.*\)", t.select().where, const, ) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 11c3e83b7d..21fc0a6272 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -760,7 +760,10 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): "%s" ", name='someconstraint')" % repr(ck.sqltext), ), - (ColumnDefault(("foo", "bar")), "ColumnDefault(('foo', 'bar'))"), + ( + ColumnDefault(("foo", "bar")), + "ScalarElementColumnDefault(('foo', 'bar'))", + ), ): eq_(repr(const), exp) @@ -916,6 +919,46 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables): a2 = a.to_metadata(m2) assert b2.c.y.references(a2.c.x) + def test_fk_w_no_colname(self): + """test a ForeignKey that refers to table name only. the column + name is assumed to be the same col name on parent table. + + this is a little used feature from long ago that nonetheless is + still in the code. + + The feature was found to be not working but is repaired for + SQLAlchemy 2.0. + + """ + m1 = MetaData() + a = Table("a", m1, Column("x", Integer)) + b = Table("b", m1, Column("x", Integer, ForeignKey("a"))) + assert b.c.x.references(a.c.x) + + m2 = MetaData() + b2 = b.to_metadata(m2) + a2 = a.to_metadata(m2) + assert b2.c.x.references(a2.c.x) + + def test_fk_w_no_colname_name_missing(self): + """test a ForeignKey that refers to table name only. the column + name is assumed to be the same col name on parent table. + + this is a little used feature from long ago that nonetheless is + still in the code. + + """ + m1 = MetaData() + a = Table("a", m1, Column("x", Integer)) + b = Table("b", m1, Column("y", Integer, ForeignKey("a"))) + + with expect_raises_message( + exc.NoReferencedColumnError, + "Could not initialize target column for ForeignKey 'a' on " + "table 'b': table 'a' has no column named 'y'", + ): + assert b.c.y.references(a.c.x) + def test_column_collection_constraint_w_ad_hoc_columns(self): """Test ColumnCollectionConstraint that has columns that aren't part of the Table.