]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep484 - SQL internals
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Mar 2022 20:39:36 +0000 (16:39 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 24 Mar 2022 20:57:30 +0000 (16:57 -0400)
non-strict checking for mostly internal or semi-internal
code

Change-Id: Ib91b47f1a8ccc15e666b94bad1ce78c4ab15b0ec

34 files changed:
lib/sqlalchemy/__init__.py
lib/sqlalchemy/cyextension/collections.pyx
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/mock.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/ext/mypy/util.py
lib/sqlalchemy/log.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/_selectable_constructors.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/events.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/profiles.txt
test/sql/test_defaults.py
test/sql/test_metadata.py

index de01a1b46104727ee29aa62d7ab75f976db7cf7d..4a6ae08b2530366934ed5c4422062a79a851971e 100644 (file)
@@ -7,6 +7,8 @@
 
 from __future__ import annotations
 
+from typing import Any
+
 from . import util as _util
 from .engine import AdaptedConnection as AdaptedConnection
 from .engine import BaseRow as BaseRow
@@ -191,7 +193,6 @@ from .sql.expression import tuple_ as tuple_
 from .sql.expression import type_coerce as type_coerce
 from .sql.expression import TypeClause as TypeClause
 from .sql.expression import TypeCoerce as TypeCoerce
-from .sql.expression import typing as typing
 from .sql.expression import UnaryExpression as UnaryExpression
 from .sql.expression import union as union
 from .sql.expression import union_all as union_all
@@ -254,7 +255,7 @@ from .types import VARCHAR as VARCHAR
 __version__ = "2.0.0b1"
 
 
-def __go(lcls):
+def __go(lcls: Any) -> None:
     from . import util as _sa_util
 
     _sa_util.preloaded.import_prefix("sqlalchemy")
index c33a6e4a508621c51ddb6e07838aed2352395f7c..fe2cb94ffeb6f4d8edb052fa5619d0df7e734e53 100644 (file)
@@ -26,6 +26,10 @@ cdef class OrderedSet(set):
 
     cdef list _list
 
+    @classmethod
+    def __class_getitem__(cls, key):
+        return cls
+
     def __init__(self, d=None):
         set.__init__(self)
         if d is not None:
index 061794bded92743f35980d430721020c7dbf8ad7..714ad3c85e41c17969b1da46be0e8e9bcc162720 100644 (file)
@@ -73,6 +73,7 @@ if typing.TYPE_CHECKING:
     from ..sql.functions import FunctionElement
     from ..sql.schema import ColumnDefault
     from ..sql.schema import HasSchemaAttr
+    from ..sql.schema import SchemaItem
 
 """Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
 
@@ -2004,7 +2005,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
     def _run_ddl_visitor(
         self,
         visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
-        element: DDLElement,
+        element: SchemaItem,
         **kwargs: Any,
     ) -> None:
         """run a DDL visitor.
@@ -2749,7 +2750,7 @@ class Engine(
     def _run_ddl_visitor(
         self,
         visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
-        element: DDLElement,
+        element: SchemaItem,
         **kwargs: Any,
     ) -> None:
         with self.begin() as conn:
index 4a833d2e549fb0804da38ef709362927cc149b48..d605af3efa063326f7b4ec351679945c77485ee2 100644 (file)
@@ -54,7 +54,9 @@ from ..sql import expression
 from ..sql._typing import is_tuple_type
 from ..sql.compiler import DDLCompiler
 from ..sql.compiler import SQLCompiler
+from ..sql.elements import ColumnClause
 from ..sql.elements import quoted_name
+from ..sql.schema import default_is_scalar
 
 if typing.TYPE_CHECKING:
     from types import ModuleType
@@ -1164,7 +1166,7 @@ class DefaultExecutionContext(ExecutionContext):
             return ()
 
     @util.memoized_property
-    def returning_cols(self) -> Optional[Sequence[Column[Any]]]:
+    def returning_cols(self) -> Optional[Sequence[ColumnClause[Any]]]:
         if TYPE_CHECKING:
             assert isinstance(self.compiled, SQLCompiler)
         return self.compiled.returning
@@ -1778,15 +1780,11 @@ class DefaultExecutionContext(ExecutionContext):
         # to avoid many calls of get_insert_default()/
         # get_update_default()
         for c in insert_prefetch:
-            if c.default and not c.default.is_sequence and c.default.is_scalar:
-                if TYPE_CHECKING:
-                    assert isinstance(c.default, ColumnDefault)
+            if c.default and default_is_scalar(c.default):
                 scalar_defaults[c] = c.default.arg
 
         for c in update_prefetch:
-            if c.onupdate and c.onupdate.is_scalar:
-                if TYPE_CHECKING:
-                    assert isinstance(c.onupdate, ColumnDefault)
+            if c.onupdate and default_is_scalar(c.onupdate):
                 scalar_defaults[c] = c.onupdate.arg
 
         for param in self.compiled_parameters:
@@ -1817,9 +1815,7 @@ class DefaultExecutionContext(ExecutionContext):
         ) = self.compiled_parameters[0]
 
         for c in compiled.insert_prefetch:
-            if c.default and not c.default.is_sequence and c.default.is_scalar:
-                if TYPE_CHECKING:
-                    assert isinstance(c.default, ColumnDefault)
+            if c.default and default_is_scalar(c.default):
                 val = c.default.arg
             else:
                 val = self.get_insert_default(c)
index a0ba96603977d1f20e288b59c4261c24fdf271e8..c94dd1032ef97aa75e5815e29408091da0b8d1fd 100644 (file)
@@ -32,6 +32,7 @@ if typing.TYPE_CHECKING:
     from ..sql.ddl import SchemaDropper
     from ..sql.ddl import SchemaGenerator
     from ..sql.schema import HasSchemaAttr
+    from ..sql.schema import SchemaItem
 
 
 class MockConnection:
@@ -55,7 +56,7 @@ class MockConnection:
     def _run_ddl_visitor(
         self,
         visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
-        element: DDLElement,
+        element: SchemaItem,
         **kwargs: Any,
     ) -> None:
         kwargs["checkfirst"] = False
index e1281365e9c196115a9be7bb15bc9c5226d88239..b8ece2b1d2bd53090aba86c64b0c56123cac6db7 100644 (file)
@@ -317,7 +317,6 @@ class Inspector(inspection.Inspectable["Inspector"]):
             with an already-given :class:`_schema.MetaData`.
 
         """
-
         with self._operation_context() as conn:
             tnames = self.dialect.get_table_names(
                 conn, schema, info_cache=self.info_cache
index 943c71b367ddf609086b0c12bc89c02250926a5e..7192675dfa86310970072a4781dbbd533a5e82ac 100644 (file)
@@ -88,7 +88,7 @@ class SQLAlchemyAttribute:
         return cls(typ=typ, info=info, **data)
 
 
-def name_is_dunder(name):
+def name_is_dunder(name: str) -> bool:
     return bool(re.match(r"^__.+?__$", name))
 
 
index 9a23d89d3f71be1cc9149ddbd5c5b561aa4e2c19..2feade04eef8136f87fbfb6f681e18e1c224b92f 100644 (file)
@@ -225,7 +225,7 @@ def instance_logger(
     else:
         name = _qual_logger_name_for_cls(instance.__class__)
 
-    instance._echo = echoflag
+    instance._echo = echoflag  # type: ignore
 
     logger: Union[logging.Logger, InstanceLogger]
 
@@ -239,7 +239,7 @@ def instance_logger(
         # levels by calling logger._log()
         logger = InstanceLogger(echoflag, name)
 
-    instance.logger = logger
+    instance.logger = logger  # type: ignore
 
 
 class echo_property:
index d4d010cbe686346a560c71d450ed05b211b600f3..dd3931faf61af9c678e18504c8cb8d008697582d 100644 (file)
@@ -511,7 +511,8 @@ class Composite(
 
         """
 
-        __hash__ = None
+        # https://github.com/python/mypy/issues/4266
+        __hash__ = None  # type: ignore
 
         @util.memoized_property
         def clauses(self):
index b4697912bce9ee5acc8836c34133435e95dd5573..58c7c4efd546984d24a8e1ec138103f00a66747d 100644 (file)
@@ -476,7 +476,8 @@ class Relationship(
                 "the set of foreign key values."
             )
 
-        __hash__ = None
+        # https://github.com/python/mypy/issues/4266
+        __hash__ = None  # type: ignore
 
         def __eq__(self, other):
             """Implement the ``==`` operator.
index 169ddf3dbb83701101cb9ef99c65ab2462ff320e..2e766f9766e7f7434ca3cb3d07aa9e1793aa874f 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import Any
 
 from .base import Executable as Executable
 from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
@@ -97,7 +98,7 @@ from .expression import within_group as within_group
 from .visitors import ClauseVisitor as ClauseVisitor
 
 
-def __go(lcls):
+def __go(lcls: Any) -> None:
     from .. import util as _sa_util
 
     from . import base
index 9043aa6d05f5bd74ce99c9d52593adcc30684202..e9acc7e6dc495b02a6bbe906f7b760a7b3ec65ef 100644 (file)
@@ -8,7 +8,7 @@
 from __future__ import annotations
 
 from typing import Any
-from typing import Union
+from typing import Optional
 
 from . import coercions
 from . import roles
@@ -23,8 +23,6 @@ from .selectable import Select
 from .selectable import TableClause
 from .selectable import TableSample
 from .selectable import Values
-from ..util.typing import _LiteralStar
-from ..util.typing import Literal
 
 
 def alias(selectable, name=None, flat=False):
@@ -283,9 +281,7 @@ def outerjoin(left, right, onclause=None, full=False):
     return Join(left, right, onclause, isouter=True, full=full)
 
 
-def select(
-    *entities: Union[_LiteralStar, Literal[1], _ColumnsClauseElement]
-) -> "Select":
+def select(*entities: _ColumnsClauseElement) -> Select:
     r"""Construct a new :class:`_expression.Select`.
 
 
@@ -326,7 +322,7 @@ def select(
     return Select(*entities)
 
 
-def table(name: str, *columns: ColumnClause, **kw: Any) -> "TableClause":
+def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause:
     """Produce a new :class:`_expression.TableClause`.
 
     The object returned is an instance of
@@ -435,7 +431,11 @@ def union_all(*selects):
     return CompoundSelect._create_union_all(*selects)
 
 
-def values(*columns, name=None, literal_binds=False) -> "Values":
+def values(
+    *columns: ColumnClause[Any],
+    name: Optional[str] = None,
+    literal_binds: bool = False,
+) -> Values:
     r"""Construct a :class:`_expression.Values` construct.
 
     The column expressions and the actual data for
index 2be98b88fef98d3e2f2841618f93ab0a5457f584..b50a7bf6a14f8550ae74280ba5e6943ef9b5572a 100644 (file)
@@ -9,6 +9,7 @@ from typing import Union
 from . import roles
 from .. import util
 from ..inspection import Inspectable
+from ..util.typing import Literal
 
 if TYPE_CHECKING:
     from .elements import quoted_name
@@ -24,12 +25,13 @@ if TYPE_CHECKING:
 _T = TypeVar("_T", bound=Any)
 
 _ColumnsClauseElement = Union[
+    Literal["*", 1],
     roles.ColumnsClauseRole,
-    Type,
+    Type[Any],
     Inspectable[roles.HasColumnElementClauseElement],
 ]
 _FromClauseElement = Union[
-    roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement]
+    roles.FromClauseRole, Type[Any], Inspectable[roles.HasFromClauseElement]
 ]
 
 _ColumnExpression = Union[
index 6a6b389de8430d76c123e056ff548cbb64ccf949..8f51359155825a62bf5f3f213e4a51e2d58a3d3c 100644 (file)
 
 from __future__ import annotations
 
-import collections.abc as collections_abc
 from enum import Enum
 from functools import reduce
 import itertools
 from itertools import zip_longest
 import operator
 import re
-import typing
 from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
+from typing import Generic
 from typing import Iterable
+from typing import Iterator
 from typing import List
+from typing import Mapping
 from typing import MutableMapping
+from typing import NoReturn
 from typing import Optional
 from typing import Sequence
+from typing import Set
 from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 
 from . import roles
 from . import visitors
@@ -36,17 +46,26 @@ from .cache_key import MemoizedHasCacheKey  # noqa
 from .traversals import HasCopyInternals  # noqa
 from .visitors import ClauseVisitor
 from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
 from .visitors import InternalTraversal
+from .. import event
 from .. import exc
 from .. import util
 from ..util import HasMemoized as HasMemoized
 from ..util import hybridmethod
 from ..util import typing as compat_typing
+from ..util.typing import Protocol
 from ..util.typing import Self
+from ..util.typing import TypeGuard
 
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
+    from . import coercions
+    from . import elements
+    from . import type_api
     from .elements import BindParameter
+    from .elements import ColumnClause
     from .elements import ColumnElement
+    from .elements import SQLCoreOperations
     from ..engine import Connection
     from ..engine import Result
     from ..engine.base import _CompiledCacheType
@@ -58,10 +77,12 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import CacheStats
     from ..engine.interfaces import Compiled
     from ..engine.interfaces import Dialect
+    from ..event import dispatcher
 
-coercions = None
-elements = None
-type_api = None
+if not TYPE_CHECKING:
+    coercions = None  # noqa
+    elements = None  # noqa
+    type_api = None  # noqa
 
 
 class _NoArg(Enum):
@@ -70,13 +91,24 @@ class _NoArg(Enum):
 
 NO_ARG = _NoArg.NO_ARG
 
-# if I use sqlalchemy.util.typing, which has the exact same
-# symbols, mypy reports: "error: _Fn? not callable"
-_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Fn = TypeVar("_Fn", bound=Callable[..., Any])
 
 _AmbiguousTableNameMap = MutableMapping[str, str]
 
 
+class _EntityNamespace(Protocol):
+    def __getattr__(self, key: str) -> SQLCoreOperations[Any]:
+        ...
+
+
+class _HasEntityNamespace(Protocol):
+    entity_namespace: _EntityNamespace
+
+
+def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]:
+    return hasattr(element, "entity_namespace")
+
+
 class Immutable:
     """mark a ClauseElement as 'immutable' when expressions are cloned."""
 
@@ -107,10 +139,14 @@ class SingletonConstant(Immutable):
     def __new__(cls, *arg, **kw):
         return cls._singleton
 
+    @util.non_memoized_property
+    def proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
+        raise NotImplementedError()
+
     @classmethod
     def _create_singleton(cls):
         obj = object.__new__(cls)
-        obj.__init__()
+        obj.__init__()  # type: ignore
 
         # for a long time this was an empty frozenset, meaning
         # a SingletonConstant would never be a "corresponding column" in
@@ -139,12 +175,11 @@ def _select_iterables(elements):
     )
 
 
-_Self = typing.TypeVar("_Self", bound="_GenerativeType")
-_Args = compat_typing.ParamSpec("_Args")
+_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType")
 
 
 class _GenerativeType(compat_typing.Protocol):
-    def _generate(self: "_Self") -> "_Self":
+    def _generate(self: _SelfGenerativeType) -> _SelfGenerativeType:
         ...
 
 
@@ -158,8 +193,8 @@ def _generative(fn: _Fn) -> _Fn:
 
     @util.decorator
     def _generative(
-        fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
-    ) -> _Self:
+        fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
+    ) -> _SelfGenerativeType:
         """Mark a method as generative."""
 
         self = self._generate()
@@ -167,9 +202,9 @@ def _generative(fn: _Fn) -> _Fn:
         assert x is self, "generative methods must return self"
         return self
 
-    decorated = _generative(fn)
-    decorated.non_generative = fn
-    return decorated
+    decorated = _generative(fn)  # type: ignore
+    decorated.non_generative = fn  # type: ignore
+    return decorated  # type: ignore
 
 
 def _exclusive_against(*names, **kw):
@@ -233,7 +268,7 @@ def _cloned_difference(a, b):
     )
 
 
-class _DialectArgView(collections_abc.MutableMapping):
+class _DialectArgView(MutableMapping[str, Any]):
     """A dictionary view of dialect-level arguments in the form
     <dialectname>_<argument_name>.
 
@@ -290,7 +325,7 @@ class _DialectArgView(collections_abc.MutableMapping):
         )
 
 
-class _DialectArgDict(collections_abc.MutableMapping):
+class _DialectArgDict(MutableMapping[str, Any]):
     """A dictionary view of dialect-level arguments for a specific
     dialect.
 
@@ -343,6 +378,8 @@ class DialectKWArgs:
 
     """
 
+    __slots__ = ()
+
     _dialect_kwargs_traverse_internals = [
         ("dialect_options", InternalTraversal.dp_dialect_options)
     ]
@@ -534,7 +571,7 @@ class CompileState:
 
     __slots__ = ("statement", "_ambiguous_table_name_map")
 
-    plugins = {}
+    plugins: Dict[Tuple[str, str], Type[CompileState]] = {}
 
     _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
 
@@ -639,9 +676,9 @@ class InPlaceGenerative(HasMemoized):
 class HasCompileState(Generative):
     """A class that has a :class:`.CompileState` associated with it."""
 
-    _compile_state_plugin = None
+    _compile_state_plugin: Optional[Type[CompileState]] = None
 
-    _attributes = util.immutabledict()
+    _attributes: util.immutabledict[str, Any] = util.EMPTY_DICT
 
     _compile_state_factory = CompileState.create_for_statement
 
@@ -655,6 +692,8 @@ class _MetaOptions(type):
 
     """
 
+    _cache_attrs: Tuple[str, ...]
+
     def __add__(self, other):
         o1 = self()
 
@@ -674,6 +713,8 @@ class Options(metaclass=_MetaOptions):
 
     __slots__ = ()
 
+    _cache_attrs: Tuple[str, ...]
+
     def __init_subclass__(cls) -> None:
         dict_ = cls.__dict__
         cls._cache_attrs = tuple(
@@ -732,13 +773,13 @@ class Options(metaclass=_MetaOptions):
         return self + {name: getattr(self, name) + value}
 
     @hybridmethod
-    def _state_dict(self):
+    def _state_dict_inst(self) -> Mapping[str, Any]:
         return self.__dict__
 
-    _state_dict_const = util.immutabledict()
+    _state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT
 
-    @_state_dict.classlevel
-    def _state_dict(cls):
+    @_state_dict_inst.classlevel
+    def _state_dict(cls) -> Mapping[str, Any]:
         return cls._state_dict_const
 
     @classmethod
@@ -825,10 +866,10 @@ class CacheableOptions(Options, HasCacheKey):
     __slots__ = ()
 
     @hybridmethod
-    def _gen_cache_key(self, anon_map, bindparams):
+    def _gen_cache_key_inst(self, anon_map, bindparams):
         return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
 
-    @_gen_cache_key.classlevel
+    @_gen_cache_key_inst.classlevel
     def _gen_cache_key(cls, anon_map, bindparams):
         return (cls, ())
 
@@ -849,11 +890,11 @@ class ExecutableOption(HasCopyInternals):
     def _clone(self, **kw):
         """Create a shallow copy of this ExecutableOption."""
         c = self.__class__.__new__(self.__class__)
-        c.__dict__ = dict(self.__dict__)
+        c.__dict__ = dict(self.__dict__)  # type: ignore
         return c
 
 
-SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable")
+SelfExecutable = TypeVar("SelfExecutable", bound="Executable")
 
 
 class Executable(roles.StatementRole, Generative):
@@ -866,9 +907,12 @@ class Executable(roles.StatementRole, Generative):
     """
 
     supports_execution: bool = True
-    _execution_options: _ImmutableExecuteOptions = util.immutabledict()
-    _with_options = ()
-    _with_context_options = ()
+    _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
+    _with_options: Tuple[ExecutableOption, ...] = ()
+    _with_context_options: Tuple[
+        Tuple[Callable[[CompileState], None], Any], ...
+    ] = ()
+    _compile_options: Optional[CacheableOptions]
 
     _executable_traverse_internals = [
         ("_with_options", InternalTraversal.dp_executable_options),
@@ -886,7 +930,9 @@ class Executable(roles.StatementRole, Generative):
     is_delete = False
     is_dml = False
 
-    if typing.TYPE_CHECKING:
+    if TYPE_CHECKING:
+
+        __visit_name__: str
 
         def _compile_w_cache(
             self,
@@ -916,11 +962,13 @@ class Executable(roles.StatementRole, Generative):
         raise NotImplementedError()
 
     @property
-    def _effective_plugin_target(self):
+    def _effective_plugin_target(self) -> str:
         return self.__visit_name__
 
     @_generative
-    def options(self: SelfExecutable, *options) -> SelfExecutable:
+    def options(
+        self: SelfExecutable, *options: ExecutableOption
+    ) -> SelfExecutable:
         """Apply options to this statement.
 
         In the general sense, options are any kind of Python object
@@ -957,7 +1005,7 @@ class Executable(roles.StatementRole, Generative):
 
     @_generative
     def _set_compile_options(
-        self: SelfExecutable, compile_options
+        self: SelfExecutable, compile_options: CacheableOptions
     ) -> SelfExecutable:
         """Assign the compile options to a new value.
 
@@ -970,16 +1018,19 @@ class Executable(roles.StatementRole, Generative):
 
     @_generative
     def _update_compile_options(
-        self: SelfExecutable, options
+        self: SelfExecutable, options: CacheableOptions
     ) -> SelfExecutable:
         """update the _compile_options with new keys."""
 
+        assert self._compile_options is not None
         self._compile_options += options
         return self
 
     @_generative
     def _add_context_option(
-        self: SelfExecutable, callable_, cache_args
+        self: SelfExecutable,
+        callable_: Callable[[CompileState], None],
+        cache_args: Any,
     ) -> SelfExecutable:
         """Add a context option to this statement.
 
@@ -995,7 +1046,7 @@ class Executable(roles.StatementRole, Generative):
         return self
 
     @_generative
-    def execution_options(self: SelfExecutable, **kw) -> SelfExecutable:
+    def execution_options(self: SelfExecutable, **kw: Any) -> SelfExecutable:
         """Set non-SQL options for the statement which take effect during
         execution.
 
@@ -1112,7 +1163,7 @@ class Executable(roles.StatementRole, Generative):
         self._execution_options = self._execution_options.union(kw)
         return self
 
-    def get_execution_options(self):
+    def get_execution_options(self) -> _ExecuteOptions:
         """Get the non-SQL options which will take effect during execution.
 
         .. versionadded:: 1.3
@@ -1124,7 +1175,7 @@ class Executable(roles.StatementRole, Generative):
         return self._execution_options
 
 
-class SchemaEventTarget:
+class SchemaEventTarget(event.EventTarget):
     """Base class for elements that are the targets of :class:`.DDLEvents`
     events.
 
@@ -1132,6 +1183,8 @@ class SchemaEventTarget:
 
     """
 
+    dispatch: dispatcher[SchemaEventTarget]
+
     def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
         """Associate with this SchemaEvent's parent object."""
 
@@ -1149,7 +1202,10 @@ class SchemaVisitor(ClauseVisitor):
     __traverse_options__ = {"schema_visitor": True}
 
 
-class ColumnCollection:
+_COL = TypeVar("_COL", bound="ColumnClause[Any]")
+
+
+class ColumnCollection(Generic[_COL]):
     """Collection of :class:`_expression.ColumnElement` instances,
     typically for
     :class:`_sql.FromClause` objects.
@@ -1260,32 +1316,36 @@ class ColumnCollection:
 
     __slots__ = "_collection", "_index", "_colset"
 
-    def __init__(self, columns=None):
+    _collection: List[Tuple[str, _COL]]
+    _index: Dict[Union[str, int], _COL]
+    _colset: Set[_COL]
+
+    def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None):
         object.__setattr__(self, "_colset", set())
         object.__setattr__(self, "_index", {})
         object.__setattr__(self, "_collection", [])
         if columns:
             self._initial_populate(columns)
 
-    def _initial_populate(self, iter_):
+    def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None:
         self._populate_separate_keys(iter_)
 
     @property
-    def _all_columns(self):
+    def _all_columns(self) -> List[_COL]:
         return [col for (k, col) in self._collection]
 
-    def keys(self):
+    def keys(self) -> List[str]:
         """Return a sequence of string key names for all columns in this
         collection."""
         return [k for (k, col) in self._collection]
 
-    def values(self):
+    def values(self) -> List[_COL]:
         """Return a sequence of :class:`_sql.ColumnClause` or
         :class:`_schema.Column` objects for all columns in this
         collection."""
         return [col for (k, col) in self._collection]
 
-    def items(self):
+    def items(self) -> List[Tuple[str, _COL]]:
         """Return a sequence of (key, column) tuples for all columns in this
         collection each consisting of a string key name and a
         :class:`_sql.ColumnClause` or
@@ -1294,17 +1354,17 @@ class ColumnCollection:
 
         return list(self._collection)
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return bool(self._collection)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._collection)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[_COL]:
         # turn to a list first to maintain over a course of changes
         return iter([col for k, col in self._collection])
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: Union[str, int]) -> _COL:
         try:
             return self._index[key]
         except KeyError as err:
@@ -1313,13 +1373,13 @@ class ColumnCollection:
             else:
                 raise
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> _COL:
         try:
             return self._index[key]
         except KeyError as err:
             raise AttributeError(key) from err
 
-    def __contains__(self, key):
+    def __contains__(self, key: str) -> bool:
         if key not in self._index:
             if not isinstance(key, str):
                 raise exc.ArgumentError(
@@ -1329,7 +1389,7 @@ class ColumnCollection:
         else:
             return True
 
-    def compare(self, other):
+    def compare(self, other: ColumnCollection[Any]) -> bool:
         """Compare this :class:`_expression.ColumnCollection` to another
         based on the names of the keys"""
 
@@ -1339,10 +1399,10 @@ class ColumnCollection:
         else:
             return True
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         return self.compare(other)
 
-    def get(self, key, default=None):
+    def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]:
         """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
         based on a string key name from this
         :class:`_expression.ColumnCollection`."""
@@ -1352,39 +1412,40 @@ class ColumnCollection:
         else:
             return default
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "%s(%s)" % (
             self.__class__.__name__,
             ", ".join(str(c) for c in self),
         )
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: str, value: Any) -> NoReturn:
         raise NotImplementedError()
 
-    def __delitem__(self, key):
+    def __delitem__(self, key: str) -> NoReturn:
         raise NotImplementedError()
 
-    def __setattr__(self, key, obj):
+    def __setattr__(self, key: str, obj: Any) -> NoReturn:
         raise NotImplementedError()
 
-    def clear(self):
+    def clear(self) -> NoReturn:
         """Dictionary clear() is not implemented for
         :class:`_sql.ColumnCollection`."""
         raise NotImplementedError()
 
-    def remove(self, column):
-        """Dictionary remove() is not implemented for
-        :class:`_sql.ColumnCollection`."""
+    def remove(self, column: Any) -> None:
         raise NotImplementedError()
 
-    def update(self, iter_):
+    def update(self, iter_: Any) -> NoReturn:
         """Dictionary update() is not implemented for
         :class:`_sql.ColumnCollection`."""
         raise NotImplementedError()
 
-    __hash__ = None
+    # https://github.com/python/mypy/issues/4266
+    __hash__ = None  # type: ignore
 
-    def _populate_separate_keys(self, iter_):
+    def _populate_separate_keys(
+        self, iter_: Iterable[Tuple[str, _COL]]
+    ) -> None:
         """populate from an iterator of (key, column)"""
         cols = list(iter_)
         self._collection[:] = cols
@@ -1394,7 +1455,7 @@ class ColumnCollection:
         )
         self._index.update({k: col for k, col in reversed(self._collection)})
 
-    def add(self, column, key=None):
+    def add(self, column: _COL, key: Optional[str] = None) -> None:
         """Add a column to this :class:`_sql.ColumnCollection`.
 
         .. note::
@@ -1416,17 +1477,17 @@ class ColumnCollection:
         if key not in self._index:
             self._index[key] = column
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         return {"_collection": self._collection, "_index": self._index}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         object.__setattr__(self, "_index", state["_index"])
         object.__setattr__(self, "_collection", state["_collection"])
         object.__setattr__(
             self, "_colset", {col for k, col in self._collection}
         )
 
-    def contains_column(self, col):
+    def contains_column(self, col: _COL) -> bool:
         """Checks if a column object exists in this collection"""
         if col not in self._colset:
             if isinstance(col, str):
@@ -1438,13 +1499,15 @@ class ColumnCollection:
         else:
             return True
 
-    def as_immutable(self):
+    def as_immutable(self) -> ImmutableColumnCollection[_COL]:
         """Return an "immutable" form of this
         :class:`_sql.ColumnCollection`."""
 
         return ImmutableColumnCollection(self)
 
-    def corresponding_column(self, column, require_embedded=False):
+    def corresponding_column(
+        self, column: _COL, require_embedded: bool = False
+    ) -> Optional[_COL]:
         """Given a :class:`_expression.ColumnElement`, return the exported
         :class:`_expression.ColumnElement` object from this
         :class:`_expression.ColumnCollection`
@@ -1497,7 +1560,7 @@ class ColumnCollection:
                 not require_embedded
                 or embedded(expanded_proxy_set, target_set)
             ):
-                if col is None:
+                if col is None or intersect is None:
 
                     # no corresponding column yet, pick this one.
 
@@ -1542,7 +1605,7 @@ class ColumnCollection:
         return col
 
 
-class DedupeColumnCollection(ColumnCollection):
+class DedupeColumnCollection(ColumnCollection[_COL]):
     """A :class:`_expression.ColumnCollection`
     that maintains deduplicating behavior.
 
@@ -1555,7 +1618,7 @@ class DedupeColumnCollection(ColumnCollection):
 
     """
 
-    def add(self, column, key=None):
+    def add(self, column: _COL, key: Optional[str] = None) -> None:
 
         if key is not None and column.key != key:
             raise exc.ArgumentError(
@@ -1589,7 +1652,9 @@ class DedupeColumnCollection(ColumnCollection):
             self._index[l] = column
             self._index[key] = column
 
-    def _populate_separate_keys(self, iter_):
+    def _populate_separate_keys(
+        self, iter_: Iterable[Tuple[str, _COL]]
+    ) -> None:
         """populate from an iterator of (key, column)"""
         cols = list(iter_)
 
@@ -1614,10 +1679,10 @@ class DedupeColumnCollection(ColumnCollection):
         for col in replace_col:
             self.replace(col)
 
-    def extend(self, iter_):
+    def extend(self, iter_: Iterable[_COL]) -> None:
         self._populate_separate_keys((col.key, col) for col in iter_)
 
-    def remove(self, column):
+    def remove(self, column: _COL) -> None:
         if column not in self._colset:
             raise ValueError(
                 "Can't remove column %r; column is not in this collection"
@@ -1634,7 +1699,7 @@ class DedupeColumnCollection(ColumnCollection):
         # delete higher index
         del self._index[len(self._collection)]
 
-    def replace(self, column):
+    def replace(self, column: _COL) -> None:
         """add the given column to this collection, removing unaliased
         versions of this column  as well as existing columns with the
         same key.
@@ -1687,7 +1752,9 @@ class DedupeColumnCollection(ColumnCollection):
         self._index.update(self._collection)
 
 
-class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
+class ImmutableColumnCollection(
+    util.ImmutableContainer, ColumnCollection[_COL]
+):
     __slots__ = ("_parent",)
 
     def __init__(self, collection):
@@ -1701,12 +1768,19 @@ class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
 
     def __setstate__(self, state):
         parent = state["_parent"]
-        self.__init__(parent)
+        self.__init__(parent)  # type: ignore
 
-    add = extend = remove = util.ImmutableContainer._immutable
+    def add(self, column: Any, key: Any = ...) -> Any:
+        self._immutable()
 
+    def extend(self, elements: Any) -> None:
+        self._immutable()
 
-class ColumnSet(util.ordered_column_set):
+    def remove(self, item: Any) -> None:
+        self._immutable()
+
+
+class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
     def contains_column(self, col):
         return col in self
 
@@ -1714,9 +1788,6 @@ class ColumnSet(util.ordered_column_set):
         for col in cols:
             self.add(col)
 
-    def __add__(self, other):
-        return list(self) + list(other)
-
     def __eq__(self, other):
         l = []
         for c in other:
@@ -1729,7 +1800,9 @@ class ColumnSet(util.ordered_column_set):
         return hash(tuple(x for x in self))
 
 
-def _entity_namespace(entity):
+def _entity_namespace(
+    entity: Union[_HasEntityNamespace, ExternallyTraversible]
+) -> _EntityNamespace:
     """Return the nearest .entity_namespace for the given entity.
 
     If not immediately available, does an iterate to find a sub-element
@@ -1737,16 +1810,20 @@ def _entity_namespace(entity):
 
     """
     try:
-        return entity.entity_namespace
+        return cast(_HasEntityNamespace, entity).entity_namespace
     except AttributeError:
-        for elem in visitors.iterate(entity):
-            if hasattr(elem, "entity_namespace"):
+        for elem in visitors.iterate(cast(ExternallyTraversible, entity)):
+            if _is_has_entity_namespace(elem):
                 return elem.entity_namespace
         else:
             raise
 
 
-def _entity_namespace_key(entity, key, default=NO_ARG):
+def _entity_namespace_key(
+    entity: Union[_HasEntityNamespace, ExternallyTraversible],
+    key: str,
+    default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG,
+) -> SQLCoreOperations[Any]:
     """Return an entry from an entity_namespace.
 
 
@@ -1760,7 +1837,7 @@ def _entity_namespace_key(entity, key, default=NO_ARG):
         if default is not NO_ARG:
             return getattr(ns, key, default)
         else:
-            return getattr(ns, key)
+            return getattr(ns, key)  # type: ignore
     except AttributeError as err:
         raise exc.InvalidRequestError(
             'Entity namespace for "%s" has no property "%s"' % (entity, key)
index f8019b9c64e242f46509503fcdfdddc733b8f08d..5ba52ae51c03695a80af7670585a69073b1cb157 100644 (file)
@@ -71,6 +71,7 @@ from .schema import Column
 from .sqltypes import TupleType
 from .type_api import TypeEngine
 from .visitors import prefix_anon_map
+from .visitors import Visitable
 from .. import exc
 from .. import util
 from ..util.typing import Literal
@@ -614,10 +615,10 @@ class Compiled:
 
         raise NotImplementedError()
 
-    def process(self, obj, **kwargs):
+    def process(self, obj: Visitable, **kwargs: Any) -> str:
         return obj._compiler_dispatch(self, **kwargs)
 
-    def __str__(self):
+    def __str__(self) -> str:
         """Return the string text of the generated SQL or DDL."""
 
         return self.string or ""
@@ -723,7 +724,7 @@ class SQLCompiler(Compiled):
     """list of columns for which onupdate default values should be evaluated
     before an UPDATE takes place"""
 
-    returning: Optional[List[Column[Any]]]
+    returning: Optional[List[ColumnClause[Any]]]
     """list of columns that will be delivered to cursor.description or
     dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
 
@@ -1485,15 +1486,12 @@ class SQLCompiler(Compiled):
             self._result_columns
         )
 
-    _key_getters_for_crud_column: Tuple[
-        Callable[[Union[str, Column[Any]]], str],
-        Callable[[Column[Any]], str],
-        Callable[[Column[Any]], str],
-    ]
+    # assigned by crud.py for insert/update statements
+    _get_bind_name_for_col: _BindNameForColProtocol
 
     @util.memoized_property
     def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
-        getter = self._key_getters_for_crud_column[2]
+        getter = self._get_bind_name_for_col
         if self.escaped_bind_names:
 
             def _get(obj):
@@ -4098,7 +4096,9 @@ class SQLCompiler(Compiled):
     def for_update_clause(self, select, **kw):
         return " FOR UPDATE"
 
-    def returning_clause(self, stmt, returning_cols):
+    def returning_clause(
+        self, stmt: UpdateBase, returning_cols: List[ColumnClause[Any]]
+    ) -> str:
         raise exc.CompileError(
             "RETURNING is not supported by this "
             "dialect's statement compiler."
@@ -4243,12 +4243,13 @@ class SQLCompiler(Compiled):
             }
         )
 
-        crud_params = crud._get_crud_params(
+        crud_params_struct = crud._get_crud_params(
             self, insert_stmt, compile_state, **kw
         )
+        crud_params_single = crud_params_struct.single_params
 
         if (
-            not crud_params
+            not crud_params_single
             and not self.dialect.supports_default_values
             and not self.dialect.supports_default_metavalue
             and not self.dialect.supports_empty_insert
@@ -4266,9 +4267,9 @@ class SQLCompiler(Compiled):
                     "version settings does not support "
                     "in-place multirow inserts." % self.dialect.name
                 )
-            crud_params_single = crud_params[0]
+            crud_params_single = crud_params_struct.single_params
         else:
-            crud_params_single = crud_params
+            crud_params_single = crud_params_struct.single_params
 
         preparer = self.preparer
         supports_default_values = self.dialect.supports_default_values
@@ -4293,7 +4294,7 @@ class SQLCompiler(Compiled):
 
         if crud_params_single or not supports_default_values:
             text += " (%s)" % ", ".join(
-                [expr for c, expr, value in crud_params_single]
+                [expr for _, expr, _ in crud_params_single]
             )
 
         if self.returning or insert_stmt._returning:
@@ -4323,19 +4324,24 @@ class SQLCompiler(Compiled):
                 )
             else:
                 text += " %s" % select_text
-        elif not crud_params and supports_default_values:
+        elif not crud_params_single and supports_default_values:
             text += " DEFAULT VALUES"
         elif compile_state._has_multi_parameters:
             text += " VALUES %s" % (
                 ", ".join(
                     "(%s)"
-                    % (", ".join(value for c, expr, value in crud_param_set))
-                    for crud_param_set in crud_params
+                    % (", ".join(value for _, _, value in crud_param_set))
+                    for crud_param_set in crud_params_struct.all_multi_params
                 )
             )
         else:
             insert_single_values_expr = ", ".join(
-                [value for c, expr, value in crud_params]
+                [
+                    value
+                    for _, _, value in cast(
+                        "List[Tuple[Any, Any, str]]", crud_params_single
+                    )
+                ]
             )
             text += " VALUES (%s)" % insert_single_values_expr
             if toplevel and insert_stmt._post_values_clause is None:
@@ -4443,9 +4449,10 @@ class SQLCompiler(Compiled):
         table_text = self.update_tables_clause(
             update_stmt, update_stmt.table, render_extra_froms, **kw
         )
-        crud_params = crud._get_crud_params(
+        crud_params_struct = crud._get_crud_params(
             self, update_stmt, compile_state, **kw
         )
+        crud_params = crud_params_struct.single_params
 
         if update_stmt._hints:
             dialect_hints, table_text = self._setup_crud_hints(
@@ -4460,7 +4467,12 @@ class SQLCompiler(Compiled):
         text += table_text
 
         text += " SET "
-        text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
+        text += ", ".join(
+            expr + "=" + value
+            for _, expr, value in cast(
+                "List[Tuple[Any, str, str]]", crud_params
+            )
+        )
 
         if self.returning or update_stmt._returning:
             if self.returning_precedes_values:
@@ -5446,6 +5458,11 @@ class _SchemaForObjectCallable(Protocol):
         ...
 
 
+class _BindNameForColProtocol(Protocol):
+    def __call__(self, col: ColumnClause[Any]) -> str:
+        ...
+
+
 class IdentifierPreparer:
 
     """Handle quoting and case-folding of identifiers based on options."""
index 4292aa9162a4550929bd2e6d84a83f4065a423c4..533a2f6cd98d63dd2fe11fe93bdd9424d21909d5 100644 (file)
@@ -13,13 +13,44 @@ from __future__ import annotations
 
 import functools
 import operator
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import MutableMapping
+from typing import NamedTuple
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
 
 from . import coercions
 from . import dml
 from . import elements
 from . import roles
+from .schema import default_is_clause_element
+from .schema import default_is_sequence
 from .. import exc
 from .. import util
+from ..util.typing import Literal
+
+if TYPE_CHECKING:
+    from .compiler import _BindNameForColProtocol
+    from .compiler import SQLCompiler
+    from .dml import DMLState
+    from .dml import Insert
+    from .dml import Update
+    from .dml import UpdateDMLState
+    from .dml import ValuesBase
+    from .elements import ClauseElement
+    from .elements import ColumnClause
+    from .elements import ColumnElement
+    from .elements import TextClause
+    from .schema import _SQLExprDefault
+    from .schema import Column
+    from .selectable import TableClause
 
 REQUIRED = util.symbol(
     "REQUIRED",
@@ -36,7 +67,27 @@ values present.
 )
 
 
-def _get_crud_params(compiler, stmt, compile_state, **kw):
+class _CrudParams(NamedTuple):
+    single_params: List[
+        Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+    ]
+    all_multi_params: List[
+        List[
+            Tuple[
+                ColumnClause[Any],
+                str,
+                str,
+            ]
+        ]
+    ]
+
+
+def _get_crud_params(
+    compiler: SQLCompiler,
+    stmt: ValuesBase,
+    compile_state: DMLState,
+    **kw: Any,
+) -> _CrudParams:
     """create a set of tuples representing column/string pairs for use
     in an INSERT or UPDATE statement.
 
@@ -59,24 +110,32 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
         _column_as_key,
         _getattr_col_key,
         _col_bind_name,
-    ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+    ) = _key_getters_for_crud_column(compiler, stmt, compile_state)
 
-    compiler._key_getters_for_crud_column = getters
+    compiler._get_bind_name_for_col = _col_bind_name
 
     # no parameters in the statement, no parameters in the
     # compiled params - return binds for all columns
     if compiler.column_keys is None and compile_state._no_parameters:
-        return [
-            (
-                c,
-                compiler.preparer.format_column(c),
-                _create_bind_param(compiler, c, None, required=True),
-            )
-            for c in stmt.table.columns
-        ]
+        return _CrudParams(
+            [
+                (
+                    c,
+                    compiler.preparer.format_column(c),
+                    _create_bind_param(compiler, c, None, required=True),
+                )
+                for c in stmt.table.columns
+            ],
+            [],
+        )
+
+    stmt_parameter_tuples: Optional[List[Any]]
+    spd: Optional[MutableMapping[str, Any]]
 
     if compile_state._has_multi_parameters:
-        spd = compile_state._multi_parameters[0]
+        mp = compile_state._multi_parameters
+        assert mp is not None
+        spd = mp[0]
         stmt_parameter_tuples = list(spd.items())
     elif compile_state._ordered_values:
         spd = compile_state._dict_parameters
@@ -92,6 +151,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
     if compiler.column_keys is None:
         parameters = {}
     elif stmt_parameter_tuples:
+        assert spd is not None
         parameters = dict(
             (_column_as_key(key), REQUIRED)
             for key in compiler.column_keys
@@ -103,7 +163,9 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
         )
 
     # create a list of column assignment clauses as tuples
-    values = []
+    values: List[
+        Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+    ] = []
 
     if stmt_parameter_tuples is not None:
         _get_stmt_parameter_tuples_params(
@@ -116,11 +178,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
             kw,
         )
 
-    check_columns = {}
+    check_columns: Dict[str, ColumnClause[Any]] = {}
 
     # special logic that only occurs for multi-table UPDATE
     # statements
-    if compile_state.isupdate and compile_state.is_multitable:
+    if dml.isupdate(compile_state) and compile_state.is_multitable:
         _get_update_multitable_params(
             compiler,
             stmt,
@@ -134,6 +196,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
         )
 
     if compile_state.isinsert and stmt._select_names:
+        # is an insert from select, is not a multiparams
+
+        assert not compile_state._has_multi_parameters
+
         _scan_insert_from_select_cols(
             compiler,
             stmt,
@@ -173,14 +239,17 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
             )
 
     if compile_state._has_multi_parameters:
-        values = _extend_values_for_multiparams(
+        # is a multiparams, is not an insert from a select
+        assert not stmt._select_names
+        multi_extended_values = _extend_values_for_multiparams(
             compiler,
             stmt,
             compile_state,
-            values,
-            _column_as_key,
+            cast("List[Tuple[ColumnClause[Any], str, str]]", values),
+            cast("Callable[..., str]", _column_as_key),
             kw,
         )
+        return _CrudParams(values, multi_extended_values)
     elif (
         not values
         and compiler.for_executemany
@@ -198,12 +267,41 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
             )
         ]
 
-    return values
+    return _CrudParams(values, [])
 
 
+@overload
 def _create_bind_param(
-    compiler, col, value, process=True, required=False, name=None, **kw
-):
+    compiler: SQLCompiler,
+    col: ColumnElement[Any],
+    value: Any,
+    process: Literal[True] = ...,
+    required: bool = False,
+    name: Optional[str] = None,
+    **kw: Any,
+) -> str:
+    ...
+
+
+@overload
+def _create_bind_param(
+    compiler: SQLCompiler,
+    col: ColumnElement[Any],
+    value: Any,
+    **kw: Any,
+) -> str:
+    ...
+
+
+def _create_bind_param(
+    compiler: SQLCompiler,
+    col: ColumnElement[Any],
+    value: Any,
+    process: bool = True,
+    required: bool = False,
+    name: Optional[str] = None,
+    **kw: Any,
+) -> Union[str, elements.BindParameter[Any]]:
     if name is None:
         name = col.key
     bindparam = elements.BindParameter(
@@ -211,8 +309,9 @@ def _create_bind_param(
     )
     bindparam._is_crud = True
     if process:
-        bindparam = bindparam._compiler_dispatch(compiler, **kw)
-    return bindparam
+        return bindparam._compiler_dispatch(compiler, **kw)
+    else:
+        return bindparam
 
 
 def _handle_values_anonymous_param(compiler, col, value, name, **kw):
@@ -253,8 +352,14 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw):
     return value._compiler_dispatch(compiler, **kw)
 
 
-def _key_getters_for_crud_column(compiler, stmt, compile_state):
-    if compile_state.isupdate and compile_state._extra_froms:
+def _key_getters_for_crud_column(
+    compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState
+) -> Tuple[
+    Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]],
+    Callable[[Column[Any]], Union[str, Tuple[str, str]]],
+    _BindNameForColProtocol,
+]:
+    if dml.isupdate(compile_state) and compile_state._extra_froms:
         # when extra tables are present, refer to the columns
         # in those extra tables as table-qualified, including in
         # dictionaries and when rendering bind param names.
@@ -267,30 +372,36 @@ def _key_getters_for_crud_column(compiler, stmt, compile_state):
             coercions.expect_as_key, roles.DMLColumnRole
         )
 
-        def _column_as_key(key):
+        def _column_as_key(
+            key: Union[ColumnClause[Any], str]
+        ) -> Union[str, Tuple[str, str]]:
             str_key = c_key_role(key)
-            if hasattr(key, "table") and key.table in _et:
-                return (key.table.name, str_key)
+            if hasattr(key, "table") and key.table in _et:  # type: ignore
+                return (key.table.name, str_key)  # type: ignore
             else:
-                return str_key
+                return str_key  # type: ignore
 
-        def _getattr_col_key(col):
+        def _getattr_col_key(
+            col: ColumnClause[Any],
+        ) -> Union[str, Tuple[str, str]]:
             if col.table in _et:
-                return (col.table.name, col.key)
+                return (col.table.name, col.key)  # type: ignore
             else:
                 return col.key
 
-        def _col_bind_name(col):
+        def _col_bind_name(col: ColumnClause[Any]) -> str:
             if col.table in _et:
+                if TYPE_CHECKING:
+                    assert isinstance(col.table, TableClause)
                 return "%s_%s" % (col.table.name, col.key)
             else:
                 return col.key
 
     else:
-        _column_as_key = functools.partial(
+        _column_as_key = functools.partial(  # type: ignore
             coercions.expect_as_key, roles.DMLColumnRole
         )
-        _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+        _getattr_col_key = _col_bind_name = operator.attrgetter("key")  # type: ignore  # noqa E501
 
     return _column_as_key, _getattr_col_key, _col_bind_name
 
@@ -321,7 +432,7 @@ def _scan_insert_from_select_cols(
 
     compiler.stack[-1]["insert_from_select"] = stmt.select
 
-    add_select_cols = []
+    add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = []
     if stmt.include_insert_from_select_defaults:
         col_set = set(cols)
         for col in stmt.table.columns:
@@ -707,16 +818,22 @@ def _append_param_insert_hasdefault(
         )
 
 
-def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
+def _append_param_insert_select_hasdefault(
+    compiler: SQLCompiler,
+    stmt: ValuesBase,
+    c: ColumnClause[Any],
+    values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]],
+    kw: Dict[str, Any],
+) -> None:
 
-    if c.default.is_sequence:
+    if default_is_sequence(c.default):
         if compiler.dialect.supports_sequences and (
             not c.default.optional or not compiler.dialect.sequences_optional
         ):
             values.append(
                 (c, compiler.preparer.format_column(c), c.default.next_value())
             )
-    elif c.default.is_clause_element:
+    elif default_is_clause_element(c.default):
         values.append(
             (c, compiler.preparer.format_column(c), c.default.arg.self_group())
         )
@@ -777,28 +894,76 @@ def _append_param_update(
         compiler.returning.append(c)
 
 
+@overload
 def _create_insert_prefetch_bind_param(
-    compiler, c, process=True, name=None, **kw
-):
+    compiler: SQLCompiler,
+    c: ColumnElement[Any],
+    process: Literal[True] = ...,
+    **kw: Any,
+) -> str:
+    ...
+
+
+@overload
+def _create_insert_prefetch_bind_param(
+    compiler: SQLCompiler,
+    c: ColumnElement[Any],
+    process: Literal[False],
+    **kw: Any,
+) -> elements.BindParameter[Any]:
+    ...
+
+
+def _create_insert_prefetch_bind_param(
+    compiler: SQLCompiler,
+    c: ColumnElement[Any],
+    process: bool = True,
+    name: Optional[str] = None,
+    **kw: Any,
+) -> Union[elements.BindParameter[Any], str]:
 
     param = _create_bind_param(
         compiler, c, None, process=process, name=name, **kw
     )
-    compiler.insert_prefetch.append(c)
+    compiler.insert_prefetch.append(c)  # type: ignore
     return param
 
 
+@overload
 def _create_update_prefetch_bind_param(
-    compiler, c, process=True, name=None, **kw
-):
+    compiler: SQLCompiler,
+    c: ColumnElement[Any],
+    process: Literal[True] = ...,
+    **kw: Any,
+) -> str:
+    ...
+
+
+@overload
+def _create_update_prefetch_bind_param(
+    compiler: SQLCompiler,
+    c: ColumnElement[Any],
+    process: Literal[False],
+    **kw: Any,
+) -> elements.BindParameter[Any]:
+    ...
+
+
+def _create_update_prefetch_bind_param(
+    compiler: SQLCompiler,
+    c: ColumnElement[Any],
+    process: bool = True,
+    name: Optional[str] = None,
+    **kw: Any,
+) -> Union[elements.BindParameter[Any], str]:
     param = _create_bind_param(
         compiler, c, None, process=process, name=name, **kw
     )
-    compiler.update_prefetch.append(c)
+    compiler.update_prefetch.append(c)  # type: ignore
     return param
 
 
-class _multiparam_column(elements.ColumnElement):
+class _multiparam_column(elements.ColumnElement[Any]):
     _is_multiparam_column = True
 
     def __init__(self, original, index):
@@ -822,14 +987,20 @@ class _multiparam_column(elements.ColumnElement):
         )
 
 
-def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
+def _process_multiparam_default_bind(
+    compiler: SQLCompiler,
+    stmt: ValuesBase,
+    c: ColumnClause[Any],
+    index: int,
+    kw: Dict[str, Any],
+) -> str:
     if not c.default:
         raise exc.CompileError(
             "INSERT value for column %s is explicitly rendered as a bound"
             "parameter in the VALUES clause; "
             "a Python-side value or SQL expression is required" % c
         )
-    elif c.default.is_clause_element:
+    elif default_is_clause_element(c.default):
         return compiler.process(c.default.arg.self_group(), **kw)
     elif c.default.is_sequence:
         # these conditions would have been established
@@ -844,9 +1015,13 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
     else:
         col = _multiparam_column(c, index)
         if isinstance(stmt, dml.Insert):
-            return _create_insert_prefetch_bind_param(compiler, col, **kw)
+            return _create_insert_prefetch_bind_param(
+                compiler, col, process=True, **kw
+            )
         else:
-            return _create_update_prefetch_bind_param(compiler, col, **kw)
+            return _create_update_prefetch_bind_param(
+                compiler, col, process=True, **kw
+            )
 
 
 def _get_update_multitable_params(
@@ -926,18 +1101,26 @@ def _get_update_multitable_params(
 
 
 def _extend_values_for_multiparams(
-    compiler,
-    stmt,
-    compile_state,
-    values,
-    _column_as_key,
-    kw,
-):
-    values_0 = values
-    values = [values]
-
-    for i, row in enumerate(compile_state._multi_parameters[1:]):
-        extension = []
+    compiler: SQLCompiler,
+    stmt: ValuesBase,
+    compile_state: DMLState,
+    initial_values: List[Tuple[ColumnClause[Any], str, str]],
+    _column_as_key: Callable[..., str],
+    kw: Dict[str, Any],
+) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
+    values_0 = initial_values
+    values = [initial_values]
+
+    mp = compile_state._multi_parameters
+    assert mp is not None
+    for i, row in enumerate(mp[1:]):
+        extension: List[
+            Tuple[
+                ColumnClause[Any],
+                str,
+                str,
+            ]
+        ] = []
 
         row = {_column_as_key(key): v for key, v in row.items()}
 
index 91bb0a5c58bdd868653c9e023996e1558f95822f..944a0a5ce6270479ab952be051c04c1efdfb40ce 100644 (file)
@@ -26,6 +26,7 @@ from . import roles
 from . import type_api
 from .elements import and_
 from .elements import BinaryExpression
+from .elements import ClauseElement
 from .elements import ClauseList
 from .elements import CollationClause
 from .elements import CollectionAggregate
@@ -43,7 +44,7 @@ _T = typing.TypeVar("_T", bound=Any)
 if typing.TYPE_CHECKING:
     from .elements import ColumnElement
     from .operators import custom_op
-    from .sqltypes import TypeEngine
+    from .type_api import TypeEngine
 
 
 def _boolean_compare(
@@ -53,10 +54,10 @@ def _boolean_compare(
     *,
     negate_op: Optional[OperatorType] = None,
     reverse: bool = False,
-    _python_is_types=(util.NoneType, bool),
-    _any_all_expr=False,
+    _python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
+    _any_all_expr: bool = False,
     result_type: Optional[
-        Union[Type["TypeEngine[bool]"], "TypeEngine[bool]"]
+        Union[Type[TypeEngine[bool]], TypeEngine[bool]]
     ] = None,
     **kwargs: Any,
 ) -> BinaryExpression[bool]:
@@ -165,7 +166,7 @@ def _custom_op_operate(
 def _binary_operate(
     expr: ColumnElement[Any],
     op: OperatorType,
-    obj: roles.BinaryElementRole,
+    obj: roles.BinaryElementRole[Any],
     *,
     reverse: bool = False,
     result_type: Optional[
@@ -192,7 +193,7 @@ def _binary_operate(
 
 
 def _conjunction_operate(
-    expr: ColumnElement[Any], op: OperatorType, other, **kw
+    expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
 ) -> ColumnElement[Any]:
     if op is operators.and_:
         return and_(expr, other)
@@ -203,7 +204,10 @@ def _conjunction_operate(
 
 
 def _scalar(
-    expr: ColumnElement[Any], op: OperatorType, fn, **kw
+    expr: ColumnElement[Any],
+    op: OperatorType,
+    fn: Callable[[ColumnElement[Any]], ColumnElement[Any]],
+    **kw: Any,
 ) -> ColumnElement[Any]:
     return fn(expr)
 
@@ -211,9 +215,9 @@ def _scalar(
 def _in_impl(
     expr: ColumnElement[Any],
     op: OperatorType,
-    seq_or_selectable,
+    seq_or_selectable: ClauseElement,
     negate_op: OperatorType,
-    **kw,
+    **kw: Any,
 ) -> ColumnElement[Any]:
     seq_or_selectable = coercions.expect(
         roles.InElementRole, seq_or_selectable, expr=expr, operator=op
@@ -227,7 +231,7 @@ def _in_impl(
 
 
 def _getitem_impl(
-    expr: ColumnElement[Any], op: OperatorType, other, **kw
+    expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
 ) -> ColumnElement[Any]:
     if isinstance(expr.type, type_api.INDEXABLE):
         other = coercions.expect(
@@ -239,7 +243,7 @@ def _getitem_impl(
 
 
 def _unsupported_impl(
-    expr: ColumnElement[Any], op: OperatorType, *arg, **kw
+    expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any
 ) -> NoReturn:
     raise NotImplementedError(
         "Operator '%s' is not supported on " "this expression" % op.__name__
@@ -247,7 +251,7 @@ def _unsupported_impl(
 
 
 def _inv_impl(
-    expr: ColumnElement[Any], op: OperatorType, **kw
+    expr: ColumnElement[Any], op: OperatorType, **kw: Any
 ) -> ColumnElement[Any]:
     """See :meth:`.ColumnOperators.__inv__`."""
 
@@ -260,14 +264,14 @@ def _inv_impl(
 
 
 def _neg_impl(
-    expr: ColumnElement[Any], op: OperatorType, **kw
+    expr: ColumnElement[Any], op: OperatorType, **kw: Any
 ) -> ColumnElement[Any]:
     """See :meth:`.ColumnOperators.__neg__`."""
     return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
 
 
 def _match_impl(
-    expr: ColumnElement[Any], op: OperatorType, other, **kw
+    expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
 ) -> ColumnElement[Any]:
     """See :meth:`.ColumnOperators.match`."""
 
@@ -289,7 +293,7 @@ def _match_impl(
 
 
 def _distinct_impl(
-    expr: ColumnElement[Any], op: OperatorType, **kw
+    expr: ColumnElement[Any], op: OperatorType, **kw: Any
 ) -> ColumnElement[Any]:
     """See :meth:`.ColumnOperators.distinct`."""
     return UnaryExpression(
@@ -298,7 +302,11 @@ def _distinct_impl(
 
 
 def _between_impl(
-    expr: ColumnElement[Any], op: OperatorType, cleft, cright, **kw
+    expr: ColumnElement[Any],
+    op: OperatorType,
+    cleft: Any,
+    cright: Any,
+    **kw: Any,
 ) -> ColumnElement[Any]:
     """See :meth:`.ColumnOperators.between`."""
     return BinaryExpression(
@@ -329,26 +337,32 @@ def _between_impl(
 
 
 def _collate_impl(
-    expr: ColumnElement[Any], op: OperatorType, collation, **kw
-) -> ColumnElement[Any]:
+    expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any
+) -> ColumnElement[str]:
     return CollationClause._create_collation_expression(expr, collation)
 
 
 def _regexp_match_impl(
-    expr: ColumnElement[Any], op: OperatorType, pattern, flags, **kw
+    expr: ColumnElement[str],
+    op: OperatorType,
+    pattern: Any,
+    flags: Optional[str],
+    **kw: Any,
 ) -> ColumnElement[Any]:
     if flags is not None:
-        flags = coercions.expect(
+        flags_expr = coercions.expect(
             roles.BinaryElementRole,
             flags,
             expr=expr,
             operator=operators.regexp_replace_op,
         )
+    else:
+        flags_expr = None
     return _boolean_compare(
         expr,
         op,
         pattern,
-        flags=flags,
+        flags=flags_expr,
         negate_op=operators.not_regexp_match_op
         if op is operators.regexp_match_op
         else operators.regexp_match_op,
@@ -359,10 +373,10 @@ def _regexp_match_impl(
 def _regexp_replace_impl(
     expr: ColumnElement[Any],
     op: OperatorType,
-    pattern,
-    replacement,
-    flags,
-    **kw,
+    pattern: Any,
+    replacement: Any,
+    flags: Optional[str],
+    **kw: Any,
 ) -> ColumnElement[Any]:
     replacement = coercions.expect(
         roles.BinaryElementRole,
@@ -371,21 +385,29 @@ def _regexp_replace_impl(
         operator=operators.regexp_replace_op,
     )
     if flags is not None:
-        flags = coercions.expect(
+        flags_expr = coercions.expect(
             roles.BinaryElementRole,
             flags,
             expr=expr,
             operator=operators.regexp_replace_op,
         )
+    else:
+        flags_expr = None
     return _binary_operate(
-        expr, op, pattern, replacement=replacement, flags=flags, **kw
+        expr, op, pattern, replacement=replacement, flags=flags_expr, **kw
     )
 
 
 # a mapping of operators with the method they use, along with
 # additional keyword arguments to be passed
 operator_lookup: Dict[
-    str, Tuple[Callable[..., ColumnElement[Any]], util.immutabledict]
+    str,
+    Tuple[
+        Callable[..., ColumnElement[Any]],
+        util.immutabledict[
+            str, Union[OperatorType, Callable[..., ColumnElement[Any]]]
+        ],
+    ],
 ] = {
     "and_": (_conjunction_operate, util.EMPTY_DICT),
     "or_": (_conjunction_operate, util.EMPTY_DICT),
index 1271c5977c20f9e95b53e3d1100c85c5a6b98ce3..10316dd2bbc82393a515072a9791f0e3a4e3c616 100644 (file)
@@ -12,11 +12,13 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and
 from __future__ import annotations
 
 import collections.abc as collections_abc
+import operator
 import typing
 from typing import Any
 from typing import List
 from typing import MutableMapping
 from typing import Optional
+from typing import TYPE_CHECKING
 
 from . import coercions
 from . import roles
@@ -36,10 +38,29 @@ from .elements import Null
 from .selectable import HasCTE
 from .selectable import HasPrefixes
 from .selectable import ReturnsRows
+from .selectable import TableClause
 from .sqltypes import NullType
 from .visitors import InternalTraversal
 from .. import exc
 from .. import util
+from ..util.typing import TypeGuard
+
+
+if TYPE_CHECKING:
+
+    def isupdate(dml) -> TypeGuard[UpdateDMLState]:
+        ...
+
+    def isdelete(dml) -> TypeGuard[DeleteDMLState]:
+        ...
+
+    def isinsert(dml) -> TypeGuard[InsertDMLState]:
+        ...
+
+else:
+    isupdate = operator.attrgetter("isupdate")
+    isdelete = operator.attrgetter("isdelete")
+    isinsert = operator.attrgetter("isinsert")
 
 
 class DMLState(CompileState):
@@ -49,6 +70,7 @@ class DMLState(CompileState):
     _ordered_values = None
     _parameter_ordering = None
     _has_multi_parameters = False
+
     isupdate = False
     isdelete = False
     isinsert = False
@@ -237,6 +259,8 @@ class UpdateBase(
     _hints = util.immutabledict()
     named_with_column = False
 
+    table: TableClause
+
     _return_defaults = False
     _return_defaults_columns = None
     _returning = ()
index 48c3c3be665628ee656bed990ac0bc143fc922c9..691eb10ec4acf5c4ec9f41dce919c083ec4f1e1e 100644 (file)
@@ -18,6 +18,7 @@ import itertools
 import operator
 import re
 import typing
+from typing import AbstractSet
 from typing import Any
 from typing import Callable
 from typing import cast
@@ -83,6 +84,7 @@ if typing.TYPE_CHECKING:
     from .operators import OperatorType
     from .schema import Column
     from .schema import DefaultGenerator
+    from .schema import FetchedValue
     from .schema import ForeignKey
     from .selectable import FromClause
     from .selectable import NamedFromClause
@@ -290,7 +292,7 @@ class ClauseElement(
 
     """
 
-    @util.memoized_property
+    @util.ro_memoized_property
     def description(self) -> Optional[str]:
         return None
 
@@ -319,7 +321,7 @@ class ClauseElement(
 
     _cache_key_traversal = None
 
-    negation_clause: ClauseElement
+    negation_clause: ColumnElement[bool]
 
     if typing.TYPE_CHECKING:
 
@@ -1153,9 +1155,7 @@ class ColumnElement(
     primary_key: bool = False
     _is_clone_of: Optional[ColumnElement[_T]]
 
-    @util.memoized_property
-    def foreign_keys(self) -> Iterable[ForeignKey]:
-        return []
+    foreign_keys: AbstractSet[ForeignKey] = frozenset()
 
     @util.memoized_property
     def _proxies(self) -> List[ColumnElement[Any]]:
@@ -1494,6 +1494,8 @@ class ColumnElement(
         else:
             key = name
 
+        assert key is not None
+
         co: ColumnClause[_T] = ColumnClause(
             coercions.expect(roles.TruncatedLabelRole, name)
             if name_is_truncatable
@@ -1506,7 +1508,6 @@ class ColumnElement(
         co._proxies = [self]
         if selectable._is_clone_of is not None:
             co._is_clone_of = selectable._is_clone_of.columns.get(key)
-        assert key is not None
         return key, co
 
     def cast(self, type_: TypeEngine[_T]) -> Cast[_T]:
@@ -4050,13 +4051,14 @@ class NamedColumn(ColumnElement[_T]):
     is_literal = False
     table: Optional[FromClause] = None
     name: str
+    key: str
 
     def _compare_name_for_result(self, other):
         return (hasattr(other, "name") and self.name == other.name) or (
             hasattr(other, "_label") and self._label == other._label
         )
 
-    @util.memoized_property
+    @util.ro_memoized_property
     def description(self) -> str:
         return self.name
 
@@ -4125,6 +4127,7 @@ class NamedColumn(ColumnElement[_T]):
             _selectable=selectable,
             is_literal=False,
         )
+
         c._propagate_attrs = selectable._propagate_attrs
         if name is None:
             c.key = self.key
@@ -4192,8 +4195,8 @@ class ColumnClause(
 
     onupdate: Optional[DefaultGenerator] = None
     default: Optional[DefaultGenerator] = None
-    server_default: Optional[DefaultGenerator] = None
-    server_onupdate: Optional[DefaultGenerator] = None
+    server_default: Optional[FetchedValue] = None
+    server_onupdate: Optional[FetchedValue] = None
 
     _is_multiparam_column = False
 
index 1a1fc4c4178c362208e51f11f035f32483e040a4..0d74e2e4c1b62029f690db539d32a978b52fc7e3 100644 (file)
@@ -7,11 +7,23 @@
 
 from __future__ import annotations
 
+from typing import Any
+from typing import TYPE_CHECKING
+
 from .base import SchemaEventTarget
 from .. import event
 
+if TYPE_CHECKING:
+    from .schema import Column
+    from .schema import Constraint
+    from .schema import SchemaItem
+    from .schema import Table
+    from ..engine.base import Connection
+    from ..engine.interfaces import ReflectedColumn
+    from ..engine.reflection import Inspector
+
 
-class DDLEvents(event.Events):
+class DDLEvents(event.Events[SchemaEventTarget]):
     """
     Define event listeners for schema objects,
     that is, :class:`.SchemaItem` and other :class:`.SchemaEventTarget`
@@ -93,7 +105,9 @@ class DDLEvents(event.Events):
     _target_class_doc = "SomeSchemaClassOrObject"
     _dispatch_target = SchemaEventTarget
 
-    def before_create(self, target, connection, **kw):
+    def before_create(
+        self, target: SchemaEventTarget, connection: Connection, **kw: Any
+    ) -> None:
         r"""Called before CREATE statements are emitted.
 
         :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -120,7 +134,9 @@ class DDLEvents(event.Events):
 
         """
 
-    def after_create(self, target, connection, **kw):
+    def after_create(
+        self, target: SchemaEventTarget, connection: Connection, **kw: Any
+    ) -> None:
         r"""Called after CREATE statements are emitted.
 
         :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -142,7 +158,9 @@ class DDLEvents(event.Events):
 
         """
 
-    def before_drop(self, target, connection, **kw):
+    def before_drop(
+        self, target: SchemaEventTarget, connection: Connection, **kw: Any
+    ) -> None:
         r"""Called before DROP statements are emitted.
 
         :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -164,7 +182,9 @@ class DDLEvents(event.Events):
 
         """
 
-    def after_drop(self, target, connection, **kw):
+    def after_drop(
+        self, target: SchemaEventTarget, connection: Connection, **kw: Any
+    ) -> None:
         r"""Called after DROP statements are emitted.
 
         :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -186,7 +206,9 @@ class DDLEvents(event.Events):
 
         """
 
-    def before_parent_attach(self, target, parent):
+    def before_parent_attach(
+        self, target: SchemaEventTarget, parent: SchemaItem
+    ) -> None:
         """Called before a :class:`.SchemaItem` is associated with
         a parent :class:`.SchemaItem`.
 
@@ -201,7 +223,9 @@ class DDLEvents(event.Events):
 
         """
 
-    def after_parent_attach(self, target, parent):
+    def after_parent_attach(
+        self, target: SchemaEventTarget, parent: SchemaItem
+    ) -> None:
         """Called after a :class:`.SchemaItem` is associated with
         a parent :class:`.SchemaItem`.
 
@@ -216,13 +240,17 @@ class DDLEvents(event.Events):
 
         """
 
-    def _sa_event_column_added_to_pk_constraint(self, const, col):
+    def _sa_event_column_added_to_pk_constraint(
+        self, const: Constraint, col: Column[Any]
+    ) -> None:
         """internal event hook used for primary key naming convention
         updates.
 
         """
 
-    def column_reflect(self, inspector, table, column_info):
+    def column_reflect(
+        self, inspector: Inspector, table: Table, column_info: ReflectedColumn
+    ) -> None:
         """Called for each unit of 'column info' retrieved when
         a :class:`_schema.Table` is being reflected.
 
index 36ddbf309b047c4f8c2abb25c0d414c887ac6228..455e74f7b0e3ea13773dd8d57a0b6e9200364d67 100644 (file)
@@ -43,7 +43,6 @@ from ._elements_constructors import text as text
 from ._elements_constructors import true as true
 from ._elements_constructors import tuple_ as tuple_
 from ._elements_constructors import type_coerce as type_coerce
-from ._elements_constructors import typing as typing
 from ._elements_constructors import within_group as within_group
 from ._selectable_constructors import alias as alias
 from ._selectable_constructors import cte as cte
index 4c4f49aa841ab6b4426b9318c4a02a886000ee89..beb73c1b5061613f278fe270daa8a9429fe349fe 100644 (file)
@@ -211,9 +211,11 @@ class StrictFromClauseRole(FromClauseRole):
     __slots__ = ()
     # does not allow text() or select() objects
 
-    c: ColumnCollection
+    c: ColumnCollection[Any]
 
-    @property
+    # this should be ->str , however, working around:
+    # https://github.com/python/mypy/issues/12440
+    @util.ro_non_memoized_property
     def description(self) -> str:
         raise NotImplementedError()
 
index 5cfb55603feea6d66f0d0bc132c595a43082bb06..540b62e8aa469e37db1589f8a625efa59bdbb43a 100644 (file)
@@ -30,16 +30,22 @@ as components in SQL expressions.
 """
 from __future__ import annotations
 
+from abc import ABC
 import collections
+import operator
 import typing
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import List
 from typing import MutableMapping
 from typing import Optional
 from typing import overload
 from typing import Sequence as _typing_Sequence
+from typing import Set
+from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -48,6 +54,7 @@ from . import ddl
 from . import roles
 from . import type_api
 from . import visitors
+from .base import ColumnCollection
 from .base import DedupeColumnCollection
 from .base import DialectKWArgs
 from .base import Executable
@@ -67,12 +74,15 @@ from .. import exc
 from .. import inspection
 from .. import util
 from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypeGuard
 
 if typing.TYPE_CHECKING:
     from .type_api import TypeEngine
     from ..engine import Connection
     from ..engine import Engine
-
+    from ..engine.interfaces import ExecutionContext
+    from ..engine.mock import MockConnection
 _T = TypeVar("_T", bound="Any")
 _ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement]
 _TAB = TypeVar("_TAB", bound="Table")
@@ -102,7 +112,7 @@ NULL_UNSPECIFIED = util.symbol(
 )
 
 
-def _get_table_key(name, schema):
+def _get_table_key(name: str, schema: Optional[str]) -> str:
     if schema is None:
         return name
     else:
@@ -207,7 +217,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
 
     __visit_name__ = "table"
 
-    constraints = None
+    constraints: Set[Constraint]
     """A collection of all :class:`_schema.Constraint` objects associated with
       this :class:`_schema.Table`.
 
@@ -235,7 +245,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
 
     """
 
-    indexes = None
+    indexes: Set[Index]
     """A collection of all :class:`_schema.Index` objects associated with this
       :class:`_schema.Table`.
 
@@ -249,6 +259,14 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
         ("schema", InternalTraversal.dp_string)
     ]
 
+    if TYPE_CHECKING:
+
+        @util.non_memoized_property
+        def columns(self) -> ColumnCollection[Column[Any]]:
+            ...
+
+        c: ColumnCollection[Column[Any]]
+
     def _gen_cache_key(self, anon_map, bindparams):
         if self._annotations:
             return (self,) + self._annotations_cache_key
@@ -736,11 +754,12 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
             )
 
     @property
-    def _sorted_constraints(self):
+    def _sorted_constraints(self) -> List[Constraint]:
         """Return the set of constraints as a list, sorted by creation
         order.
 
         """
+
         return sorted(self.constraints, key=lambda c: c._creation_order)
 
     @property
@@ -801,6 +820,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
         )
         self.info = kwargs.pop("info", self.info)
 
+        exclude_columns: _typing_Sequence[str]
+
         if autoload:
             if not autoload_replace:
                 # don't replace columns already present.
@@ -1074,8 +1095,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
             return metadata.tables[key]
 
         args = []
-        for c in self.columns:
-            args.append(c._copy(schema=schema))
+        for col in self.columns:
+            args.append(col._copy(schema=schema))
         table = Table(
             name,
             metadata,
@@ -1084,28 +1105,30 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
             *args,
             **self.kwargs,
         )
-        for c in self.constraints:
-            if isinstance(c, ForeignKeyConstraint):
-                referred_schema = c._referred_schema
+        for const in self.constraints:
+            if isinstance(const, ForeignKeyConstraint):
+                referred_schema = const._referred_schema
                 if referred_schema_fn:
                     fk_constraint_schema = referred_schema_fn(
-                        self, schema, c, referred_schema
+                        self, schema, const, referred_schema
                     )
                 else:
                     fk_constraint_schema = (
                         schema if referred_schema == self.schema else None
                     )
                 table.append_constraint(
-                    c._copy(schema=fk_constraint_schema, target_table=table)
+                    const._copy(
+                        schema=fk_constraint_schema, target_table=table
+                    )
                 )
-            elif not c._type_bound:
+            elif not const._type_bound:
                 # skip unique constraints that would be generated
                 # by the 'unique' flag on Column
-                if c._column_flag:
+                if const._column_flag:
                     continue
 
                 table.append_constraint(
-                    c._copy(schema=schema, target_table=table)
+                    const._copy(schema=schema, target_table=table)
                 )
         for index in self.indexes:
             # skip indexes that would be generated
@@ -1734,23 +1757,25 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
 
         name = kwargs.pop("name", None)
         type_ = kwargs.pop("type_", None)
-        args = list(args)
-        if args:
-            if isinstance(args[0], str):
+        l_args = list(args)
+        del args
+
+        if l_args:
+            if isinstance(l_args[0], str):
                 if name is not None:
                     raise exc.ArgumentError(
                         "May not pass name positionally and as a keyword."
                     )
-                name = args.pop(0)
-        if args:
-            coltype = args[0]
+                name = l_args.pop(0)
+        if l_args:
+            coltype = l_args[0]
 
             if hasattr(coltype, "_sqla_type"):
                 if type_ is not None:
                     raise exc.ArgumentError(
                         "May not pass type_ positionally and as a keyword."
                     )
-                type_ = args.pop(0)
+                type_ = l_args.pop(0)
 
         if name is not None:
             name = quoted_name(name, kwargs.pop("quote", None))
@@ -1772,7 +1797,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
         else:
             self.nullable = not primary_key
 
-        self.default = kwargs.pop("default", None)
+        default = kwargs.pop("default", None)
+        onupdate = kwargs.pop("onupdate", None)
+
         self.server_default = kwargs.pop("server_default", None)
         self.server_onupdate = kwargs.pop("server_onupdate", None)
 
@@ -1784,7 +1811,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
 
         self.system = kwargs.pop("system", False)
         self.doc = kwargs.pop("doc", None)
-        self.onupdate = kwargs.pop("onupdate", None)
         self.autoincrement = kwargs.pop("autoincrement", "auto")
         self.constraints = set()
         self.foreign_keys = set()
@@ -1803,32 +1829,38 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
                 if isinstance(impl, SchemaEventTarget):
                     impl._set_parent_with_dispatch(self)
 
-        if self.default is not None:
-            if isinstance(self.default, (ColumnDefault, Sequence)):
-                args.append(self.default)
-            else:
-                args.append(ColumnDefault(self.default))
+        if default is not None:
+            if not isinstance(default, (ColumnDefault, Sequence)):
+                default = ColumnDefault(default)
+
+            self.default = default
+            l_args.append(default)
+        else:
+            self.default = None
+
+        if onupdate is not None:
+            if not isinstance(onupdate, (ColumnDefault, Sequence)):
+                onupdate = ColumnDefault(onupdate, for_update=True)
+
+            self.onupdate = onupdate
+            l_args.append(onupdate)
+        else:
+            self.onpudate = None
 
         if self.server_default is not None:
             if isinstance(self.server_default, FetchedValue):
-                args.append(self.server_default._as_for_update(False))
+                l_args.append(self.server_default._as_for_update(False))
             else:
-                args.append(DefaultClause(self.server_default))
-
-        if self.onupdate is not None:
-            if isinstance(self.onupdate, (ColumnDefault, Sequence)):
-                args.append(self.onupdate)
-            else:
-                args.append(ColumnDefault(self.onupdate, for_update=True))
+                l_args.append(DefaultClause(self.server_default))
 
         if self.server_onupdate is not None:
             if isinstance(self.server_onupdate, FetchedValue):
-                args.append(self.server_onupdate._as_for_update(True))
+                l_args.append(self.server_onupdate._as_for_update(True))
             else:
-                args.append(
+                l_args.append(
                     DefaultClause(self.server_onupdate, for_update=True)
                 )
-        self._init_items(*args)
+        self._init_items(*l_args)
 
         util.set_creation_order(self)
 
@@ -1837,7 +1869,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
 
         self._extra_kwargs(**kwargs)
 
-    foreign_keys = None
+    table: Table
+
+    constraints: Set[Constraint]
+
+    foreign_keys: Set[ForeignKey]
     """A collection of all :class:`_schema.ForeignKey` marker objects
        associated with this :class:`_schema.Column`.
 
@@ -1850,7 +1886,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
 
     """
 
-    index = None
+    index: bool
     """The value of the :paramref:`_schema.Column.index` parameter.
 
        Does not indicate if this :class:`_schema.Column` is actually indexed
@@ -1861,7 +1897,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
            :attr:`_schema.Table.indexes`
     """
 
-    unique = None
+    unique: bool
     """The value of the :paramref:`_schema.Column.unique` parameter.
 
        Does not indicate if this :class:`_schema.Column` is actually subject to
@@ -2074,8 +2110,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
         server_default = self.server_default
         server_onupdate = self.server_onupdate
         if isinstance(server_default, (Computed, Identity)):
+            args.append(server_default._copy(**kw))
             server_default = server_onupdate = None
-            args.append(self.server_default._copy(**kw))
 
         type_ = self.type
         if isinstance(type_, SchemaEventTarget):
@@ -2203,9 +2239,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
 
     __visit_name__ = "foreign_key"
 
+    parent: Column[Any]
+
     def __init__(
         self,
-        column: Union[str, Column, SQLCoreOperations],
+        column: Union[str, Column[Any], SQLCoreOperations[Any]],
         _constraint: Optional["ForeignKeyConstraint"] = None,
         use_alter: bool = False,
         name: Optional[str] = None,
@@ -2296,7 +2334,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
             self._table_column = self._colspec
 
             if not isinstance(
-                self._table_column.table, (util.NoneType, TableClause)
+                self._table_column.table, (type(None), TableClause)
             ):
                 raise exc.ArgumentError(
                     "ForeignKey received Column not bound "
@@ -2309,7 +2347,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         # object passes itself in when creating ForeignKey
         # markers.
         self.constraint = _constraint
-        self.parent = None
+
+        # .parent is not Optional under normal use
+        self.parent = None  # type: ignore
+
         self.use_alter = use_alter
         self.name = name
         self.onupdate = onupdate
@@ -2501,19 +2542,18 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         return parenttable, tablekey, colname
 
     def _link_to_col_by_colstring(self, parenttable, table, colname):
-        if not hasattr(self.constraint, "_referred_table"):
-            self.constraint._referred_table = table
-        else:
-            assert self.constraint._referred_table is table
-
         _column = None
         if colname is None:
             # colname is None in the case that ForeignKey argument
             # was specified as table name only, in which case we
             # match the column name to the same column on the
             # parent.
-            key = self.parent
-            _column = table.c.get(self.parent.key, None)
+            # this use case wasn't working in later 1.x series
+            # as it had no test coverage; fixed in 2.0
+            parent = self.parent
+            assert parent is not None
+            key = parent.key
+            _column = table.c.get(key, None)
         elif self.link_to_name:
             key = colname
             for c in table.c:
@@ -2533,10 +2573,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
                 key,
             )
 
-        self._set_target_column(_column)
+        return _column
 
     def _set_target_column(self, column):
-        assert isinstance(self.parent.table, Table)
+        assert self.parent is not None
 
         # propagate TypeEngine to parent if it didn't have one
         if self.parent.type._isnull:
@@ -2561,11 +2601,6 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         If no target column has been established, an exception
         is raised.
 
-        .. versionchanged:: 0.9.0
-            Foreign key target column resolution now occurs as soon as both
-            the ForeignKey object and the remote Column to which it refers
-            are both associated with the same MetaData object.
-
         """
 
         if isinstance(self._colspec, str):
@@ -2586,14 +2621,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
                     "parent MetaData" % parenttable
                 )
             else:
-                raise exc.NoReferencedColumnError(
-                    "Could not initialize target column for "
-                    "ForeignKey '%s' on table '%s': "
-                    "table '%s' has no column named '%s'"
-                    % (self._colspec, parenttable.name, tablekey, colname),
-                    tablekey,
-                    colname,
+                table = parenttable.metadata.tables[tablekey]
+                return self._link_to_col_by_colstring(
+                    parenttable, table, colname
                 )
+
         elif hasattr(self._colspec, "__clause_element__"):
             _column = self._colspec.__clause_element__()
             return _column
@@ -2601,18 +2633,22 @@ class ForeignKey(DialectKWArgs, SchemaItem):
             _column = self._colspec
             return _column
 
-    def _set_parent(self, column, **kw):
-        if self.parent is not None and self.parent is not column:
+    def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+        assert isinstance(parent, Column)
+
+        if self.parent is not None and self.parent is not parent:
             raise exc.InvalidRequestError(
                 "This ForeignKey already has a parent !"
             )
-        self.parent = column
+        self.parent = parent
         self.parent.foreign_keys.add(self)
         self.parent._on_table_attach(self._set_table)
 
     def _set_remote_table(self, table):
-        parenttable, tablekey, colname = self._resolve_col_tokens()
-        self._link_to_col_by_colstring(parenttable, table, colname)
+        parenttable, _, colname = self._resolve_col_tokens()
+        _column = self._link_to_col_by_colstring(parenttable, table, colname)
+        self._set_target_column(_column)
+        assert self.constraint is not None
         self.constraint._validate_dest_table(table)
 
     def _remove_from_metadata(self, metadata):
@@ -2651,10 +2687,15 @@ class ForeignKey(DialectKWArgs, SchemaItem):
             if table_key in parenttable.metadata.tables:
                 table = parenttable.metadata.tables[table_key]
                 try:
-                    self._link_to_col_by_colstring(parenttable, table, colname)
+                    _column = self._link_to_col_by_colstring(
+                        parenttable, table, colname
+                    )
                 except exc.NoReferencedColumnError:
                     # this is OK, we'll try later
                     pass
+                else:
+                    self._set_target_column(_column)
+
             parenttable.metadata._fk_memos[fk_key].append(self)
         elif hasattr(self._colspec, "__clause_element__"):
             _column = self._colspec.__clause_element__()
@@ -2664,6 +2705,31 @@ class ForeignKey(DialectKWArgs, SchemaItem):
             self._set_target_column(_column)
 
 
+if TYPE_CHECKING:
+
+    def default_is_sequence(
+        obj: Optional[DefaultGenerator],
+    ) -> TypeGuard[Sequence]:
+        ...
+
+    def default_is_clause_element(
+        obj: Optional[DefaultGenerator],
+    ) -> TypeGuard[ColumnElementColumnDefault]:
+        ...
+
+    def default_is_scalar(
+        obj: Optional[DefaultGenerator],
+    ) -> TypeGuard[ScalarElementColumnDefault]:
+        ...
+
+else:
+    default_is_sequence = operator.attrgetter("is_sequence")
+
+    default_is_clause_element = operator.attrgetter("is_clause_element")
+
+    default_is_scalar = operator.attrgetter("is_scalar")
+
+
 class DefaultGenerator(Executable, SchemaItem):
     """Base class for column *default* values."""
 
@@ -2671,18 +2737,18 @@ class DefaultGenerator(Executable, SchemaItem):
 
     is_sequence = False
     is_server_default = False
+    is_clause_element = False
+    is_callable = False
     is_scalar = False
-    column = None
+    column: Optional[Column[Any]]
 
     def __init__(self, for_update=False):
         self.for_update = for_update
 
-    @util.memoized_property
-    def is_callable(self):
-        raise NotImplementedError()
-
-    def _set_parent(self, column, **kw):
-        self.column = column
+    def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+        if TYPE_CHECKING:
+            assert isinstance(parent, Column)
+        self.column = parent
         if self.for_update:
             self.column.onupdate = self
         else:
@@ -2696,7 +2762,7 @@ class DefaultGenerator(Executable, SchemaItem):
         )
 
 
-class ColumnDefault(DefaultGenerator):
+class ColumnDefault(DefaultGenerator, ABC):
     """A plain default value on a column.
 
     This could correspond to a constant, a callable function,
@@ -2718,7 +2784,30 @@ class ColumnDefault(DefaultGenerator):
 
     """
 
-    def __init__(self, arg, **kwargs):
+    arg: Any
+
+    @overload
+    def __new__(
+        cls, arg: Callable[..., Any], for_update: bool = ...
+    ) -> CallableColumnDefault:
+        ...
+
+    @overload
+    def __new__(
+        cls, arg: ColumnElement[Any], for_update: bool = ...
+    ) -> ColumnElementColumnDefault:
+        ...
+
+    # if I return ScalarElementColumnDefault here, which is what's actually
+    # returned, mypy complains that
+    # overloads overlap w/ incompatible return types.
+    @overload
+    def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault:
+        ...
+
+    def __new__(
+        cls, arg: Any = None, for_update: bool = False
+    ) -> ColumnDefault:
         """Construct a new :class:`.ColumnDefault`.
 
 
@@ -2744,70 +2833,121 @@ class ColumnDefault(DefaultGenerator):
            statement and parameters.
 
         """
-        super(ColumnDefault, self).__init__(**kwargs)
+
         if isinstance(arg, FetchedValue):
             raise exc.ArgumentError(
                 "ColumnDefault may not be a server-side default type."
             )
-        if callable(arg):
-            arg = self._maybe_wrap_callable(arg)
+        elif callable(arg):
+            cls = CallableColumnDefault
+        elif isinstance(arg, ClauseElement):
+            cls = ColumnElementColumnDefault
+        elif arg is not None:
+            cls = ScalarElementColumnDefault
+
+        return object.__new__(cls)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.arg!r})"
+
+
+class ScalarElementColumnDefault(ColumnDefault):
+    """default generator for a fixed scalar Python value
+
+    .. versionadded: 2.0
+
+    """
+
+    is_scalar = True
+
+    def __init__(self, arg: Any, for_update: bool = False):
+        self.for_update = for_update
         self.arg = arg
 
-    @util.memoized_property
-    def is_callable(self):
-        return callable(self.arg)
 
-    @util.memoized_property
-    def is_clause_element(self):
-        return isinstance(self.arg, ClauseElement)
+# _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"]
+_SQLExprDefault = Union["ColumnElement[Any]", "TextClause"]
 
-    @util.memoized_property
-    def is_scalar(self):
-        return (
-            not self.is_callable
-            and not self.is_clause_element
-            and not self.is_sequence
-        )
+
+class ColumnElementColumnDefault(ColumnDefault):
+    """default generator for a SQL expression
+
+    .. versionadded:: 2.0
+
+    """
+
+    is_clause_element = True
+
+    arg: _SQLExprDefault
+
+    def __init__(
+        self,
+        arg: _SQLExprDefault,
+        for_update: bool = False,
+    ):
+        self.for_update = for_update
+        self.arg = arg
 
     @util.memoized_property
     @util.preload_module("sqlalchemy.sql.sqltypes")
     def _arg_is_typed(self):
         sqltypes = util.preloaded.sql_sqltypes
 
-        if self.is_clause_element:
-            return not isinstance(self.arg.type, sqltypes.NullType)
-        else:
-            return False
+        return not isinstance(self.arg.type, sqltypes.NullType)
+
+
+class _CallableColumnDefaultProtocol(Protocol):
+    def __call__(self, context: ExecutionContext) -> Any:
+        ...
 
-    def _maybe_wrap_callable(self, fn):
+
+class CallableColumnDefault(ColumnDefault):
+    """default generator for a callable Python function
+
+    .. versionadded:: 2.0
+
+    """
+
+    is_callable = True
+    arg: _CallableColumnDefaultProtocol
+
+    def __init__(
+        self,
+        arg: Union[_CallableColumnDefaultProtocol, Callable[[], Any]],
+        for_update: bool = False,
+    ):
+        self.for_update = for_update
+        self.arg = self._maybe_wrap_callable(arg)
+
+    def _maybe_wrap_callable(
+        self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]]
+    ) -> _CallableColumnDefaultProtocol:
         """Wrap callables that don't accept a context.
 
         This is to allow easy compatibility with default callables
         that aren't specific to accepting of a context.
 
         """
+
         try:
             argspec = util.get_callable_argspec(fn, no_self=True)
         except TypeError:
-            return util.wrap_callable(lambda ctx: fn(), fn)
+            return util.wrap_callable(lambda ctx: fn(), fn)  # type: ignore
 
         defaulted = argspec[3] is not None and len(argspec[3]) or 0
         positionals = len(argspec[0]) - defaulted
 
         if positionals == 0:
-            return util.wrap_callable(lambda ctx: fn(), fn)
+            return util.wrap_callable(lambda ctx: fn(), fn)  # type: ignore
 
         elif positionals == 1:
-            return fn
+            return fn  # type: ignore
         else:
             raise exc.ArgumentError(
                 "ColumnDefault Python function takes zero or one "
                 "positional arguments"
             )
 
-    def __repr__(self):
-        return "ColumnDefault(%r)" % (self.arg,)
-
 
 class IdentityOptions:
     """Defines options for a named database sequence or an identity column.
@@ -2899,6 +3039,8 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
 
     is_sequence = True
 
+    column: Optional[Column[Any]] = None
+
     def __init__(
         self,
         name,
@@ -3087,14 +3229,6 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
         else:
             self.data_type = None
 
-    @util.memoized_property
-    def is_callable(self):
-        return False
-
-    @util.memoized_property
-    def is_clause_element(self):
-        return False
-
     @util.preload_module("sqlalchemy.sql.functions")
     def next_value(self):
         """Return a :class:`.next_value` function element
@@ -3235,6 +3369,9 @@ class Constraint(DialectKWArgs, SchemaItem):
 
     __visit_name__ = "constraint"
 
+    _creation_order: int
+    _column_flag: bool
+
     def __init__(
         self,
         name=None,
@@ -3316,8 +3453,6 @@ class Constraint(DialectKWArgs, SchemaItem):
 
 
 class ColumnCollectionMixin:
-
-    columns = None
     """A :class:`_expression.ColumnCollection` of :class:`_schema.Column`
     objects.
 
@@ -3326,8 +3461,17 @@ class ColumnCollectionMixin:
 
     """
 
+    columns: ColumnCollection[Column[Any]]
+
     _allow_multiple_tables = False
 
+    if TYPE_CHECKING:
+
+        def _set_parent_with_dispatch(
+            self, parent: SchemaEventTarget, **kw: Any
+        ) -> None:
+            ...
+
     def __init__(self, *columns, **kw):
         _autoattach = kw.pop("_autoattach", True)
         self._column_flag = kw.pop("_column_flag", False)
@@ -3404,14 +3548,16 @@ class ColumnCollectionMixin:
                     )
                 )
 
-    def _col_expressions(self, table):
+    def _col_expressions(self, table: Table) -> List[Column[Any]]:
         return [
             table.c[col] if isinstance(col, str) else col
             for col in self._pending_colargs
         ]
 
-    def _set_parent(self, table, **kw):
-        for col in self._col_expressions(table):
+    def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+        if TYPE_CHECKING:
+            assert isinstance(parent, Table)
+        for col in self._col_expressions(parent):
             if col is not None:
                 self.columns.add(col)
 
@@ -3446,7 +3592,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
             self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
         )
 
-    columns = None
+    columns: DedupeColumnCollection[Column[Any]]
     """A :class:`_expression.ColumnCollection` representing the set of columns
     for this constraint.
 
@@ -3568,7 +3714,7 @@ class CheckConstraint(ColumnCollectionConstraint):
         """
 
         self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
-        columns = []
+        columns: List[Column[Any]] = []
         visitors.traverse(self.sqltext, {}, {"column": columns.append})
 
         super(CheckConstraint, self).__init__(
@@ -3779,17 +3925,17 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
                 assert table is self.parent
             self._set_parent_with_dispatch(table)
 
-    def _append_element(self, column, fk):
+    def _append_element(self, column: Column[Any], fk: ForeignKey) -> None:
         self.columns.add(column)
         self.elements.append(fk)
 
-    columns = None
+    columns: DedupeColumnCollection[Column[Any]]
     """A :class:`_expression.ColumnCollection` representing the set of columns
     for this constraint.
 
     """
 
-    elements = None
+    elements: List[ForeignKey]
     """A sequence of :class:`_schema.ForeignKey` objects.
 
     Each :class:`_schema.ForeignKey`
@@ -4271,7 +4417,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
 
         self._validate_dialect_kwargs(kw)
 
-        self.expressions = []
+        self.expressions: List[ColumnElement[Any]] = []
         # will call _set_parent() if table-bound column
         # objects are present
         ColumnCollectionMixin.__init__(
@@ -4501,11 +4647,13 @@ class MetaData(HasSchemaAttr):
         )
         if info:
             self.info = info
-        self._schemas = set()
-        self._sequences = {}
-        self._fk_memos = collections.defaultdict(list)
+        self._schemas: Set[str] = set()
+        self._sequences: Dict[str, Sequence] = {}
+        self._fk_memos: Dict[
+            Tuple[str, str], List[ForeignKey]
+        ] = collections.defaultdict(list)
 
-    tables: Dict[str, Table]
+    tables: util.FacadeDict[str, Table]
     """A dictionary of :class:`_schema.Table`
     objects keyed to their name or "table key".
 
@@ -4539,7 +4687,7 @@ class MetaData(HasSchemaAttr):
 
     def _remove_table(self, name, schema):
         key = _get_table_key(name, schema)
-        removed = dict.pop(self.tables, key, None)
+        removed = dict.pop(self.tables, key, None)  # type: ignore
         if removed is not None:
             for fk in removed.foreign_keys:
                 fk._remove_from_metadata(self)
@@ -4634,12 +4782,12 @@ class MetaData(HasSchemaAttr):
 
         """
         return ddl.sort_tables(
-            sorted(self.tables.values(), key=lambda t: t.key)
+            sorted(self.tables.values(), key=lambda t: t.key)  # type: ignore
         )
 
     def reflect(
         self,
-        bind: Union["Engine", "Connection"],
+        bind: Union[Engine, Connection],
         schema: Optional[str] = None,
         views: bool = False,
         only: Optional[_typing_Sequence[str]] = None,
@@ -4647,7 +4795,7 @@ class MetaData(HasSchemaAttr):
         autoload_replace: bool = True,
         resolve_fks: bool = True,
         **dialect_kwargs: Any,
-    ):
+    ) -> None:
         r"""Load all available table definitions from the database.
 
         Automatically creates ``Table`` entries in this ``MetaData`` for any
@@ -4748,12 +4896,14 @@ class MetaData(HasSchemaAttr):
             if schema is not None:
                 reflect_opts["schema"] = schema
 
-            available = util.OrderedSet(insp.get_table_names(schema))
+            available: util.OrderedSet[str] = util.OrderedSet(
+                insp.get_table_names(schema)
+            )
             if views:
                 available.update(insp.get_view_names(schema))
 
             if schema is not None:
-                available_w_schema = util.OrderedSet(
+                available_w_schema: util.OrderedSet[str] = util.OrderedSet(
                     ["%s.%s" % (schema, name) for name in available]
                 )
             else:
@@ -4796,10 +4946,10 @@ class MetaData(HasSchemaAttr):
 
     def create_all(
         self,
-        bind: Union["Engine", "Connection"],
+        bind: Union[Engine, Connection, MockConnection],
         tables: Optional[_typing_Sequence[Table]] = None,
         checkfirst: bool = True,
-    ):
+    ) -> None:
         """Create all tables stored in this metadata.
 
         Conditional by default, will not attempt to recreate tables already
@@ -4824,10 +4974,10 @@ class MetaData(HasSchemaAttr):
 
     def drop_all(
         self,
-        bind: Union["Engine", "Connection"],
+        bind: Union[Engine, Connection, MockConnection],
         tables: Optional[_typing_Sequence[Table]] = None,
         checkfirst: bool = True,
-    ):
+    ) -> None:
         """Drop all tables stored in this metadata.
 
         Conditional by default, will not attempt to drop tables not present in
index e143d1476182318b94065b463815b1a1e137e2b4..8665a74db6bb838b7f08ef67f0a1d85f4d154883 100644 (file)
@@ -463,7 +463,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
 
     _is_clone_of: Optional[FromClause]
 
-    schema = None
+    schema: Optional[str] = None
     """Define the 'schema' attribute for this :class:`_expression.FromClause`.
 
     This is typically ``None`` for most objects except that of
@@ -673,7 +673,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         """
         return self._cloned_set.intersection(other._cloned_set)
 
-    @property
+    @util.non_memoized_property
     def description(self) -> str:
         """A brief description of this :class:`_expression.FromClause`.
 
@@ -710,7 +710,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         return self.columns
 
     @util.memoized_property
-    def columns(self) -> ColumnCollection:
+    def columns(self) -> ColumnCollection[Any]:
         """A named-based collection of :class:`_expression.ColumnElement`
         objects maintained by this :class:`_expression.FromClause`.
 
@@ -796,7 +796,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
 
     # this is awkward.   maybe there's a better way
     if TYPE_CHECKING:
-        c: ColumnCollection
+        c: ColumnCollection[Any]
     else:
         c = property(
             attrgetter("columns"),
@@ -2399,6 +2399,8 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
 
     _is_table = True
 
+    fullname: str
+
     implicit_returning = False
     """:class:`_expression.TableClause`
     doesn't support having a primary key or column
index 829c1b72e7c1690c3e82d7aac59731ea4c60f276..1a6de34b03a429b38921e3599f3da76fb70507f0 100644 (file)
@@ -345,6 +345,12 @@ class Integer(HasExpressionLookup, TypeEngine[int]):
 
     __visit_name__ = "integer"
 
+    if TYPE_CHECKING:
+
+        @util.ro_memoized_property
+        def _type_affinity(self) -> Type[Integer]:
+            ...
+
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
 
@@ -1892,8 +1898,8 @@ class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]):
             operators.truediv: {Numeric: self.__class__},
         }
 
-    @util.non_memoized_property
-    def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
+    @util.ro_non_memoized_property
+    def _type_affinity(self) -> Type[Interval]:
         return Interval
 
 
index 5a0aba694f9552edc88cb71214b7d7386375c637..9a934a50bc75095ea2ec3a640afa96a067fdbff3 100644 (file)
@@ -705,7 +705,7 @@ class TypeEngine(Visitable, Generic[_T]):
         """
         return self
 
-    @util.memoized_property
+    @util.ro_memoized_property
     def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]:
         """Return a rudimental 'affinity' value expressing the general class
         of type."""
@@ -719,7 +719,7 @@ class TypeEngine(Visitable, Generic[_T]):
         else:
             return self.__class__
 
-    @util.memoized_property
+    @util.ro_memoized_property
     def _generic_type_affinity(
         self,
     ) -> Type[TypeEngine[_T]]:
@@ -1694,7 +1694,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
         tt.impl = tt.impl_instance = typedesc
         return tt
 
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
         return self.impl_instance._type_affinity
 
index b40d22f8edcb0d534b378a7226d123b6201d5df4..7e616cd74f068778b79acfef2caf915a006be9ce 100644 (file)
@@ -130,6 +130,8 @@ from .langhelpers import (
 from .langhelpers import PluginLoader as PluginLoader
 from .langhelpers import portable_instancemethod as portable_instancemethod
 from .langhelpers import quoted_token_parser as quoted_token_parser
+from .langhelpers import ro_memoized_property as ro_memoized_property
+from .langhelpers import ro_non_memoized_property as ro_non_memoized_property
 from .langhelpers import safe_reraise as safe_reraise
 from .langhelpers import set_creation_order as set_creation_order
 from .langhelpers import string_or_unprintable as string_or_unprintable
index 85ae1b65dfe287882fa49af3af6250f18f80de4a..2d974b7372b398bad072a958faf9dfca19120747 100644 (file)
@@ -135,7 +135,7 @@ def coerce_to_immutabledict(d):
 EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
 
 
-class FacadeDict(ImmutableDictBase[Any, Any]):
+class FacadeDict(ImmutableDictBase[_KT, _VT]):
     """A dictionary that is not publicly mutable."""
 
     def __new__(cls, *args):
index 35294715cbbf411467d93a5fe90d17567b6242a0..8cf50c724cbd24f654db7ff7e559870b8f0283d4 100644 (file)
@@ -55,8 +55,8 @@ _T_co = TypeVar("_T_co", covariant=True)
 _F = TypeVar("_F", bound=Callable[..., Any])
 _MP = TypeVar("_MP", bound="memoized_property[Any]")
 _MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
-_HP = TypeVar("_HP", bound="hybridproperty")
-_HM = TypeVar("_HM", bound="hybridmethod")
+_HP = TypeVar("_HP", bound="hybridproperty[Any]")
+_HM = TypeVar("_HM", bound="hybridmethod[Any]")
 
 
 if compat.py310:
@@ -1234,12 +1234,23 @@ class _memoized_property(generic_fn_descriptor[_T_co]):
 # superclass has non-memoized, the class hierarchy of the descriptors
 # would need to be reversed; "class non_memoized(memoized)".  so there's no
 # way to achieve this.
+# additional issues, RO properties:
+# https://github.com/python/mypy/issues/12440
 if TYPE_CHECKING:
+
+    # allow memoized and non-memoized to be freely mixed by having them
+    # be the same class
     memoized_property = generic_fn_descriptor
     non_memoized_property = generic_fn_descriptor
+
+    # for read only situations, mypy only sees @property as read only.
+    # read only is needed when a subtype specializes the return type
+    # of a property, meaning assignment needs to be disallowed
+    ro_memoized_property = property
+    ro_non_memoized_property = property
 else:
-    memoized_property = _memoized_property
-    non_memoized_property = _non_memoized_property
+    memoized_property = ro_memoized_property = _memoized_property
+    non_memoized_property = ro_non_memoized_property = _non_memoized_property
 
 
 def memoized_instancemethod(fn: _F) -> _F:
@@ -1515,7 +1526,9 @@ def duck_type_collection(
         return default
 
 
-def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any:
+def assert_arg_type(
+    arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
+) -> Any:
     if isinstance(arg, argtype):
         return arg
     else:
@@ -1576,37 +1589,37 @@ class classproperty(property):
         return self.fget(cls)  # type: ignore
 
 
-class hybridproperty:
-    def __init__(self, func):
+class hybridproperty(Generic[_T]):
+    def __init__(self, func: Callable[..., _T]):
         self.func = func
         self.clslevel = func
 
-    def __get__(self, instance, owner):
+    def __get__(self, instance: Any, owner: Any) -> _T:
         if instance is None:
             clsval = self.clslevel(owner)
             return clsval
         else:
             return self.func(instance)
 
-    def classlevel(self, func):
+    def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
         self.clslevel = func
         return self
 
 
-class hybridmethod:
+class hybridmethod(Generic[_T]):
     """Decorate a function as cls- or instance- level."""
 
-    def __init__(self, func):
+    def __init__(self, func: Callable[..., _T]):
         self.func = self.__func__ = func
         self.clslevel = func
 
-    def __get__(self, instance, owner):
+    def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
         if instance is None:
-            return self.clslevel.__get__(owner, owner.__class__)
+            return self.clslevel.__get__(owner, owner.__class__)  # type:ignore
         else:
-            return self.func.__get__(instance, owner)
+            return self.func.__get__(instance, owner)  # type:ignore
 
-    def classlevel(self, func):
+    def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
         self.clslevel = func
         return self
 
index 03621a601807b854b334c93a48ba0c87356cba67..0b9d8c62c765ef417a1e2c36b05305872859e416 100644 (file)
@@ -34,8 +34,10 @@ else:
 
 if compat.py310:
     from typing import TypeGuard as TypeGuard
+    from typing import TypeAlias as TypeAlias
 else:
     from typing_extensions import TypeGuard as TypeGuard
+    from typing_extensions import TypeAlias as TypeAlias
 
 if typing.TYPE_CHECKING or compat.py38:
     from typing import SupportsIndex as SupportsIndex
index 7117c2689f38b459750d3cc0e96b7a30e3df3236..acbc69537adc19c56a845d5c7c7939d588dfc1f6 100644 (file)
@@ -51,57 +51,73 @@ reportTypedDictNotRequiredAccess = "warning"
 [tool.mypy]
 mypy_path = "./lib/"
 show_error_codes = true
-strict = false
+strict = true
 incremental = true
 
-# disabled checking
 [[tool.mypy.overrides]]
-module="sqlalchemy.*"
-ignore_errors = true
-warn_unused_ignores = false
-
-strict = true
-
-# some debate at
-# https://github.com/python/mypy/issues/8754.
-# implicit_reexport = true
 
-# individual packages or even modules should be listed here
-# with strictness-specificity set up.  there's no way we are going to get
-# the whole library 100% strictly typed, so we have to tune this based on
-# the type of module or package we are dealing with
-
-[[tool.mypy.overrides]]
 # ad-hoc ignores
 module = [
     "sqlalchemy.engine.reflection",  # interim, should be strict
+
+    # TODO for strict:
+    "sqlalchemy.ext.asyncio.*",
+    "sqlalchemy.ext.automap",
+    "sqlalchemy.ext.compiler",
+    "sqlalchemy.ext.declarative.*",
+    "sqlalchemy.ext.mutable",
+    "sqlalchemy.ext.horizontal_shard",
+
+    "sqlalchemy.sql._selectable_constructors",
+    "sqlalchemy.sql._dml_constructors",
+
+    # TODO for non-strict:
+    "sqlalchemy.ext.baked",
+    "sqlalchemy.ext.instrumentation",
+    "sqlalchemy.ext.indexable",
+    "sqlalchemy.ext.orderinglist",
+    "sqlalchemy.ext.serializer",
+
+    "sqlalchemy.sql.selectable",   # would be nice as strict
+    "sqlalchemy.sql.ddl",
+    "sqlalchemy.sql.functions",  # would be nice as strict
+    "sqlalchemy.sql.lambdas",
+    "sqlalchemy.sql.dml",   # would be nice as strict
+    "sqlalchemy.sql.util",
+
+    # not yet classified:
+    "sqlalchemy.orm.*",
+    "sqlalchemy.dialects.*",
+    "sqlalchemy.cyextension.*",
+    "sqlalchemy.future.*",
+    "sqlalchemy.testing.*",
+
 ]
 
+warn_unused_ignores = false
 ignore_errors = true
 
 # strict checking
 [[tool.mypy.overrides]]
+
 module = [
-    "sqlalchemy.sql.annotation",
-    "sqlalchemy.sql.cache_key",
-    "sqlalchemy.sql._elements_constructors",
-    "sqlalchemy.sql.operators",
-    "sqlalchemy.sql.type_api",
-    "sqlalchemy.sql.roles",
-    "sqlalchemy.sql.visitors",
-    "sqlalchemy.sql._py_util",
+    # packages
     "sqlalchemy.connectors.*",
+    "sqlalchemy.event.*",
+    "sqlalchemy.ext.*",
+    "sqlalchemy.sql.*",
     "sqlalchemy.engine.*",
-    "sqlalchemy.ext.hybrid",
-    "sqlalchemy.ext.associationproxy",
     "sqlalchemy.pool.*",
-    "sqlalchemy.event.*",
+
+    # modules
     "sqlalchemy.events",
     "sqlalchemy.exc",
     "sqlalchemy.inspection",
     "sqlalchemy.schema",
     "sqlalchemy.types",
 ]
+
+warn_unused_ignores = false
 ignore_errors = false
 strict = true
 
@@ -109,20 +125,24 @@ strict = true
 [[tool.mypy.overrides]]
 
 module = [
-    #"sqlalchemy.sql.*",
-    "sqlalchemy.sql.sqltypes",
-    "sqlalchemy.sql.elements",
+    "sqlalchemy.engine.cursor",
+    "sqlalchemy.engine.default",
+
+    "sqlalchemy.sql.base",
     "sqlalchemy.sql.coercions",
     "sqlalchemy.sql.compiler",
-    #"sqlalchemy.sql.default_comparator",
+    "sqlalchemy.sql.crud",
+    "sqlalchemy.sql.elements",    # would be nice as strict
     "sqlalchemy.sql.naming",
+    "sqlalchemy.sql.schema",      # would be nice as strict
+    "sqlalchemy.sql.sqltypes",   # would be nice as strict
     "sqlalchemy.sql.traversals",
+
     "sqlalchemy.util.*",
-    "sqlalchemy.engine.cursor",
-    "sqlalchemy.engine.default",
 ]
 
 
+warn_unused_ignores = false
 ignore_errors = false
 
 # mostly strict without requiring totally untyped things to be
index ec1cba724120c6d7214828af6cfd283dd25dac7d..074b649f2e90bf0b941656886d70e97f95b981ab 100644 (file)
@@ -1,29 +1,29 @@
 # /home/classic/dev/sqlalchemy/test/profiles.txt
 # This file is written out on a per-environment basis.
-# For each test in aaa_profiling, the corresponding function and 
+# For each test in aaa_profiling, the corresponding function and
 # environment is located within this file.  If it doesn't exist,
 # the test is skipped.
-# If a callcount does exist, it is compared to what we received. 
+# If a callcount does exist, it is compared to what we received.
 # assertions are raised if the counts do not match.
-# 
-# To add a new callcount test, apply the function_call_count 
-# decorator and re-run the tests using the --write-profiles 
+#
+# To add a new callcount test, apply the function_call_count
+# decorator and re-run the tests using the --write-profiles
 # option - this file will be rewritten including the new count.
-# 
+#
 
 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
 
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 70
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 70
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 75
 
 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_select
 
index f3260f2721d4c42db3b82657da2c5dc6039ae103..52c77996956da9741a0b3bbf5c819b04b81b7568 100644 (file)
@@ -354,7 +354,8 @@ class DefaultObjectTest(fixtures.TestBase):
             assert_raises_message(
                 sa.exc.ArgumentError,
                 r"SQL expression for WHERE/HAVING role expected, "
-                r"got (?:Sequence|ColumnDefault|DefaultClause)\('y'.*\)",
+                r"got (?:Sequence|(?:ScalarElement)ColumnDefault|"
+                r"DefaultClause)\('y'.*\)",
                 t.select().where,
                 const,
             )
index 11c3e83b7d44b5b132cd6b198ea7c9a908e52d7b..21fc0a6272820785a3086f24683cf1df7309528a 100644 (file)
@@ -760,7 +760,10 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
                 "%s"
                 ", name='someconstraint')" % repr(ck.sqltext),
             ),
-            (ColumnDefault(("foo", "bar")), "ColumnDefault(('foo', 'bar'))"),
+            (
+                ColumnDefault(("foo", "bar")),
+                "ScalarElementColumnDefault(('foo', 'bar'))",
+            ),
         ):
             eq_(repr(const), exp)
 
@@ -916,6 +919,46 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables):
         a2 = a.to_metadata(m2)
         assert b2.c.y.references(a2.c.x)
 
+    def test_fk_w_no_colname(self):
+        """test a ForeignKey that refers to table name only.  the column
+        name is assumed to be the same col name on parent table.
+
+        this is a little used feature from long ago that nonetheless is
+        still in the code.
+
+        The feature was found to be not working but is repaired for
+        SQLAlchemy 2.0.
+
+        """
+        m1 = MetaData()
+        a = Table("a", m1, Column("x", Integer))
+        b = Table("b", m1, Column("x", Integer, ForeignKey("a")))
+        assert b.c.x.references(a.c.x)
+
+        m2 = MetaData()
+        b2 = b.to_metadata(m2)
+        a2 = a.to_metadata(m2)
+        assert b2.c.x.references(a2.c.x)
+
+    def test_fk_w_no_colname_name_missing(self):
+        """test a ForeignKey that refers to table name only.  the column
+        name is assumed to be the same col name on parent table.
+
+        this is a little used feature from long ago that nonetheless is
+        still in the code.
+
+        """
+        m1 = MetaData()
+        a = Table("a", m1, Column("x", Integer))
+        b = Table("b", m1, Column("y", Integer, ForeignKey("a")))
+
+        with expect_raises_message(
+            exc.NoReferencedColumnError,
+            "Could not initialize target column for ForeignKey 'a' on "
+            "table 'b': table 'a' has no column named 'y'",
+        ):
+            assert b.c.y.references(a.c.x)
+
     def test_column_collection_constraint_w_ad_hoc_columns(self):
         """Test ColumnCollectionConstraint that has columns that aren't
         part of the Table.