]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
revenge of pep 484
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 6 May 2022 20:09:52 +0000 (16:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 May 2022 01:57:01 +0000 (21:57 -0400)
trying to get remaining must-haves for ORM

Change-Id: I66a3ecbbb8e5ba37c818c8a92737b576ecf012f7

61 files changed:
doc/build/changelog/unreleased_20/map_decl.rst [new file with mode: 0644]
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/ext/declarative/extensions.py
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/_typing.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/clsregistry.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/exc.py
lib/sqlalchemy/orm/identity.py
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapped_collection.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/state_changes.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/preloaded.py
lib/sqlalchemy/util/topological.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/ext/mypy/plain_files/composite.py [new file with mode: 0644]
test/ext/mypy/test_mypy_plugin_py3k.py
test/orm/declarative/test_basic.py
test/orm/declarative/test_inheritance.py
test/orm/declarative/test_typed_mapping.py
test/orm/inheritance/test_concrete.py
test/orm/test_query.py
test/orm/test_relationships.py

diff --git a/doc/build/changelog/unreleased_20/map_decl.rst b/doc/build/changelog/unreleased_20/map_decl.rst
new file mode 100644 (file)
index 0000000..9e27f5d
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, orm
+
+    Fixed issue where the :meth:`_orm.registry.map_declaratively` method
+    would return an internal "mapper config" object and not the
+    :class:`.Mapper` object as stated in the API documentation.
index f4e22df2dbb94ddc7e16d85c05e02a4d1b4b717f..d5f0d8126318e82bf877ec40c29a03fad076270a 100644 (file)
@@ -21,6 +21,7 @@ from typing import ClassVar
 from typing import Dict
 from typing import Iterator
 from typing import List
+from typing import NoReturn
 from typing import Optional
 from typing import Sequence
 from typing import Tuple
@@ -53,7 +54,11 @@ _UNPICKLED = util.symbol("unpickled")
 
 
 if typing.TYPE_CHECKING:
+    from .base import Connection
+    from .default import DefaultExecutionContext
     from .interfaces import _DBAPICursorDescription
+    from .interfaces import DBAPICursor
+    from .interfaces import Dialect
     from .interfaces import ExecutionContext
     from .result import _KeyIndexType
     from .result import _KeyMapRecType
@@ -61,6 +66,7 @@ if typing.TYPE_CHECKING:
     from .result import _ProcessorsType
     from ..sql.type_api import _ResultProcessorType
 
+
 _T = TypeVar("_T", bound=Any)
 
 # metadata entry tuple indexes.
@@ -235,7 +241,7 @@ class CursorResultMetaData(ResultMetaData):
             ) = context.result_column_struct
             num_ctx_cols = len(result_columns)
         else:
-            result_columns = (
+            result_columns = (  # type: ignore
                 cols_are_ordered
             ) = (
                 num_ctx_cols
@@ -776,25 +782,53 @@ class ResultFetchStrategy:
 
     alternate_cursor_description: Optional[_DBAPICursorDescription] = None
 
-    def soft_close(self, result, dbapi_cursor):
+    def soft_close(
+        self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+    ) -> None:
         raise NotImplementedError()
 
-    def hard_close(self, result, dbapi_cursor):
+    def hard_close(
+        self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+    ) -> None:
         raise NotImplementedError()
 
-    def yield_per(self, result, dbapi_cursor, num):
+    def yield_per(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: Optional[DBAPICursor],
+        num: int,
+    ) -> None:
         return
 
-    def fetchone(self, result, dbapi_cursor, hard_close=False):
+    def fetchone(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: DBAPICursor,
+        hard_close: bool = False,
+    ) -> Any:
         raise NotImplementedError()
 
-    def fetchmany(self, result, dbapi_cursor, size=None):
+    def fetchmany(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: DBAPICursor,
+        size: Optional[int] = None,
+    ) -> Any:
         raise NotImplementedError()
 
-    def fetchall(self, result):
+    def fetchall(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: DBAPICursor,
+    ) -> Any:
         raise NotImplementedError()
 
-    def handle_exception(self, result, dbapi_cursor, err):
+    def handle_exception(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: Optional[DBAPICursor],
+        err: BaseException,
+    ) -> NoReturn:
         raise err
 
 
@@ -882,18 +916,32 @@ class CursorFetchStrategy(ResultFetchStrategy):
 
     __slots__ = ()
 
-    def soft_close(self, result, dbapi_cursor):
+    def soft_close(
+        self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+    ) -> None:
         result.cursor_strategy = _NO_CURSOR_DQL
 
-    def hard_close(self, result, dbapi_cursor):
+    def hard_close(
+        self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+    ) -> None:
         result.cursor_strategy = _NO_CURSOR_DQL
 
-    def handle_exception(self, result, dbapi_cursor, err):
+    def handle_exception(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: Optional[DBAPICursor],
+        err: BaseException,
+    ) -> NoReturn:
         result.connection._handle_dbapi_exception(
             err, None, None, dbapi_cursor, result.context
         )
 
-    def yield_per(self, result, dbapi_cursor, num):
+    def yield_per(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: Optional[DBAPICursor],
+        num: int,
+    ) -> None:
         result.cursor_strategy = BufferedRowCursorFetchStrategy(
             dbapi_cursor,
             {"max_row_buffer": num},
@@ -901,7 +949,12 @@ class CursorFetchStrategy(ResultFetchStrategy):
             growth_factor=0,
         )
 
-    def fetchone(self, result, dbapi_cursor, hard_close=False):
+    def fetchone(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: DBAPICursor,
+        hard_close: bool = False,
+    ) -> Any:
         try:
             row = dbapi_cursor.fetchone()
             if row is None:
@@ -910,7 +963,12 @@ class CursorFetchStrategy(ResultFetchStrategy):
         except BaseException as e:
             self.handle_exception(result, dbapi_cursor, e)
 
-    def fetchmany(self, result, dbapi_cursor, size=None):
+    def fetchmany(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: DBAPICursor,
+        size: Optional[int] = None,
+    ) -> Any:
         try:
             if size is None:
                 l = dbapi_cursor.fetchmany()
@@ -923,7 +981,11 @@ class CursorFetchStrategy(ResultFetchStrategy):
         except BaseException as e:
             self.handle_exception(result, dbapi_cursor, e)
 
-    def fetchall(self, result, dbapi_cursor):
+    def fetchall(
+        self,
+        result: CursorResult[Any],
+        dbapi_cursor: DBAPICursor,
+    ) -> Any:
         try:
             rows = dbapi_cursor.fetchall()
             result._soft_close()
@@ -1163,6 +1225,9 @@ class _NoResultMetaData(ResultMetaData):
 _NO_RESULT_METADATA = _NoResultMetaData()
 
 
+SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
+
+
 class CursorResult(Result[_T]):
     """A Result that is representing state from a DBAPI cursor.
 
@@ -1199,7 +1264,17 @@ class CursorResult(Result[_T]):
     closed: bool = False
     _is_cursor = True
 
-    def __init__(self, context, cursor_strategy, cursor_description):
+    context: DefaultExecutionContext
+    dialect: Dialect
+    cursor_strategy: ResultFetchStrategy
+    connection: Connection
+
+    def __init__(
+        self,
+        context: DefaultExecutionContext,
+        cursor_strategy: ResultFetchStrategy,
+        cursor_description: Optional[_DBAPICursorDescription],
+    ):
         self.context = context
         self.dialect = context.dialect
         self.cursor = context.cursor
@@ -1333,7 +1408,7 @@ class CursorResult(Result[_T]):
 
         if not self._soft_closed:
             cursor = self.cursor
-            self.cursor = None
+            self.cursor = None  # type: ignore
             self.connection._safe_close_cursor(cursor)
             self._soft_closed = True
 
@@ -1605,7 +1680,7 @@ class CursorResult(Result[_T]):
         return self.dialect.supports_sane_multi_rowcount
 
     @util.memoized_property
-    def rowcount(self):
+    def rowcount(self) -> int:
         """Return the 'rowcount' for this result.
 
         The 'rowcount' reports the number of rows *matched*
@@ -1655,6 +1730,7 @@ class CursorResult(Result[_T]):
             return self.context.rowcount
         except BaseException as e:
             self.cursor_strategy.handle_exception(self, self.cursor, e)
+            raise  # not called
 
     @property
     def lastrowid(self):
@@ -1749,7 +1825,7 @@ class CursorResult(Result[_T]):
             )
         return merged_result
 
-    def close(self):
+    def close(self) -> Any:
         """Close this :class:`_engine.CursorResult`.
 
         This closes out the underlying DBAPI cursor corresponding to the
@@ -1772,7 +1848,7 @@ class CursorResult(Result[_T]):
         self._soft_close(hard=True)
 
     @_generative
-    def yield_per(self, num):
+    def yield_per(self: SelfCursorResult, num: int) -> SelfCursorResult:
         self._yield_per = num
         self.cursor_strategy.yield_per(self, self.cursor, num)
         return self
index 6094ad0fbb6f2f9d0fc1ec13ef7f10e5f69e4b24..fc114efa3ae8e984924856dd0ea38300a3d42904 100644 (file)
@@ -64,6 +64,7 @@ if typing.TYPE_CHECKING:
     from .base import Engine
     from .interfaces import _CoreMultiExecuteParams
     from .interfaces import _CoreSingleExecuteParams
+    from .interfaces import _DBAPICursorDescription
     from .interfaces import _DBAPIMultiExecuteParams
     from .interfaces import _ExecuteOptions
     from .interfaces import _IsolationLevel
@@ -1285,8 +1286,8 @@ class DefaultExecutionContext(ExecutionContext):
     def handle_dbapi_exception(self, e):
         pass
 
-    @property
-    def rowcount(self):
+    @util.non_memoized_property
+    def rowcount(self) -> int:
         return self.cursor.rowcount
 
     def supports_sane_rowcount(self):
@@ -1304,7 +1305,7 @@ class DefaultExecutionContext(ExecutionContext):
                 strategy = _cursor.BufferedRowCursorFetchStrategy(
                     self.cursor, self.execution_options
                 )
-            cursor_description = (
+            cursor_description: _DBAPICursorDescription = (
                 strategy.alternate_cursor_description
                 or self.cursor.description
             )
index 6410246039fd157f074cf8b1ee1188621c6ea9bc..e5414b70f3c62fb25e221de7cfe1b8a482081f72 100644 (file)
@@ -133,17 +133,7 @@ class DBAPICursor(Protocol):
     @property
     def description(
         self,
-    ) -> Sequence[
-        Tuple[
-            str,
-            "DBAPIType",
-            Optional[int],
-            Optional[int],
-            Optional[int],
-            Optional[int],
-            Optional[bool],
-        ]
-    ]:
+    ) -> _DBAPICursorDescription:
         """The description attribute of the Cursor.
 
         .. seealso::
@@ -217,7 +207,15 @@ _DBAPIMultiExecuteParams = Union[
 _DBAPIAnyExecuteParams = Union[
     _DBAPIMultiExecuteParams, _DBAPISingleExecuteParams
 ]
-_DBAPICursorDescription = Tuple[str, Any, Any, Any, Any, Any, Any]
+_DBAPICursorDescription = Tuple[
+    str,
+    "DBAPIType",
+    Optional[int],
+    Optional[int],
+    Optional[int],
+    Optional[int],
+    Optional[bool],
+]
 
 _AnySingleExecuteParams = _DBAPISingleExecuteParams
 _AnyMultiExecuteParams = _DBAPIMultiExecuteParams
@@ -2297,6 +2295,9 @@ class ExecutionContext:
 
     """
 
+    engine: Engine
+    """engine which the Connection is associated with"""
+
     connection: Connection
     """Connection object which can be freely used by default value
       generators to execute SQL.  This Connection should reference the
index 420ba5c8c354d6c20f2b0325bd458368d0ccb578..7db95eac9b385d35bb5db104c9de2077eeec10bd 100644 (file)
@@ -53,7 +53,6 @@ from ..orm import ORMDescriptor
 from ..orm.base import SQLORMOperations
 from ..sql import operators
 from ..sql import or_
-from ..sql.elements import SQLCoreOperations
 from ..util.typing import Literal
 from ..util.typing import Protocol
 from ..util.typing import Self
@@ -64,8 +63,10 @@ if typing.TYPE_CHECKING:
     from ..orm.interfaces import MapperProperty
     from ..orm.interfaces import PropComparator
     from ..orm.mapper import Mapper
+    from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _InfoType
 
+
 _T = TypeVar("_T", bound=Any)
 _T_co = TypeVar("_T_co", bound=Any, covariant=True)
 _T_con = TypeVar("_T_con", bound=Any, contravariant=True)
@@ -631,7 +632,9 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
 
     @property
     def _comparator(self) -> PropComparator[Any]:
-        return self._get_property().comparator
+        return getattr(  # type: ignore
+            self.owning_class, self.target_collection
+        ).comparator
 
     def __clause_element__(self) -> NoReturn:
         raise NotImplementedError(
@@ -957,7 +960,9 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
         proxy.setter = setter
 
     def _criterion_exists(
-        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+        self,
+        criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+        **kwargs: Any,
     ) -> ColumnElement[bool]:
         is_has = kwargs.pop("is_has", None)
 
@@ -969,8 +974,8 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
             return self._comparator._criterion_exists(inner)
 
         if self._target_is_object:
-            prop = getattr(self.target_class, self.value_attr)
-            value_expr = prop._criterion_exists(criterion, **kwargs)
+            attr = getattr(self.target_class, self.value_attr)
+            value_expr = attr.comparator._criterion_exists(criterion, **kwargs)
         else:
             if kwargs:
                 raise exc.ArgumentError(
@@ -988,8 +993,10 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
         return self._comparator._criterion_exists(value_expr)
 
     def any(
-        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
-    ) -> SQLCoreOperations[Any]:
+        self,
+        criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+        **kwargs: Any,
+    ) -> ColumnElement[bool]:
         """Produce a proxied 'any' expression using EXISTS.
 
         This expression will be a composed product
@@ -1010,8 +1017,10 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
         )
 
     def has(
-        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
-    ) -> SQLCoreOperations[Any]:
+        self,
+        criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+        **kwargs: Any,
+    ) -> ColumnElement[bool]:
         """Produce a proxied 'has' expression using EXISTS.
 
         This expression will be a composed product
@@ -1069,12 +1078,16 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]):
         self._ambiguous()
 
     def any(
-        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+        self,
+        criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+        **kwargs: Any,
     ) -> NoReturn:
         self._ambiguous()
 
     def has(
-        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+        self,
+        criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+        **kwargs: Any,
     ) -> NoReturn:
         self._ambiguous()
 
index 9faf2ed51fa169f6c6a8826d318b9351d3c5ba79..22fa83c58f37330804c2237ab40b47c042f96e04 100644 (file)
@@ -8,7 +8,10 @@
 
 
 """Public API functions and helpers for declarative."""
+from __future__ import annotations
 
+from typing import Callable
+from typing import TYPE_CHECKING
 
 from ... import inspection
 from ...orm import exc as orm_exc
@@ -20,6 +23,10 @@ from ...orm.util import polymorphic_union
 from ...schema import Table
 from ...util import OrderedDict
 
+if TYPE_CHECKING:
+    from ...engine.reflection import Inspector
+    from ...sql.schema import MetaData
+
 
 class ConcreteBase:
     """A helper class for 'concrete' declarative mappings.
@@ -380,31 +387,36 @@ class DeferredReflection:
                 mapper = thingy.cls.__mapper__
                 metadata = mapper.class_.metadata
                 for rel in mapper._props.values():
+
                     if (
                         isinstance(rel, relationships.Relationship)
-                        and rel.secondary is not None
+                        and rel._init_args.secondary._is_populated()
                     ):
-                        if isinstance(rel.secondary, Table):
-                            cls._reflect_table(rel.secondary, insp)
-                        elif isinstance(rel.secondary, str):
+
+                        secondary_arg = rel._init_args.secondary
+
+                        if isinstance(secondary_arg.argument, Table):
+                            cls._reflect_table(secondary_arg.argument, insp)
+                        elif isinstance(secondary_arg.argument, str):
 
                             _, resolve_arg = _resolver(rel.parent.class_, rel)
 
-                            rel.secondary = resolve_arg(rel.secondary)
-                            rel.secondary._resolvers += (
+                            resolver = resolve_arg(
+                                secondary_arg.argument, True
+                            )
+                            resolver._resolvers += (
                                 cls._sa_deferred_table_resolver(
                                     insp, metadata
                                 ),
                             )
 
-                            # controversy!  do we resolve it here? or leave
-                            # it deferred?   I think doing it here is necessary
-                            # so the connection does not leak.
-                            rel.secondary = rel.secondary()
+                            secondary_arg.argument = resolver()
 
     @classmethod
-    def _sa_deferred_table_resolver(cls, inspector, metadata):
-        def _resolve(key):
+    def _sa_deferred_table_resolver(
+        cls, inspector: Inspector, metadata: MetaData
+    ) -> Callable[[str], Table]:
+        def _resolve(key: str) -> Table:
             t1 = Table(key, metadata)
             cls._reflect_table(t1, inspector)
             return t1
index ea558495b41b8a769fa345aad31c6541597d1fac..accfa8949ce5bc7909ec87826794d0df8834e0bf 100644 (file)
@@ -1305,8 +1305,8 @@ class Comparator(interfaces.PropComparator[_T]):
         return ret_expr
 
     @util.non_memoized_property
-    def property(self) -> Optional[interfaces.MapperProperty[_T]]:
-        return None
+    def property(self) -> interfaces.MapperProperty[_T]:
+        raise NotImplementedError()
 
     def adapt_to_entity(
         self, adapt_to_entity: AliasedInsp[Any]
@@ -1344,7 +1344,7 @@ class ExprComparator(Comparator[_T]):
             return [(self.expression, value)]
 
     @util.non_memoized_property
-    def property(self) -> Optional[MapperProperty[_T]]:
+    def property(self) -> MapperProperty[_T]:
         # this accessor is not normally used, however is accessed by things
         # like ORM synonyms if the hybrid is used in this context; the
         # .property attribute is not necessarily accessible
index 560db9817b064753c4256464514db689484fcbce..18a18bd80075f9735b719b49e2d01ce0fdda94d2 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
 
 from __future__ import annotations
 
@@ -12,6 +11,8 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import Collection
+from typing import Iterable
+from typing import NoReturn
 from typing import Optional
 from typing import overload
 from typing import Type
@@ -45,6 +46,7 @@ from ..util.typing import Literal
 if TYPE_CHECKING:
     from ._typing import _EntityType
     from ._typing import _ORMColumnExprArgument
+    from .descriptor_props import _CC
     from .descriptor_props import _CompositeAttrType
     from .interfaces import PropComparator
     from .mapper import Mapper
@@ -54,14 +56,19 @@ if TYPE_CHECKING:
     from .relationships import _ORMColCollectionArgument
     from .relationships import _ORMOrderByArgument
     from .relationships import _RelationshipJoinConditionArgument
+    from .session import _SessionBind
     from ..sql._typing import _ColumnExpressionArgument
+    from ..sql._typing import _FromClauseArgument
     from ..sql._typing import _InfoType
+    from ..sql._typing import _OnClauseArgument
     from ..sql._typing import _TypeEngineArgument
+    from ..sql.elements import ColumnElement
     from ..sql.schema import _ServerDefaultType
     from ..sql.schema import FetchedValue
     from ..sql.selectable import Alias
     from ..sql.selectable import Subquery
 
+
 _T = typing.TypeVar("_T")
 
 
@@ -424,10 +431,10 @@ def column_property(
 
 @overload
 def composite(
-    class_: Type[_T],
+    class_: Type[_CC],
     *attrs: _CompositeAttrType[Any],
     **kwargs: Any,
-) -> Composite[_T]:
+) -> Composite[_CC]:
     ...
 
 
@@ -680,7 +687,7 @@ def with_loader_criteria(
 
 def relationship(
     argument: Optional[_RelationshipArgumentType[Any]] = None,
-    secondary: Optional[FromClause] = None,
+    secondary: Optional[Union[FromClause, str]] = None,
     *,
     uselist: Optional[bool] = None,
     collection_class: Optional[
@@ -696,14 +703,14 @@ def relationship(
     cascade: str = "save-update, merge",
     viewonly: bool = False,
     lazy: _LazyLoadArgumentType = "select",
-    passive_deletes: bool = False,
+    passive_deletes: Union[Literal["all"], bool] = False,
     passive_updates: bool = True,
     active_history: bool = False,
     enable_typechecks: bool = True,
     foreign_keys: Optional[_ORMColCollectionArgument] = None,
     remote_side: Optional[_ORMColCollectionArgument] = None,
     join_depth: Optional[int] = None,
-    comparator_factory: Optional[Type[PropComparator[Any]]] = None,
+    comparator_factory: Optional[Type[Relationship.Comparator[Any]]] = None,
     single_parent: bool = False,
     innerjoin: bool = False,
     distinct_target_key: Optional[bool] = None,
@@ -1660,10 +1667,19 @@ def synonym(
         than can be achieved with synonyms.
 
     """
-    return Synonym(name, map_column, descriptor, comparator_factory, doc, info)
+    return Synonym(
+        name,
+        map_column=map_column,
+        descriptor=descriptor,
+        comparator_factory=comparator_factory,
+        doc=doc,
+        info=info,
+    )
 
 
-def create_session(bind=None, **kwargs):
+def create_session(
+    bind: Optional[_SessionBind] = None, **kwargs: Any
+) -> Session:
     r"""Create a new :class:`.Session`
     with no automation enabled by default.
 
@@ -1699,7 +1715,7 @@ def create_session(bind=None, **kwargs):
     return Session(bind=bind, **kwargs)
 
 
-def _mapper_fn(*arg, **kw):
+def _mapper_fn(*arg: Any, **kw: Any) -> NoReturn:
     """Placeholder for the now-removed ``mapper()`` function.
 
     Classical mappings should be performed using the
@@ -1726,7 +1742,9 @@ def _mapper_fn(*arg, **kw):
     )
 
 
-def dynamic_loader(argument, **kw):
+def dynamic_loader(
+    argument: Optional[_RelationshipArgumentType[Any]] = None, **kw: Any
+) -> Relationship[Any]:
     """Construct a dynamically-loading mapper property.
 
     This is essentially the same as
@@ -1746,7 +1764,7 @@ def dynamic_loader(argument, **kw):
     return relationship(argument, **kw)
 
 
-def backref(name, **kwargs):
+def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
     """Create a back reference with explicit keyword arguments, which are the
     same arguments one can send to :func:`relationship`.
 
@@ -1765,7 +1783,11 @@ def backref(name, **kwargs):
     return (name, kwargs)
 
 
-def deferred(*columns, **kw):
+def deferred(
+    column: _ORMColumnExprArgument[_T],
+    *additional_columns: _ORMColumnExprArgument[Any],
+    **kw: Any,
+) -> ColumnProperty[_T]:
     r"""Indicate a column-based mapped attribute that by default will
     not load unless accessed.
 
@@ -1791,7 +1813,8 @@ def deferred(*columns, **kw):
         :ref:`deferred`
 
     """
-    return ColumnProperty(deferred=True, *columns, **kw)
+    kw["deferred"] = True
+    return ColumnProperty(column, *additional_columns, **kw)
 
 
 def query_expression(
@@ -1824,7 +1847,7 @@ def query_expression(
     return prop
 
 
-def clear_mappers():
+def clear_mappers() -> None:
     """Remove all mappers from all classes.
 
     .. versionchanged:: 1.4  This function now locates all
@@ -2003,16 +2026,16 @@ def aliased(
 
 
 def with_polymorphic(
-    base,
-    classes,
-    selectable=False,
-    flat=False,
-    polymorphic_on=None,
-    aliased=False,
-    adapt_on_names=False,
-    innerjoin=False,
-    _use_mapper_path=False,
-):
+    base: Union[_O, Mapper[_O]],
+    classes: Iterable[Type[Any]],
+    selectable: Union[Literal[False, None], FromClause] = False,
+    flat: bool = False,
+    polymorphic_on: Optional[ColumnElement[Any]] = None,
+    aliased: bool = False,
+    innerjoin: bool = False,
+    adapt_on_names: bool = False,
+    _use_mapper_path: bool = False,
+) -> AliasedClass[_O]:
     """Produce an :class:`.AliasedClass` construct which specifies
     columns for descendant mappers of the given base.
 
@@ -2096,7 +2119,13 @@ def with_polymorphic(
     )
 
 
-def join(left, right, onclause=None, isouter=False, full=False):
+def join(
+    left: _FromClauseArgument,
+    right: _FromClauseArgument,
+    onclause: Optional[_OnClauseArgument] = None,
+    isouter: bool = False,
+    full: bool = False,
+) -> _ORMJoin:
     r"""Produce an inner join between left and right clauses.
 
     :func:`_orm.join` is an extension to the core join interface
@@ -2135,7 +2164,12 @@ def join(left, right, onclause=None, isouter=False, full=False):
     return _ORMJoin(left, right, onclause, isouter, full)
 
 
-def outerjoin(left, right, onclause=None, full=False):
+def outerjoin(
+    left: _FromClauseArgument,
+    right: _FromClauseArgument,
+    onclause: Optional[_OnClauseArgument] = None,
+    full: bool = False,
+) -> _ORMJoin:
     """Produce a left outer join between left and right clauses.
 
     This is the "outer join" version of the :func:`_orm.join` function,
index 29d82340aba3fe0225712a4f3096fcd4e4f6f750..0e624afe2a84a26b1afc9aa5c3a6145989f1e329 100644 (file)
@@ -2,8 +2,8 @@ from __future__ import annotations
 
 import operator
 from typing import Any
-from typing import Callable
 from typing import Dict
+from typing import Mapping
 from typing import Optional
 from typing import Tuple
 from typing import Type
@@ -20,9 +20,12 @@ from ..util.typing import TypeGuard
 if TYPE_CHECKING:
     from .attributes import AttributeImpl
     from .attributes import CollectionAttributeImpl
+    from .attributes import HasCollectionAdapter
+    from .attributes import QueryableAttribute
     from .base import PassiveFlag
     from .decl_api import registry as _registry_type
     from .descriptor_props import _CompositeClassProto
+    from .interfaces import InspectionAttr
     from .interfaces import MapperProperty
     from .interfaces import UserDefinedOption
     from .mapper import Mapper
@@ -30,11 +33,14 @@ if TYPE_CHECKING:
     from .state import InstanceState
     from .util import AliasedClass
     from .util import AliasedInsp
+    from ..sql._typing import _CE
     from ..sql.base import ExecutableOption
 
 _T = TypeVar("_T", bound=Any)
 
 
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+
 # I would have preferred this were bound=object however it seems
 # to not travel in all situations when defined in that way.
 _O = TypeVar("_O", bound=Any)
@@ -42,6 +48,12 @@ _O = TypeVar("_O", bound=Any)
 
 """
 
+_OO = TypeVar("_OO", bound=object)
+"""The 'ORM mapped object, that's definitely object' type.
+
+"""
+
+
 if TYPE_CHECKING:
     _RegistryType = _registry_type
 
@@ -54,6 +66,7 @@ _EntityType = Union[
 ]
 
 
+_ClassDict = Mapping[str, Any]
 _InstanceDict = Dict[str, Any]
 
 _IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]]
@@ -64,10 +77,19 @@ _ORMColumnExprArgument = Union[
     roles.ExpressionElementRole[_T],
 ]
 
-# somehow Protocol didn't want to work for this one
-_ORMAdapterProto = Callable[
-    [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T]
-]
+
+_ORMCOLEXPR = TypeVar("_ORMCOLEXPR", bound=ColumnElement[Any])
+
+
+class _ORMAdapterProto(Protocol):
+    """protocol for the :class:`.AliasedInsp._orm_adapt_element` method
+    which is a synonym for :class:`.AliasedInsp._adapt_element`.
+
+
+    """
+
+    def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE:
+        ...
 
 
 class _LoaderCallable(Protocol):
@@ -96,6 +118,16 @@ if TYPE_CHECKING:
     def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]:
         ...
 
+    def insp_is_attribute(
+        obj: InspectionAttr,
+    ) -> TypeGuard[QueryableAttribute[Any]]:
+        ...
+
+    def attr_is_internal_proxy(
+        obj: InspectionAttr,
+    ) -> TypeGuard[QueryableAttribute[Any]]:
+        ...
+
     def prop_is_relationship(
         prop: MapperProperty[Any],
     ) -> TypeGuard[Relationship[Any]]:
@@ -106,9 +138,19 @@ if TYPE_CHECKING:
     ) -> TypeGuard[CollectionAttributeImpl]:
         ...
 
+    def is_has_collection_adapter(
+        impl: AttributeImpl,
+    ) -> TypeGuard[HasCollectionAdapter]:
+        ...
+
 else:
     insp_is_mapper_property = operator.attrgetter("is_property")
     insp_is_mapper = operator.attrgetter("is_mapper")
     insp_is_aliased_class = operator.attrgetter("is_aliased_class")
+    insp_is_attribute = operator.attrgetter("is_attribute")
+    attr_is_internal_proxy = operator.attrgetter("_is_internal_proxy")
     is_collection_impl = operator.attrgetter("collection")
     prop_is_relationship = operator.attrgetter("_is_relationship")
+    is_has_collection_adapter = operator.attrgetter(
+        "_is_has_collection_adapter"
+    )
index 9aeaeaa2726d4987d7d394eeb28297b930ebf2fa..b5faa7cbf13cddb4f7142adf2b89a735ac3eedcb 100644 (file)
@@ -117,7 +117,9 @@ class NoKey(str):
     pass
 
 
-_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]]
+_AllPendingType = Sequence[
+    Tuple[Optional["InstanceState[Any]"], Optional[object]]
+]
 
 NO_KEY = NoKey("no name")
 
@@ -798,6 +800,8 @@ class AttributeImpl:
     supports_population: bool
     dynamic: bool
 
+    _is_has_collection_adapter = False
+
     _replace_token: AttributeEventToken
     _remove_token: AttributeEventToken
     _append_token: AttributeEventToken
@@ -1140,7 +1144,7 @@ class AttributeImpl:
         state: InstanceState[Any],
         dict_: _InstanceDict,
         value: Any,
-        initiator: Optional[AttributeEventToken],
+        initiator: Optional[AttributeEventToken] = None,
         passive: PassiveFlag = PASSIVE_OFF,
         check_old: Any = None,
         pop: bool = False,
@@ -1236,7 +1240,7 @@ class ScalarAttributeImpl(AttributeImpl):
         state: InstanceState[Any],
         dict_: Dict[str, Any],
         value: Any,
-        initiator: Optional[AttributeEventToken],
+        initiator: Optional[AttributeEventToken] = None,
         passive: PassiveFlag = PASSIVE_OFF,
         check_old: Optional[object] = None,
         pop: bool = False,
@@ -1402,7 +1406,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         state: InstanceState[Any],
         dict_: _InstanceDict,
         value: Any,
-        initiator: Optional[AttributeEventToken],
+        initiator: Optional[AttributeEventToken] = None,
         passive: PassiveFlag = PASSIVE_OFF,
         check_old: Any = None,
         pop: bool = False,
@@ -1494,6 +1498,9 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 class HasCollectionAdapter:
     __slots__ = ()
 
+    collection: bool
+    _is_has_collection_adapter = True
+
     def _dispose_previous_collection(
         self,
         state: InstanceState[Any],
@@ -1508,7 +1515,7 @@ class HasCollectionAdapter:
         self,
         state: InstanceState[Any],
         dict_: _InstanceDict,
-        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        user_data: Literal[None] = ...,
         passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
     ) -> CollectionAdapter:
         ...
@@ -1518,8 +1525,18 @@ class HasCollectionAdapter:
         self,
         state: InstanceState[Any],
         dict_: _InstanceDict,
-        user_data: Optional[_AdaptedCollectionProtocol] = None,
-        passive: PassiveFlag = PASSIVE_OFF,
+        user_data: _AdaptedCollectionProtocol = ...,
+        passive: PassiveFlag = ...,
+    ) -> CollectionAdapter:
+        ...
+
+    @overload
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = ...,
+        passive: PassiveFlag = ...,
     ) -> Union[
         Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
     ]:
@@ -1530,12 +1547,25 @@ class HasCollectionAdapter:
         state: InstanceState[Any],
         dict_: _InstanceDict,
         user_data: Optional[_AdaptedCollectionProtocol] = None,
-        passive: PassiveFlag = PASSIVE_OFF,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
     ) -> Union[
         Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
     ]:
         raise NotImplementedError()
 
+    def set(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken] = None,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+        check_old: Any = None,
+        pop: bool = False,
+        _adapt: bool = True,
+    ) -> None:
+        raise NotImplementedError()
+
 
 if TYPE_CHECKING:
 
@@ -1790,7 +1820,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         initiator: Optional[AttributeEventToken],
         passive: PassiveFlag = PASSIVE_OFF,
     ) -> None:
-        collection = self.get_collection(state, dict_, passive=passive)
+        collection = self.get_collection(
+            state, dict_, user_data=None, passive=passive
+        )
         if collection is PASSIVE_NO_RESULT:
             value = self.fire_append_event(state, dict_, value, initiator)
             assert (
@@ -1810,7 +1842,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         initiator: Optional[AttributeEventToken],
         passive: PassiveFlag = PASSIVE_OFF,
     ) -> None:
-        collection = self.get_collection(state, state.dict, passive=passive)
+        collection = self.get_collection(
+            state, state.dict, user_data=None, passive=passive
+        )
         if collection is PASSIVE_NO_RESULT:
             self.fire_remove_event(state, dict_, value, initiator)
             assert (
@@ -1844,7 +1878,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         dict_: _InstanceDict,
         value: Any,
         initiator: Optional[AttributeEventToken] = None,
-        passive: PassiveFlag = PASSIVE_OFF,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
         check_old: Any = None,
         pop: bool = False,
         _adapt: bool = True,
@@ -1963,7 +1997,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         self,
         state: InstanceState[Any],
         dict_: _InstanceDict,
-        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        user_data: Literal[None] = ...,
         passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
     ) -> CollectionAdapter:
         ...
@@ -1973,7 +2007,17 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
         self,
         state: InstanceState[Any],
         dict_: _InstanceDict,
-        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        user_data: _AdaptedCollectionProtocol = ...,
+        passive: PassiveFlag = ...,
+    ) -> CollectionAdapter:
+        ...
+
+    @overload
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = ...,
         passive: PassiveFlag = PASSIVE_OFF,
     ) -> Union[
         Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
@@ -2490,7 +2534,7 @@ def register_attribute_impl(
     impl_class: Optional[Type[AttributeImpl]] = None,
     backref: Optional[str] = None,
     **kw: Any,
-) -> InstrumentedAttribute[Any]:
+) -> QueryableAttribute[Any]:
 
     manager = manager_of_class(class_)
     if uselist:
@@ -2599,7 +2643,7 @@ def init_state_collection(
         attr._dispose_previous_collection(state, old, old_collection, False)
 
     user_data = attr._default_value(state, dict_)
-    adapter = attr.get_collection(state, dict_, user_data)
+    adapter: CollectionAdapter = attr.get_collection(state, dict_, user_data)
     adapter._reset_empty()
 
     return adapter
index 0ace9b1cb625ac933085f9903f50d53c571ad2c9..63f873fd0edd1e3d91c98082702f034205c68182 100644 (file)
@@ -18,6 +18,7 @@ from typing import Any
 from typing import Callable
 from typing import Dict
 from typing import Generic
+from typing import no_type_check
 from typing import Optional
 from typing import overload
 from typing import Type
@@ -35,17 +36,20 @@ from ..sql.elements import SQLCoreOperations
 from ..util import FastIntFlag
 from ..util.langhelpers import TypingOnly
 from ..util.typing import Literal
-from ..util.typing import Self
 
 if typing.TYPE_CHECKING:
+    from ._typing import _EntityType
     from ._typing import _ExternalEntityType
     from ._typing import _InternalEntityType
     from .attributes import InstrumentedAttribute
     from .instrumentation import ClassManager
+    from .interfaces import PropComparator
     from .mapper import Mapper
     from .state import InstanceState
     from .util import AliasedClass
+    from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _InfoType
+    from ..sql.elements import ColumnElement
 
 _T = TypeVar("_T", bound=Any)
 
@@ -191,35 +195,34 @@ EXT_CONTINUE = util.symbol("EXT_CONTINUE")
 EXT_STOP = util.symbol("EXT_STOP")
 EXT_SKIP = util.symbol("EXT_SKIP")
 
-ONETOMANY = util.symbol(
-    "ONETOMANY",
+
+class RelationshipDirection(Enum):
+    ONETOMANY = 1
     """Indicates the one-to-many direction for a :func:`_orm.relationship`.
 
     This symbol is typically used by the internals but may be exposed within
     certain API features.
 
-    """,
-)
+    """
 
-MANYTOONE = util.symbol(
-    "MANYTOONE",
+    MANYTOONE = 2
     """Indicates the many-to-one direction for a :func:`_orm.relationship`.
 
     This symbol is typically used by the internals but may be exposed within
     certain API features.
 
-    """,
-)
+    """
 
-MANYTOMANY = util.symbol(
-    "MANYTOMANY",
+    MANYTOMANY = 3
     """Indicates the many-to-many direction for a :func:`_orm.relationship`.
 
     This symbol is typically used by the internals but may be exposed within
     certain API features.
 
-    """,
-)
+    """
+
+
+ONETOMANY, MANYTOONE, MANYTOMANY = tuple(RelationshipDirection)
 
 
 class InspectionAttrExtensionType(Enum):
@@ -249,7 +252,7 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
 _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
 
 
-_F = TypeVar("_F", bound=Callable)
+_F = TypeVar("_F", bound=Callable[..., Any])
 _Self = TypeVar("_Self")
 
 
@@ -397,29 +400,34 @@ def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]:
         return None
 
 
-def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]:
+def _class_to_mapper(
+    class_or_mapper: Union[Mapper[_T], Type[_T]]
+) -> Mapper[_T]:
+    # can't get mypy to see an overload for this
     insp = inspection.inspect(class_or_mapper, False)
     if insp is not None:
-        return insp.mapper
+        return insp.mapper  # type: ignore
     else:
+        assert isinstance(class_or_mapper, type)
         raise exc.UnmappedClassError(class_or_mapper)
 
 
 def _mapper_or_none(
-    entity: Union[_T, _InternalEntityType[_T]]
+    entity: Union[Type[_T], _InternalEntityType[_T]]
 ) -> Optional[Mapper[_T]]:
     """Return the :class:`_orm.Mapper` for the given class or None if the
     class is not mapped.
     """
 
+    # can't get mypy to see an overload for this
     insp = inspection.inspect(entity, False)
     if insp is not None:
-        return insp.mapper
+        return insp.mapper  # type: ignore
     else:
         return None
 
 
-def _is_mapped_class(entity):
+def _is_mapped_class(entity: Any) -> bool:
     """Return True if the given object is a mapped class,
     :class:`_orm.Mapper`, or :class:`.AliasedClass`.
     """
@@ -432,20 +440,13 @@ def _is_mapped_class(entity):
     )
 
 
-def _orm_columns(entity):
-    insp = inspection.inspect(entity, False)
-    if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"):
-        return [c for c in insp.selectable.c]
-    else:
-        return [entity]
-
-
-def _is_aliased_class(entity):
+def _is_aliased_class(entity: Any) -> bool:
     insp = inspection.inspect(entity, False)
     return insp is not None and getattr(insp, "is_aliased_class", False)
 
 
-def _entity_descriptor(entity, key):
+@no_type_check
+def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any:
     """Return a class attribute given an entity and string name.
 
     May return :class:`.InstrumentedAttribute` or user-defined
@@ -651,16 +652,26 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
 
     if typing.TYPE_CHECKING:
 
-        def of_type(self, class_):
+        def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
             ...
 
-        def and_(self, *criteria):
+        def and_(
+            self, *criteria: _ColumnExpressionArgument[bool]
+        ) -> PropComparator[bool]:
             ...
 
-        def any(self, criterion=None, **kwargs):  # noqa: A001
+        def any(  # noqa: A001
+            self,
+            criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+            **kwargs: Any,
+        ) -> ColumnElement[bool]:
             ...
 
-        def has(self, criterion=None, **kwargs):
+        def has(
+            self,
+            criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+            **kwargs: Any,
+        ) -> ColumnElement[bool]:
             ...
 
 
@@ -673,7 +684,9 @@ class ORMDescriptor(Generic[_T], TypingOnly):
     if typing.TYPE_CHECKING:
 
         @overload
-        def __get__(self: Self, instance: Any, owner: Literal[None]) -> Self:
+        def __get__(
+            self, instance: Any, owner: Literal[None]
+        ) -> ORMDescriptor[_T]:
             ...
 
         @overload
index 473468c6cde98e030280ad34f7e7ffdf0dada4f8..b3fcd29ea38268a1b49dfc53a582edc44811c98f 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """Routines to handle the string class registry used by declarative.
 
@@ -16,7 +15,22 @@ This system allows specification of classes and expressions used in
 from __future__ import annotations
 
 import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Mapping
 from typing import MutableMapping
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 import weakref
 
@@ -29,6 +43,14 @@ from .. import exc
 from .. import inspection
 from .. import util
 from ..sql.schema import _get_table_key
+from ..util.typing import CallableReference
+
+if TYPE_CHECKING:
+    from .relationships import Relationship
+    from ..sql.schema import MetaData
+    from ..sql.schema import Table
+
+_T = TypeVar("_T", bound=Any)
 
 _ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
 
@@ -36,10 +58,12 @@ _ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
 # the _decl_class_registry, which is usually weak referencing.
 # the internal registries here link to classes with weakrefs and remove
 # themselves when all references to contained classes are removed.
-_registries = set()
+_registries: Set[ClsRegistryToken] = set()
 
 
-def add_class(classname, cls, decl_class_registry):
+def add_class(
+    classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType
+) -> None:
     """Add a class to the _decl_class_registry associated with the
     given declarative class.
 
@@ -49,13 +73,15 @@ def add_class(classname, cls, decl_class_registry):
         existing = decl_class_registry[classname]
         if not isinstance(existing, _MultipleClassMarker):
             existing = decl_class_registry[classname] = _MultipleClassMarker(
-                [cls, existing]
+                [cls, cast("Type[Any]", existing)]
             )
     else:
         decl_class_registry[classname] = cls
 
     try:
-        root_module = decl_class_registry["_sa_module_registry"]
+        root_module = cast(
+            _ModuleMarker, decl_class_registry["_sa_module_registry"]
+        )
     except KeyError:
         decl_class_registry[
             "_sa_module_registry"
@@ -79,7 +105,9 @@ def add_class(classname, cls, decl_class_registry):
         module.add_class(classname, cls)
 
 
-def remove_class(classname, cls, decl_class_registry):
+def remove_class(
+    classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType
+) -> None:
     if classname in decl_class_registry:
         existing = decl_class_registry[classname]
         if isinstance(existing, _MultipleClassMarker):
@@ -88,7 +116,9 @@ def remove_class(classname, cls, decl_class_registry):
             del decl_class_registry[classname]
 
     try:
-        root_module = decl_class_registry["_sa_module_registry"]
+        root_module = cast(
+            _ModuleMarker, decl_class_registry["_sa_module_registry"]
+        )
     except KeyError:
         return
 
@@ -102,7 +132,11 @@ def remove_class(classname, cls, decl_class_registry):
         module.remove_class(classname, cls)
 
 
-def _key_is_empty(key, decl_class_registry, test):
+def _key_is_empty(
+    key: str,
+    decl_class_registry: _ClsRegistryType,
+    test: Callable[[Any], bool],
+) -> bool:
     """test if a key is empty of a certain object.
 
     used for unit tests against the registry to see if garbage collection
@@ -124,6 +158,8 @@ def _key_is_empty(key, decl_class_registry, test):
         for sub_thing in thing.contents:
             if test(sub_thing):
                 return False
+        else:
+            raise NotImplementedError("unknown codepath")
     else:
         return not test(thing)
 
@@ -142,20 +178,27 @@ class _MultipleClassMarker(ClsRegistryToken):
 
     __slots__ = "on_remove", "contents", "__weakref__"
 
-    def __init__(self, classes, on_remove=None):
+    contents: Set[weakref.ref[Type[Any]]]
+    on_remove: CallableReference[Optional[Callable[[], None]]]
+
+    def __init__(
+        self,
+        classes: Iterable[Type[Any]],
+        on_remove: Optional[Callable[[], None]] = None,
+    ):
         self.on_remove = on_remove
         self.contents = set(
             [weakref.ref(item, self._remove_item) for item in classes]
         )
         _registries.add(self)
 
-    def remove_item(self, cls):
+    def remove_item(self, cls: Type[Any]) -> None:
         self._remove_item(weakref.ref(cls))
 
-    def __iter__(self):
+    def __iter__(self) -> Generator[Optional[Type[Any]], None, None]:
         return (ref() for ref in self.contents)
 
-    def attempt_get(self, path, key):
+    def attempt_get(self, path: List[str], key: str) -> Type[Any]:
         if len(self.contents) > 1:
             raise exc.InvalidRequestError(
                 'Multiple classes found for path "%s" '
@@ -170,14 +213,14 @@ class _MultipleClassMarker(ClsRegistryToken):
                 raise NameError(key)
             return cls
 
-    def _remove_item(self, ref):
+    def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None:
         self.contents.discard(ref)
         if not self.contents:
             _registries.discard(self)
             if self.on_remove:
                 self.on_remove()
 
-    def add_item(self, item):
+    def add_item(self, item: Type[Any]) -> None:
         # protect against class registration race condition against
         # asynchronous garbage collection calling _remove_item,
         # [ticket:3208]
@@ -206,7 +249,12 @@ class _ModuleMarker(ClsRegistryToken):
 
     __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
 
-    def __init__(self, name, parent):
+    parent: Optional[_ModuleMarker]
+    contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]]
+    mod_ns: _ModNS
+    path: List[str]
+
+    def __init__(self, name: str, parent: Optional[_ModuleMarker]):
         self.parent = parent
         self.name = name
         self.contents = {}
@@ -217,51 +265,53 @@ class _ModuleMarker(ClsRegistryToken):
             self.path = []
         _registries.add(self)
 
-    def __contains__(self, name):
+    def __contains__(self, name: str) -> bool:
         return name in self.contents
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: str) -> ClsRegistryToken:
         return self.contents[name]
 
-    def _remove_item(self, name):
+    def _remove_item(self, name: str) -> None:
         self.contents.pop(name, None)
         if not self.contents and self.parent is not None:
             self.parent._remove_item(self.name)
             _registries.discard(self)
 
-    def resolve_attr(self, key):
-        return getattr(self.mod_ns, key)
+    def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]:
+        return self.mod_ns.__getattr__(key)
 
-    def get_module(self, name):
+    def get_module(self, name: str) -> _ModuleMarker:
         if name not in self.contents:
             marker = _ModuleMarker(name, self)
             self.contents[name] = marker
         else:
-            marker = self.contents[name]
+            marker = cast(_ModuleMarker, self.contents[name])
         return marker
 
-    def add_class(self, name, cls):
+    def add_class(self, name: str, cls: Type[Any]) -> None:
         if name in self.contents:
-            existing = self.contents[name]
+            existing = cast(_MultipleClassMarker, self.contents[name])
             existing.add_item(cls)
         else:
             existing = self.contents[name] = _MultipleClassMarker(
                 [cls], on_remove=lambda: self._remove_item(name)
             )
 
-    def remove_class(self, name, cls):
+    def remove_class(self, name: str, cls: Type[Any]) -> None:
         if name in self.contents:
-            existing = self.contents[name]
+            existing = cast(_MultipleClassMarker, self.contents[name])
             existing.remove_item(cls)
 
 
 class _ModNS:
     __slots__ = ("__parent",)
 
-    def __init__(self, parent):
+    __parent: _ModuleMarker
+
+    def __init__(self, parent: _ModuleMarker):
         self.__parent = parent
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]:
         try:
             value = self.__parent.contents[key]
         except KeyError:
@@ -282,10 +332,12 @@ class _ModNS:
 class _GetColumns:
     __slots__ = ("cls",)
 
-    def __init__(self, cls):
+    cls: Type[Any]
+
+    def __init__(self, cls: Type[Any]):
         self.cls = cls
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Any:
         mp = class_mapper(self.cls, configure=False)
         if mp:
             if key not in mp.all_orm_descriptors:
@@ -296,6 +348,7 @@ class _GetColumns:
 
             desc = mp.all_orm_descriptors[key]
             if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
+                assert isinstance(desc, attributes.QueryableAttribute)
                 prop = desc.property
                 if isinstance(prop, Synonym):
                     key = prop.name
@@ -316,15 +369,18 @@ inspection._inspects(_GetColumns)(
 class _GetTable:
     __slots__ = "key", "metadata"
 
-    def __init__(self, key, metadata):
+    key: str
+    metadata: MetaData
+
+    def __init__(self, key: str, metadata: MetaData):
         self.key = key
         self.metadata = metadata
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Table:
         return self.metadata.tables[_get_table_key(key, self.key)]
 
 
-def _determine_container(key, value):
+def _determine_container(key: str, value: Any) -> _GetColumns:
     if isinstance(value, _MultipleClassMarker):
         value = value.attempt_get([], key)
     return _GetColumns(value)
@@ -341,7 +397,21 @@ class _class_resolver:
         "favor_tables",
     )
 
-    def __init__(self, cls, prop, fallback, arg, favor_tables=False):
+    cls: Type[Any]
+    prop: Relationship[Any]
+    fallback: Mapping[str, Any]
+    arg: str
+    favor_tables: bool
+    _resolvers: Tuple[Callable[[str], Any], ...]
+
+    def __init__(
+        self,
+        cls: Type[Any],
+        prop: Relationship[Any],
+        fallback: Mapping[str, Any],
+        arg: str,
+        favor_tables: bool = False,
+    ):
         self.cls = cls
         self.prop = prop
         self.arg = arg
@@ -350,11 +420,12 @@ class _class_resolver:
         self._resolvers = ()
         self.favor_tables = favor_tables
 
-    def _access_cls(self, key):
+    def _access_cls(self, key: str) -> Any:
         cls = self.cls
 
         manager = attributes.manager_of_class(cls)
         decl_base = manager.registry
+        assert decl_base is not None
         decl_class_registry = decl_base._class_registry
         metadata = decl_base.metadata
 
@@ -362,7 +433,7 @@ class _class_resolver:
             if key in metadata.tables:
                 return metadata.tables[key]
             elif key in metadata._schemas:
-                return _GetTable(key, cls.metadata)
+                return _GetTable(key, getattr(cls, "metadata", metadata))
 
         if key in decl_class_registry:
             return _determine_container(key, decl_class_registry[key])
@@ -371,13 +442,14 @@ class _class_resolver:
             if key in metadata.tables:
                 return metadata.tables[key]
             elif key in metadata._schemas:
-                return _GetTable(key, cls.metadata)
+                return _GetTable(key, getattr(cls, "metadata", metadata))
 
-        if (
-            "_sa_module_registry" in decl_class_registry
-            and key in decl_class_registry["_sa_module_registry"]
+        if "_sa_module_registry" in decl_class_registry and key in cast(
+            _ModuleMarker, decl_class_registry["_sa_module_registry"]
         ):
-            registry = decl_class_registry["_sa_module_registry"]
+            registry = cast(
+                _ModuleMarker, decl_class_registry["_sa_module_registry"]
+            )
             return registry.resolve_attr(key)
         elif self._resolvers:
             for resolv in self._resolvers:
@@ -387,7 +459,7 @@ class _class_resolver:
 
         return self.fallback[key]
 
-    def _raise_for_name(self, name, err):
+    def _raise_for_name(self, name: str, err: Exception) -> NoReturn:
         generic_match = re.match(r"(.+)\[(.+)\]", name)
 
         if generic_match:
@@ -409,7 +481,7 @@ class _class_resolver:
                 % (self.prop.parent, self.arg, name, self.cls)
             ) from err
 
-    def _resolve_name(self):
+    def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]:
         name = self.arg
         d = self._dict
         rval = None
@@ -427,9 +499,11 @@ class _class_resolver:
             if isinstance(rval, _GetColumns):
                 return rval.cls
             else:
+                if TYPE_CHECKING:
+                    assert isinstance(rval, (type, Table, _ModNS))
                 return rval
 
-    def __call__(self):
+    def __call__(self) -> Any:
         try:
             x = eval(self.arg, globals(), self._dict)
 
@@ -441,10 +515,15 @@ class _class_resolver:
             self._raise_for_name(n.args[0], n)
 
 
-_fallback_dict = None
+_fallback_dict: Mapping[str, Any] = None  # type: ignore
 
 
-def _resolver(cls, prop):
+def _resolver(
+    cls: Type[Any], prop: Relationship[Any]
+) -> Tuple[
+    Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+    Callable[[str, bool], _class_resolver],
+]:
 
     global _fallback_dict
 
@@ -456,12 +535,14 @@ def _resolver(cls, prop):
             {"foreign": foreign, "remote": remote}
         )
 
-    def resolve_arg(arg, favor_tables=False):
+    def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver:
         return _class_resolver(
             cls, prop, _fallback_dict, arg, favor_tables=favor_tables
         )
 
-    def resolve_name(arg):
+    def resolve_name(
+        arg: str,
+    ) -> Callable[[], Union[Type[Any], Table, _ModNS]]:
         return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
 
     return resolve_name, resolve_arg
index da0da0fcfc011cc5293477941e15f175ea96b0aa..78fe89d05fbad6055ff3e11178fff5cf9d0cc976 100644 (file)
@@ -115,6 +115,7 @@ from typing import Collection
 from typing import Dict
 from typing import Iterable
 from typing import List
+from typing import NoReturn
 from typing import Optional
 from typing import Set
 from typing import Tuple
@@ -130,6 +131,7 @@ from ..util.compat import inspect_getfullargspec
 from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from .attributes import AttributeEventToken
     from .attributes import CollectionAttributeImpl
     from .mapped_collection import attribute_mapped_collection
     from .mapped_collection import column_mapped_collection
@@ -500,7 +502,7 @@ class CollectionAdapter:
         self.invalidated = False
         self.empty = False
 
-    def _warn_invalidated(self):
+    def _warn_invalidated(self) -> None:
         util.warn("This collection has been invalidated.")
 
     @property
@@ -509,7 +511,7 @@ class CollectionAdapter:
         return self._data()
 
     @property
-    def _referenced_by_owner(self):
+    def _referenced_by_owner(self) -> bool:
         """return True if the owner state still refers to this collection.
 
         This will return False within a bulk replace operation,
@@ -521,7 +523,9 @@ class CollectionAdapter:
     def bulk_appender(self):
         return self._data()._sa_appender
 
-    def append_with_event(self, item, initiator=None):
+    def append_with_event(
+        self, item: Any, initiator: Optional[AttributeEventToken] = None
+    ) -> None:
         """Add an entity to the collection, firing mutation events."""
 
         self._data()._sa_appender(item, _sa_initiator=initiator)
@@ -533,7 +537,7 @@ class CollectionAdapter:
         self.empty = True
         self.owner_state._empty_collections[self._key] = user_data
 
-    def _reset_empty(self):
+    def _reset_empty(self) -> None:
         assert (
             self.empty
         ), "This collection adapter is not in the 'empty' state"
@@ -542,20 +546,20 @@ class CollectionAdapter:
             self._key
         ] = self.owner_state._empty_collections.pop(self._key)
 
-    def _refuse_empty(self):
+    def _refuse_empty(self) -> NoReturn:
         raise sa_exc.InvalidRequestError(
             "This is a special 'empty' collection which cannot accommodate "
             "internal mutation operations"
         )
 
-    def append_without_event(self, item):
+    def append_without_event(self, item: Any) -> None:
         """Add or restore an entity to the collection, firing no events."""
 
         if self.empty:
             self._refuse_empty()
         self._data()._sa_appender(item, _sa_initiator=False)
 
-    def append_multiple_without_event(self, items):
+    def append_multiple_without_event(self, items: Iterable[Any]) -> None:
         """Add or restore an entity to the collection, firing no events."""
         if self.empty:
             self._refuse_empty()
@@ -566,17 +570,21 @@ class CollectionAdapter:
     def bulk_remover(self):
         return self._data()._sa_remover
 
-    def remove_with_event(self, item, initiator=None):
+    def remove_with_event(
+        self, item: Any, initiator: Optional[AttributeEventToken] = None
+    ) -> None:
         """Remove an entity from the collection, firing mutation events."""
         self._data()._sa_remover(item, _sa_initiator=initiator)
 
-    def remove_without_event(self, item):
+    def remove_without_event(self, item: Any) -> None:
         """Remove an entity from the collection, firing no events."""
         if self.empty:
             self._refuse_empty()
         self._data()._sa_remover(item, _sa_initiator=False)
 
-    def clear_with_event(self, initiator=None):
+    def clear_with_event(
+        self, initiator: Optional[AttributeEventToken] = None
+    ) -> None:
         """Empty the collection, firing a mutation event for each entity."""
 
         if self.empty:
@@ -585,7 +593,7 @@ class CollectionAdapter:
         for item in list(self):
             remover(item, _sa_initiator=initiator)
 
-    def clear_without_event(self):
+    def clear_without_event(self) -> None:
         """Empty the collection, firing no events."""
 
         if self.empty:
index 28fea2f9b3bb83907400cd6193601e14e7d01f12..58556bb58055b8e41bb3c37e9df05eab8b9494ec 100644 (file)
@@ -12,6 +12,7 @@ import itertools
 from typing import Any
 from typing import cast
 from typing import Dict
+from typing import Iterable
 from typing import List
 from typing import Optional
 from typing import Set
@@ -43,6 +44,7 @@ from ..sql import expression
 from ..sql import roles
 from ..sql import util as sql_util
 from ..sql import visitors
+from ..sql._typing import _TP
 from ..sql._typing import is_dml
 from ..sql._typing import is_insert_update
 from ..sql._typing import is_select_base
@@ -55,22 +57,32 @@ from ..sql.base import Options
 from ..sql.dml import UpdateBase
 from ..sql.elements import GroupedElement
 from ..sql.elements import TextClause
-from ..sql.selectable import ExecutableReturnsRows
 from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
 from ..sql.selectable import LABEL_STYLE_NONE
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..sql.selectable import Select
 from ..sql.selectable import SelectLabelStyle
 from ..sql.selectable import SelectState
+from ..sql.selectable import TypedReturnsRows
 from ..sql.visitors import InternalTraversal
 
 if TYPE_CHECKING:
     from ._typing import _InternalEntityType
+    from .loading import PostLoad
     from .mapper import Mapper
     from .query import Query
+    from .session import _BindArguments
+    from .session import Session
+    from ..engine.interfaces import _CoreSingleExecuteParams
+    from ..engine.interfaces import _ExecuteOptionsParameter
+    from ..sql._typing import _ColumnsClauseArgument
+    from ..sql.compiler import SQLCompiler
     from ..sql.dml import _DMLTableElement
     from ..sql.elements import ColumnElement
+    from ..sql.selectable import _JoinTargetElement
     from ..sql.selectable import _LabelConventionCallable
+    from ..sql.selectable import _SetupJoinsElement
+    from ..sql.selectable import ExecutableReturnsRows
     from ..sql.selectable import SelectBase
     from ..sql.type_api import TypeEngine
 
@@ -80,7 +92,7 @@ _path_registry = PathRegistry.root
 _EMPTY_DICT = util.immutabledict()
 
 
-LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM")
+LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM
 
 
 class QueryContext:
@@ -109,6 +121,10 @@ class QueryContext:
         "loaders_require_uniquing",
     )
 
+    runid: int
+    post_load_paths: Dict[PathRegistry, PostLoad]
+    compile_state: ORMCompileState
+
     class default_load_options(Options):
         _only_return_tuples = False
         _populate_existing = False
@@ -123,13 +139,16 @@ class QueryContext:
 
     def __init__(
         self,
-        compile_state,
-        statement,
-        params,
-        session,
-        load_options,
-        execution_options=None,
-        bind_arguments=None,
+        compile_state: CompileState,
+        statement: Union[Select[Any], FromStatement[Any]],
+        params: _CoreSingleExecuteParams,
+        session: Session,
+        load_options: Union[
+            Type[QueryContext.default_load_options],
+            QueryContext.default_load_options,
+        ],
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+        bind_arguments: Optional[_BindArguments] = None,
     ):
         self.load_options = load_options
         self.execution_options = execution_options or _EMPTY_DICT
@@ -220,8 +239,8 @@ class ORMCompileState(CompileState):
     attributes: Dict[Any, Any]
     global_attributes: Dict[Any, Any]
 
-    statement: Union[Select, FromStatement]
-    select_statement: Union[Select, FromStatement]
+    statement: Union[Select[Any], FromStatement[Any]]
+    select_statement: Union[Select[Any], FromStatement[Any]]
     _entities: List[_QueryEntity]
     _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter]
     compile_options: Union[
@@ -238,6 +257,7 @@ class ORMCompileState(CompileState):
         Tuple[Any, ...]
     ]
     current_path: PathRegistry = _path_registry
+    _has_mapper_entities = False
 
     def __init__(self, *arg, **kw):
         raise NotImplementedError()
@@ -266,7 +286,12 @@ class ORMCompileState(CompileState):
             return SelectState._column_naming_convention(label_style)
 
     @classmethod
-    def create_for_statement(cls, statement_container, compiler, **kw):
+    def create_for_statement(
+        cls,
+        statement: Union[Select, FromStatement],
+        compiler: Optional[SQLCompiler],
+        **kw: Any,
+    ) -> ORMCompileState:
         """Create a context for a statement given a :class:`.Compiler`.
 
         This method is always invoked in the context of SQLCompiler.process().
@@ -443,7 +468,12 @@ class ORMFromStatementCompileState(ORMCompileState):
     eager_joins = _EMPTY_DICT
 
     @classmethod
-    def create_for_statement(cls, statement_container, compiler, **kw):
+    def create_for_statement(
+        cls,
+        statement_container: Union[Select, FromStatement],
+        compiler: Optional[SQLCompiler],
+        **kw: Any,
+    ) -> ORMCompileState:
 
         if compiler is not None:
             toplevel = not compiler.stack
@@ -577,7 +607,7 @@ class ORMFromStatementCompileState(ORMCompileState):
         return None
 
 
-class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
+class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
     """Core construct that represents a load of ORM objects from various
     :class:`.ReturnsRows` and other classes including:
 
@@ -595,7 +625,7 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
 
     _for_update_arg = None
 
-    element: Union[SelectBase, TextClause, UpdateBase]
+    element: Union[ExecutableReturnsRows, TextClause]
 
     _traverse_internals = [
         ("_raw_columns", InternalTraversal.dp_clauseelement_list),
@@ -606,7 +636,11 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
         ("_compile_options", InternalTraversal.dp_has_cache_key)
     ]
 
-    def __init__(self, entities, element):
+    def __init__(
+        self,
+        entities: Iterable[_ColumnsClauseArgument[Any]],
+        element: Union[ExecutableReturnsRows, TextClause],
+    ):
         self._raw_columns = [
             coercions.expect(
                 roles.ColumnsClauseRole,
@@ -701,7 +735,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
     _having_criteria = ()
 
     @classmethod
-    def create_for_statement(cls, statement, compiler, **kw):
+    def create_for_statement(
+        cls,
+        statement: Union[Select, FromStatement],
+        compiler: Optional[SQLCompiler],
+        **kw: Any,
+    ) -> ORMCompileState:
         """compiler hook, we arrive here from compiler.visit_select() only."""
 
         self = cls.__new__(cls)
@@ -1073,9 +1112,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         )
 
     @classmethod
-    @util.preload_module("sqlalchemy.orm.query")
     def from_statement(cls, statement, from_statement):
-        query = util.preloaded.orm_query
 
         from_statement = coercions.expect(
             roles.ReturnsRowsRole,
@@ -1083,7 +1120,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             apply_propagate_attrs=statement,
         )
 
-        stmt = query.FromStatement(statement._raw_columns, from_statement)
+        stmt = FromStatement(statement._raw_columns, from_statement)
 
         stmt.__dict__.update(
             _with_options=statement._with_options,
@@ -2114,7 +2151,9 @@ def _column_descriptions(
     return d
 
 
-def _legacy_filter_by_entity_zero(query_or_augmented_select):
+def _legacy_filter_by_entity_zero(
+    query_or_augmented_select: Union[Query[Any], Select[Any]]
+) -> Optional[_InternalEntityType[Any]]:
     self = query_or_augmented_select
     if self._setup_joins:
         _last_joined_entity = self._last_joined_entity
@@ -2127,7 +2166,9 @@ def _legacy_filter_by_entity_zero(query_or_augmented_select):
     return _entity_from_pre_ent_zero(self)
 
 
-def _entity_from_pre_ent_zero(query_or_augmented_select):
+def _entity_from_pre_ent_zero(
+    query_or_augmented_select: Union[Query[Any], Select[Any]]
+) -> Optional[_InternalEntityType[Any]]:
     self = query_or_augmented_select
     if not self._raw_columns:
         return None
@@ -2144,13 +2185,19 @@ def _entity_from_pre_ent_zero(query_or_augmented_select):
         return ent
 
 
-def _determine_last_joined_entity(setup_joins, entity_zero=None):
+def _determine_last_joined_entity(
+    setup_joins: Tuple[_SetupJoinsElement, ...],
+    entity_zero: Optional[_InternalEntityType[Any]] = None,
+) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]:
     if not setup_joins:
         return None
 
     (target, onclause, from_, flags) = setup_joins[-1]
 
-    if isinstance(target, interfaces.PropComparator):
+    if isinstance(
+        target,
+        attributes.QueryableAttribute,
+    ):
         return target.entity
     else:
         return target
@@ -2161,6 +2208,8 @@ class _QueryEntity:
 
     __slots__ = ()
 
+    supports_single_entity: bool
+
     _non_hashable_value = False
     _null_column_type = False
     use_id_for_hash = False
@@ -2173,6 +2222,9 @@ class _QueryEntity:
     def setup_compile_state(self, compile_state: ORMCompileState) -> None:
         raise NotImplementedError()
 
+    def row_processor(self, context, result):
+        raise NotImplementedError()
+
     @classmethod
     def to_compile_state(
         cls, compile_state, entities, entities_collection, is_current_entities
index fbe35f92ab4db375f1d03785c44afd9f1d696d73..1c343b04ce794777006cd08b0bf25a4d87e075c3 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+
 """Public API functions and helpers for declarative."""
 
 from __future__ import annotations
@@ -14,16 +15,21 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import ClassVar
+from typing import Dict
+from typing import FrozenSet
+from typing import Iterator
 from typing import Mapping
 from typing import Optional
+from typing import overload
+from typing import Set
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 import weakref
 
 from . import attributes
 from . import clsregistry
-from . import exc as orm_exc
 from . import instrumentation
 from . import interfaces
 from . import mapperlib
@@ -38,24 +44,40 @@ from .decl_base import _del_attribute
 from .decl_base import _mapper
 from .descriptor_props import Synonym as _orm_synonym
 from .mapper import Mapper
+from .state import InstanceState
 from .. import exc
 from .. import inspection
 from .. import util
 from ..sql.elements import SQLCoreOperations
 from ..sql.schema import MetaData
 from ..sql.selectable import FromClause
-from ..sql.type_api import TypeEngine
 from ..util import hybridmethod
 from ..util import hybridproperty
 from ..util import typing as compat_typing
+from ..util.typing import CallableReference
+from ..util.typing import Literal
 
+if TYPE_CHECKING:
+    from ._typing import _O
+    from ._typing import _RegistryType
+    from .descriptor_props import Synonym
+    from .instrumentation import ClassManager
+    from .interfaces import MapperProperty
+    from ..sql._typing import _TypeEngineArgument
 
 _T = TypeVar("_T", bound=Any)
 
-_TypeAnnotationMapType = Mapping[Type, Union[Type[TypeEngine], TypeEngine]]
+# it's not clear how to have Annotated, Union objects etc. as keys here
+# from a typing perspective so just leave it open ended for now
+_TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"]
+_MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"]
+
+_DeclaredAttrDecorated = Callable[
+    ..., Union[Mapped[_T], SQLCoreOperations[_T]]
+]
 
 
-def has_inherited_table(cls):
+def has_inherited_table(cls: Type[_O]) -> bool:
     """Given a class, return True if any of the classes it inherits from has a
     mapped table, otherwise return False.
 
@@ -75,13 +97,13 @@ def has_inherited_table(cls):
 
 
 class _DynamicAttributesType(type):
-    def __setattr__(cls, key, value):
+    def __setattr__(cls, key: str, value: Any) -> None:
         if "__mapper__" in cls.__dict__:
             _add_attribute(cls, key, value)
         else:
             type.__setattr__(cls, key, value)
 
-    def __delattr__(cls, key):
+    def __delattr__(cls, key: str) -> None:
         if "__mapper__" in cls.__dict__:
             _del_attribute(cls, key)
         else:
@@ -89,7 +111,7 @@ class _DynamicAttributesType(type):
 
 
 class DeclarativeAttributeIntercept(
-    _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"]
+    _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
 ):
     """Metaclass that may be used in conjunction with the
     :class:`_orm.DeclarativeBase` class to support addition of class
@@ -99,10 +121,10 @@ class DeclarativeAttributeIntercept(
 
 
 class DeclarativeMeta(
-    _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"]
+    _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
 ):
     metadata: MetaData
-    registry: "RegistryType"
+    registry: RegistryType
 
     def __init__(
         cls, classname: Any, bases: Any, dict_: Any, **kw: Any
@@ -130,7 +152,9 @@ class DeclarativeMeta(
         type.__init__(cls, classname, bases, dict_)
 
 
-def synonym_for(name, map_column=False):
+def synonym_for(
+    name: str, map_column: bool = False
+) -> Callable[[Callable[..., Any]], Synonym[Any]]:
     """Decorator that produces an :func:`_orm.synonym`
     attribute in conjunction with a Python descriptor.
 
@@ -164,7 +188,7 @@ def synonym_for(name, map_column=False):
 
     """
 
-    def decorate(fn):
+    def decorate(fn: Callable[..., Any]) -> Synonym[Any]:
         return _orm_synonym(name, map_column=map_column, descriptor=fn)
 
     return decorate
@@ -255,16 +279,16 @@ class declared_attr(interfaces._MappedAttribute[_T]):
 
     if typing.TYPE_CHECKING:
 
-        def __set__(self, instance, value):
+        def __set__(self, instance: Any, value: Any) -> None:
             ...
 
-        def __delete__(self, instance: Any):
+        def __delete__(self, instance: Any) -> None:
             ...
 
     def __init__(
         self,
-        fn: Callable[..., Union[Mapped[_T], SQLCoreOperations[_T]]],
-        cascading=False,
+        fn: _DeclaredAttrDecorated[_T],
+        cascading: bool = False,
     ):
         self.fget = fn
         self._cascading = cascading
@@ -273,10 +297,28 @@ class declared_attr(interfaces._MappedAttribute[_T]):
     def _collect_return_annotation(self) -> Optional[Type[Any]]:
         return util.get_annotations(self.fget).get("return")
 
-    def __get__(self, instance, owner) -> InstrumentedAttribute[_T]:
+    # this is the Mapped[] API where at class descriptor get time we want
+    # the type checker to see InstrumentedAttribute[_T].   However the
+    # callable function prior to mapping in fact calls the given
+    # declarative function that does not return InstrumentedAttribute
+    @overload
+    def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]:
+        ...
+
+    @overload
+    def __get__(self, instance: object, owner: Any) -> _T:
+        ...
+
+    def __get__(
+        self, instance: Optional[object], owner: Any
+    ) -> Union[InstrumentedAttribute[_T], _T]:
         # the declared_attr needs to make use of a cache that exists
         # for the span of the declarative scan_attributes() phase.
         # to achieve this we look at the class manager that's configured.
+
+        # note this method should not be called outside of the declarative
+        # setup phase
+
         cls = owner
         manager = attributes.opt_manager_of_class(cls)
         if manager is None:
@@ -287,30 +329,33 @@ class declared_attr(interfaces._MappedAttribute[_T]):
                     "Unmanaged access of declarative attribute %s from "
                     "non-mapped class %s" % (self.fget.__name__, cls.__name__)
                 )
-            return self.fget(cls)
+            return self.fget(cls)  # type: ignore
         elif manager.is_mapped:
             # the class is mapped, which means we're outside of the declarative
             # scan setup, just run the function.
-            return self.fget(cls)
+            return self.fget(cls)  # type: ignore
 
         # here, we are inside of the declarative scan.  use the registry
         # that is tracking the values of these attributes.
         declarative_scan = manager.declarative_scan()
+
+        # assert that we are in fact in the declarative scan
         assert declarative_scan is not None
+
         reg = declarative_scan.declared_attr_reg
 
         if self in reg:
-            return reg[self]
+            return reg[self]  # type: ignore
         else:
             reg[self] = obj = self.fget(cls)
-            return obj
+            return obj  # type: ignore
 
     @hybridmethod
-    def _stateful(cls, **kw):
+    def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]:
         return _stateful_declared_attr(**kw)
 
     @hybridproperty
-    def cascading(cls):
+    def cascading(cls) -> _stateful_declared_attr[_T]:
         """Mark a :class:`.declared_attr` as cascading.
 
         This is a special-use modifier which indicates that a column
@@ -372,20 +417,23 @@ class declared_attr(interfaces._MappedAttribute[_T]):
         return cls._stateful(cascading=True)
 
 
-class _stateful_declared_attr(declared_attr):
-    def __init__(self, **kw):
+class _stateful_declared_attr(declared_attr[_T]):
+    kw: Dict[str, Any]
+
+    def __init__(self, **kw: Any):
         self.kw = kw
 
-    def _stateful(self, **kw):
+    @hybridmethod
+    def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]:
         new_kw = self.kw.copy()
         new_kw.update(kw)
         return _stateful_declared_attr(**new_kw)
 
-    def __call__(self, fn):
+    def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]:
         return declared_attr(fn, **self.kw)
 
 
-def declarative_mixin(cls):
+def declarative_mixin(cls: Type[_T]) -> Type[_T]:
     """Mark a class as providing the feature of "declarative mixin".
 
     E.g.::
@@ -427,9 +475,9 @@ def declarative_mixin(cls):
     return cls
 
 
-def _setup_declarative_base(cls):
+def _setup_declarative_base(cls: Type[Any]) -> None:
     if "metadata" in cls.__dict__:
-        metadata = cls.metadata
+        metadata = cls.metadata  # type: ignore
     else:
         metadata = None
 
@@ -457,15 +505,15 @@ def _setup_declarative_base(cls):
         reg = registry(
             metadata=metadata, type_annotation_map=type_annotation_map
         )
-        cls.registry = reg
+        cls.registry = reg  # type: ignore
 
-    cls._sa_registry = reg
+    cls._sa_registry = reg  # type: ignore
 
     if "metadata" not in cls.__dict__:
-        cls.metadata = cls.registry.metadata
+        cls.metadata = cls.registry.metadata  # type: ignore
 
 
-class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
+class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]):
     """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass
     to intercept new attributes.
 
@@ -477,10 +525,10 @@ class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
 
     """
 
-    registry: ClassVar["registry"]
-    _sa_registry: ClassVar["registry"]
+    registry: ClassVar[_RegistryType]
+    _sa_registry: ClassVar[_RegistryType]
     metadata: ClassVar[MetaData]
-    __mapper__: ClassVar[Mapper]
+    __mapper__: ClassVar[Mapper[Any]]
     __table__: Optional[FromClause]
 
     if typing.TYPE_CHECKING:
@@ -496,7 +544,7 @@ class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
 
 
 class DeclarativeBase(
-    inspection.Inspectable["InstanceState"],
+    inspection.Inspectable[InstanceState[Any]],
     metaclass=DeclarativeAttributeIntercept,
 ):
     """Base class used for declarative class definitions.
@@ -557,10 +605,10 @@ class DeclarativeBase(
 
     """
 
-    registry: ClassVar["registry"]
-    _sa_registry: ClassVar["registry"]
+    registry: ClassVar[_RegistryType]
+    _sa_registry: ClassVar[_RegistryType]
     metadata: ClassVar[MetaData]
-    __mapper__: ClassVar[Mapper]
+    __mapper__: ClassVar[Mapper[Any]]
     __table__: Optional[FromClause]
 
     if typing.TYPE_CHECKING:
@@ -572,10 +620,12 @@ class DeclarativeBase(
         if DeclarativeBase in cls.__bases__:
             _setup_declarative_base(cls)
         else:
-            cls._sa_registry.map_declaratively(cls)
+            _as_declarative(cls._sa_registry, cls, cls.__dict__)
 
 
-def add_mapped_attribute(target, key, attr):
+def add_mapped_attribute(
+    target: Type[_O], key: str, attr: MapperProperty[Any]
+) -> None:
     """Add a new mapped attribute to an ORM mapped class.
 
     E.g.::
@@ -593,14 +643,15 @@ def add_mapped_attribute(target, key, attr):
 
 
 def declarative_base(
+    *,
     metadata: Optional[MetaData] = None,
-    mapper=None,
-    cls=object,
-    name="Base",
+    mapper: Optional[Callable[..., Mapper[Any]]] = None,
+    cls: Type[Any] = object,
+    name: str = "Base",
     class_registry: Optional[clsregistry._ClsRegistryType] = None,
     type_annotation_map: Optional[_TypeAnnotationMapType] = None,
     constructor: Callable[..., None] = _declarative_constructor,
-    metaclass=DeclarativeMeta,
+    metaclass: Type[Any] = DeclarativeMeta,
 ) -> Any:
     r"""Construct a base class for declarative class definitions.
 
@@ -736,8 +787,19 @@ class registry:
 
     """
 
+    _class_registry: clsregistry._ClsRegistryType
+    _managers: weakref.WeakKeyDictionary[ClassManager[Any], Literal[True]]
+    _non_primary_mappers: weakref.WeakKeyDictionary[Mapper[Any], Literal[True]]
+    metadata: MetaData
+    constructor: CallableReference[Callable[..., None]]
+    type_annotation_map: _MutableTypeAnnotationMapType
+    _dependents: Set[_RegistryType]
+    _dependencies: Set[_RegistryType]
+    _new_mappers: bool
+
     def __init__(
         self,
+        *,
         metadata: Optional[MetaData] = None,
         class_registry: Optional[clsregistry._ClsRegistryType] = None,
         type_annotation_map: Optional[_TypeAnnotationMapType] = None,
@@ -799,9 +861,7 @@ class registry:
 
     def update_type_annotation_map(
         self,
-        type_annotation_map: Mapping[
-            Type, Union[Type[TypeEngine], TypeEngine]
-        ],
+        type_annotation_map: _TypeAnnotationMapType,
     ) -> None:
         """update the :paramref:`_orm.registry.type_annotation_map` with new
         values."""
@@ -817,20 +877,20 @@ class registry:
         )
 
     @property
-    def mappers(self):
+    def mappers(self) -> FrozenSet[Mapper[Any]]:
         """read only collection of all :class:`_orm.Mapper` objects."""
 
         return frozenset(manager.mapper for manager in self._managers).union(
             self._non_primary_mappers
         )
 
-    def _set_depends_on(self, registry):
+    def _set_depends_on(self, registry: RegistryType) -> None:
         if registry is self:
             return
         registry._dependents.add(self)
         self._dependencies.add(registry)
 
-    def _flag_new_mapper(self, mapper):
+    def _flag_new_mapper(self, mapper: Mapper[Any]) -> None:
         mapper._ready_for_configure = True
         if self._new_mappers:
             return
@@ -839,7 +899,9 @@ class registry:
             reg._new_mappers = True
 
     @classmethod
-    def _recurse_with_dependents(cls, registries):
+    def _recurse_with_dependents(
+        cls, registries: Set[RegistryType]
+    ) -> Iterator[RegistryType]:
         todo = registries
         done = set()
         while todo:
@@ -856,7 +918,9 @@ class registry:
             todo.update(reg._dependents.difference(done))
 
     @classmethod
-    def _recurse_with_dependencies(cls, registries):
+    def _recurse_with_dependencies(
+        cls, registries: Set[RegistryType]
+    ) -> Iterator[RegistryType]:
         todo = registries
         done = set()
         while todo:
@@ -873,7 +937,7 @@ class registry:
             # them before
             todo.update(reg._dependencies.difference(done))
 
-    def _mappers_to_configure(self):
+    def _mappers_to_configure(self) -> Iterator[Mapper[Any]]:
         return itertools.chain(
             (
                 manager.mapper
@@ -889,13 +953,13 @@ class registry:
             ),
         )
 
-    def _add_non_primary_mapper(self, np_mapper):
+    def _add_non_primary_mapper(self, np_mapper: Mapper[Any]) -> None:
         self._non_primary_mappers[np_mapper] = True
 
-    def _dispose_cls(self, cls):
+    def _dispose_cls(self, cls: Type[_O]) -> None:
         clsregistry.remove_class(cls.__name__, cls, self._class_registry)
 
-    def _add_manager(self, manager):
+    def _add_manager(self, manager: ClassManager[Any]) -> None:
         self._managers[manager] = True
         if manager.is_mapped:
             raise exc.ArgumentError(
@@ -905,7 +969,7 @@ class registry:
         assert manager.registry is None
         manager.registry = self
 
-    def configure(self, cascade=False):
+    def configure(self, cascade: bool = False) -> None:
         """Configure all as-yet unconfigured mappers in this
         :class:`_orm.registry`.
 
@@ -946,7 +1010,7 @@ class registry:
         """
         mapperlib._configure_registries({self}, cascade=cascade)
 
-    def dispose(self, cascade=False):
+    def dispose(self, cascade: bool = False) -> None:
         """Dispose of all mappers in this :class:`_orm.registry`.
 
         After invocation, all the classes that were mapped within this registry
@@ -972,7 +1036,7 @@ class registry:
 
         mapperlib._dispose_registries({self}, cascade=cascade)
 
-    def _dispose_manager_and_mapper(self, manager):
+    def _dispose_manager_and_mapper(self, manager: ClassManager[Any]) -> None:
         if "mapper" in manager.__dict__:
             mapper = manager.mapper
 
@@ -984,11 +1048,11 @@ class registry:
 
     def generate_base(
         self,
-        mapper=None,
-        cls=object,
-        name="Base",
-        metaclass=DeclarativeMeta,
-    ):
+        mapper: Optional[Callable[..., Mapper[Any]]] = None,
+        cls: Type[Any] = object,
+        name: str = "Base",
+        metaclass: Type[Any] = DeclarativeMeta,
+    ) -> Any:
         """Generate a declarative base class.
 
         Classes that inherit from the returned class object will be
@@ -1070,7 +1134,7 @@ class registry:
 
         if hasattr(cls, "__class_getitem__"):
 
-            def __class_getitem__(cls, key):
+            def __class_getitem__(cls: Type[_T], key: str) -> Type[_T]:
                 # allow generic classes in py3.9+
                 return cls
 
@@ -1078,7 +1142,7 @@ class registry:
 
         return metaclass(name, bases, class_dict)
 
-    def mapped(self, cls):
+    def mapped(self, cls: Type[_O]) -> Type[_O]:
         """Class decorator that will apply the Declarative mapping process
         to a given class.
 
@@ -1114,7 +1178,7 @@ class registry:
         _as_declarative(self, cls, cls.__dict__)
         return cls
 
-    def as_declarative_base(self, **kw):
+    def as_declarative_base(self, **kw: Any) -> Callable[[Type[_T]], Type[_T]]:
         """
         Class decorator which will invoke
         :meth:`_orm.registry.generate_base`
@@ -1142,14 +1206,14 @@ class registry:
 
         """
 
-        def decorate(cls):
+        def decorate(cls: Type[_T]) -> Type[_T]:
             kw["cls"] = cls
             kw["name"] = cls.__name__
-            return self.generate_base(**kw)
+            return self.generate_base(**kw)  # type: ignore
 
         return decorate
 
-    def map_declaratively(self, cls):
+    def map_declaratively(self, cls: Type[_O]) -> Mapper[_O]:
         """Map a class declaratively.
 
         In this form of mapping, the class is scanned for mapping information,
@@ -1194,9 +1258,15 @@ class registry:
             :meth:`_orm.registry.map_imperatively`
 
         """
-        return _as_declarative(self, cls, cls.__dict__)
+        _as_declarative(self, cls, cls.__dict__)
+        return cls.__mapper__  # type: ignore
 
-    def map_imperatively(self, class_, local_table=None, **kw):
+    def map_imperatively(
+        self,
+        class_: Type[_O],
+        local_table: Optional[FromClause] = None,
+        **kw: Any,
+    ) -> Mapper[_O]:
         r"""Map a class imperatively.
 
         In this form of mapping, the class is not scanned for any mapping
@@ -1251,7 +1321,7 @@ class registry:
 RegistryType = registry
 
 
-def as_declarative(**kw):
+def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]:
     """
     Class decorator which will adapt a given class into a
     :func:`_orm.declarative_base`.
@@ -1292,14 +1362,9 @@ def as_declarative(**kw):
 @inspection._inspects(
     DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept
 )
-def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]:
-    mp: Mapper[Any] = _inspect_mapped_class(cls)
+def _inspect_decl_meta(cls: Type[Any]) -> Optional[Mapper[Any]]:
+    mp: Optional[Mapper[Any]] = _inspect_mapped_class(cls)
     if mp is None:
         if _DeferredMapperConfig.has_cls(cls):
             _DeferredMapperConfig.raise_unmapped_for_cls(cls)
-            raise orm_exc.UnmappedClassError(
-                cls,
-                msg="Class %s has a deferred mapping on it.  It is not yet "
-                "usable as a mapped class." % orm_exc._safe_cls_name(cls),
-            )
     return mp
index b1f81cb6b8e3781d60ed4e4f01e28c1ea68429d0..c3faac36cf3070b8692d05c22682817c74665083 100644 (file)
@@ -4,16 +4,26 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+
 """Internal implementation for declarative."""
 
 from __future__ import annotations
 
 import collections
 from typing import Any
+from typing import Callable
+from typing import cast
 from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Mapping
+from typing import NoReturn
+from typing import Optional
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
 import weakref
 
 from . import attributes
@@ -21,6 +31,8 @@ from . import clsregistry
 from . import exc as orm_exc
 from . import instrumentation
 from . import mapperlib
+from ._typing import _O
+from ._typing import attr_is_internal_proxy
 from .attributes import InstrumentedAttribute
 from .attributes import QueryableAttribute
 from .base import _is_mapped_class
@@ -32,6 +44,7 @@ from .interfaces import _MappedAttribute
 from .interfaces import _MapsColumns
 from .interfaces import MapperProperty
 from .mapper import Mapper as mapper
+from .mapper import Mapper
 from .properties import ColumnProperty
 from .properties import MappedColumn
 from .util import _is_mapped_annotation
@@ -43,12 +56,41 @@ from ..sql import expression
 from ..sql.schema import Column
 from ..sql.schema import Table
 from ..util import topological
+from ..util.typing import Protocol
 
 if TYPE_CHECKING:
+    from ._typing import _ClassDict
     from ._typing import _RegistryType
+    from .decl_api import declared_attr
+    from .instrumentation import ClassManager
+    from ..sql.schema import MetaData
+    from ..sql.selectable import FromClause
+
+_T = TypeVar("_T", bound=Any)
+
+_MapperKwArgs = Mapping[str, Any]
+
+_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]]
 
 
-def _declared_mapping_info(cls):
+class _DeclMappedClassProtocol(Protocol[_O]):
+    metadata: MetaData
+    __mapper__: Mapper[_O]
+    __table__: Table
+    __tablename__: str
+    __mapper_args__: Mapping[str, Any]
+    __table_args__: Optional[_TableArgsType]
+
+    def __declare_first__(self) -> None:
+        pass
+
+    def __declare_last__(self) -> None:
+        pass
+
+
+def _declared_mapping_info(
+    cls: Type[Any],
+) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]:
     # deferred mapping
     if _DeferredMapperConfig.has_cls(cls):
         return _DeferredMapperConfig.config_for_cls(cls)
@@ -59,13 +101,15 @@ def _declared_mapping_info(cls):
         return None
 
 
-def _resolve_for_abstract_or_classical(cls):
+def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]:
     if cls is object:
         return None
 
+    sup: Optional[Type[Any]]
+
     if cls.__dict__.get("__abstract__", False):
-        for sup in cls.__bases__:
-            sup = _resolve_for_abstract_or_classical(sup)
+        for base_ in cls.__bases__:
+            sup = _resolve_for_abstract_or_classical(base_)
             if sup is not None:
                 return sup
         else:
@@ -79,7 +123,9 @@ def _resolve_for_abstract_or_classical(cls):
             return cls
 
 
-def _get_immediate_cls_attr(cls, attrname, strict=False):
+def _get_immediate_cls_attr(
+    cls: Type[Any], attrname: str, strict: bool = False
+) -> Optional[Any]:
     """return an attribute of the class that is either present directly
     on the class, e.g. not on a superclass, or is from a superclass but
     this superclass is a non-mapped mixin, that is, not a descendant of
@@ -102,7 +148,7 @@ def _get_immediate_cls_attr(cls, attrname, strict=False):
         return getattr(cls, attrname)
 
     for base in cls.__mro__[1:]:
-        _is_classicial_inherits = _dive_for_cls_manager(base)
+        _is_classicial_inherits = _dive_for_cls_manager(base) is not None
 
         if attrname in base.__dict__ and (
             base is cls
@@ -116,33 +162,37 @@ def _get_immediate_cls_attr(cls, attrname, strict=False):
         return None
 
 
-def _dive_for_cls_manager(cls):
+def _dive_for_cls_manager(cls: Type[_O]) -> Optional[ClassManager[_O]]:
     # because the class manager registration is pluggable,
     # we need to do the search for every class in the hierarchy,
     # rather than just a simple "cls._sa_class_manager"
 
-    # python 2 old style class
-    if not hasattr(cls, "__mro__"):
-        return None
-
     for base in cls.__mro__:
-        manager = attributes.opt_manager_of_class(base)
+        manager: Optional[ClassManager[_O]] = attributes.opt_manager_of_class(
+            base
+        )
         if manager:
             return manager
     return None
 
 
-def _as_declarative(registry, cls, dict_):
+def _as_declarative(
+    registry: _RegistryType, cls: Type[Any], dict_: _ClassDict
+) -> Optional[_MapperConfig]:
 
     # declarative scans the class for attributes.  no table or mapper
     # args passed separately.
-
     return _MapperConfig.setup_mapping(registry, cls, dict_, None, {})
 
 
-def _mapper(registry, cls, table, mapper_kw):
+def _mapper(
+    registry: _RegistryType,
+    cls: Type[_O],
+    table: Optional[FromClause],
+    mapper_kw: _MapperKwArgs,
+) -> Mapper[_O]:
     _ImperativeMapperConfig(registry, cls, table, mapper_kw)
-    return cls.__mapper__
+    return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__
 
 
 @util.preload_module("sqlalchemy.orm.decl_api")
@@ -152,7 +202,9 @@ def _is_declarative_props(obj: Any) -> bool:
     return isinstance(obj, (declared_attr, util.classproperty))
 
 
-def _check_declared_props_nocascade(obj, name, cls):
+def _check_declared_props_nocascade(
+    obj: Any, name: str, cls: Type[_O]
+) -> bool:
     if _is_declarative_props(obj):
         if getattr(obj, "_cascading", False):
             util.warn(
@@ -174,8 +226,20 @@ class _MapperConfig:
         "__weakref__",
     )
 
+    cls: Type[Any]
+    classname: str
+    properties: util.OrderedDict[str, MapperProperty[Any]]
+    declared_attr_reg: Dict[declared_attr[Any], Any]
+
     @classmethod
-    def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
+    def setup_mapping(
+        cls,
+        registry: _RegistryType,
+        cls_: Type[_O],
+        dict_: _ClassDict,
+        table: Optional[FromClause],
+        mapper_kw: _MapperKwArgs,
+    ) -> Optional[_MapperConfig]:
         manager = attributes.opt_manager_of_class(cls)
         if manager and manager.class_ is cls_:
             raise exc.InvalidRequestError(
@@ -183,24 +247,26 @@ class _MapperConfig:
             )
 
         if cls_.__dict__.get("__abstract__", False):
-            return
+            return None
 
         defer_map = _get_immediate_cls_attr(
             cls_, "_sa_decl_prepare_nocascade", strict=True
         ) or hasattr(cls_, "_sa_decl_prepare")
 
         if defer_map:
-            cfg_cls = _DeferredMapperConfig
+            return _DeferredMapperConfig(
+                registry, cls_, dict_, table, mapper_kw
+            )
         else:
-            cfg_cls = _ClassScanMapperConfig
-
-        return cfg_cls(registry, cls_, dict_, table, mapper_kw)
+            return _ClassScanMapperConfig(
+                registry, cls_, dict_, table, mapper_kw
+            )
 
     def __init__(
         self,
         registry: _RegistryType,
         cls_: Type[Any],
-        mapper_kw: Dict[str, Any],
+        mapper_kw: _MapperKwArgs,
     ):
         self.cls = util.assert_arg_type(cls_, type, "cls_")
         self.classname = cls_.__name__
@@ -224,13 +290,16 @@ class _MapperConfig:
                     "Mapper." % self.cls
                 )
 
-    def set_cls_attribute(self, attrname, value):
+    def set_cls_attribute(self, attrname: str, value: _T) -> _T:
 
         manager = instrumentation.manager_of_class(self.cls)
         manager.install_member(attrname, value)
         return value
 
-    def _early_mapping(self, mapper_kw):
+    def map(self, mapper_kw: _MapperKwArgs = ...) -> Mapper[Any]:
+        raise NotImplementedError()
+
+    def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None:
         self.map(mapper_kw)
 
 
@@ -239,10 +308,10 @@ class _ImperativeMapperConfig(_MapperConfig):
 
     def __init__(
         self,
-        registry,
-        cls_,
-        table,
-        mapper_kw,
+        registry: _RegistryType,
+        cls_: Type[_O],
+        table: Optional[FromClause],
+        mapper_kw: _MapperKwArgs,
     ):
         super(_ImperativeMapperConfig, self).__init__(
             registry, cls_, mapper_kw
@@ -260,7 +329,7 @@ class _ImperativeMapperConfig(_MapperConfig):
 
             self._early_mapping(mapper_kw)
 
-    def map(self, mapper_kw=util.EMPTY_DICT):
+    def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
         mapper_cls = mapper
 
         return self.set_cls_attribute(
@@ -268,7 +337,7 @@ class _ImperativeMapperConfig(_MapperConfig):
             mapper_cls(self.cls, self.local_table, **mapper_kw),
         )
 
-    def _setup_inheritance(self, mapper_kw):
+    def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
         cls = self.cls
 
         inherits = mapper_kw.get("inherits", None)
@@ -277,8 +346,8 @@ class _ImperativeMapperConfig(_MapperConfig):
             # since we search for classical mappings now, search for
             # multiple mapped bases as well and raise an error.
             inherits_search = []
-            for c in cls.__bases__:
-                c = _resolve_for_abstract_or_classical(c)
+            for base_ in cls.__bases__:
+                c = _resolve_for_abstract_or_classical(base_)
                 if c is None:
                     continue
                 if _declared_mapping_info(
@@ -318,13 +387,30 @@ class _ClassScanMapperConfig(_MapperConfig):
         "inherits",
     )
 
+    registry: _RegistryType
+    clsdict_view: _ClassDict
+    collected_annotations: Dict[str, Tuple[Any, bool]]
+    collected_attributes: Dict[str, Any]
+    local_table: Optional[FromClause]
+    persist_selectable: Optional[FromClause]
+    declared_columns: util.OrderedSet[Column[Any]]
+    column_copies: Dict[
+        Union[MappedColumn[Any], Column[Any]],
+        Union[MappedColumn[Any], Column[Any]],
+    ]
+    tablename: Optional[str]
+    mapper_args: Mapping[str, Any]
+    table_args: Optional[_TableArgsType]
+    mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
+    inherits: Optional[Type[Any]]
+
     def __init__(
         self,
-        registry,
-        cls_,
-        dict_,
-        table,
-        mapper_kw,
+        registry: _RegistryType,
+        cls_: Type[_O],
+        dict_: _ClassDict,
+        table: Optional[FromClause],
+        mapper_kw: _MapperKwArgs,
     ):
 
         # grab class dict before the instrumentation manager has been added.
@@ -337,7 +423,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         self.persist_selectable = None
 
         self.collected_attributes = {}
-        self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
+        self.collected_annotations = {}
         self.declared_columns = util.OrderedSet()
         self.column_copies = {}
 
@@ -360,31 +446,37 @@ class _ClassScanMapperConfig(_MapperConfig):
 
             self._early_mapping(mapper_kw)
 
-    def _setup_declared_events(self):
+    def _setup_declared_events(self) -> None:
         if _get_immediate_cls_attr(self.cls, "__declare_last__"):
 
             @event.listens_for(mapper, "after_configured")
-            def after_configured():
-                self.cls.__declare_last__()
+            def after_configured() -> None:
+                cast(
+                    "_DeclMappedClassProtocol[Any]", self.cls
+                ).__declare_last__()
 
         if _get_immediate_cls_attr(self.cls, "__declare_first__"):
 
             @event.listens_for(mapper, "before_configured")
-            def before_configured():
-                self.cls.__declare_first__()
-
-    def _cls_attr_override_checker(self, cls):
+            def before_configured() -> None:
+                cast(
+                    "_DeclMappedClassProtocol[Any]", self.cls
+                ).__declare_first__()
+
+    def _cls_attr_override_checker(
+        self, cls: Type[_O]
+    ) -> Callable[[str, Any], bool]:
         """Produce a function that checks if a class has overridden an
         attribute, taking SQLAlchemy-enabled dataclass fields into account.
 
         """
         sa_dataclass_metadata_key = _get_immediate_cls_attr(
-            cls, "__sa_dataclass_metadata_key__", None
+            cls, "__sa_dataclass_metadata_key__"
         )
 
         if sa_dataclass_metadata_key is None:
 
-            def attribute_is_overridden(key, obj):
+            def attribute_is_overridden(key: str, obj: Any) -> bool:
                 return getattr(cls, key) is not obj
 
         else:
@@ -402,7 +494,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
             absent = object()
 
-            def attribute_is_overridden(key, obj):
+            def attribute_is_overridden(key: str, obj: Any) -> bool:
                 if _is_declarative_props(obj):
                     obj = obj.fget
 
@@ -457,13 +549,15 @@ class _ClassScanMapperConfig(_MapperConfig):
         ]
     )
 
-    def _cls_attr_resolver(self, cls):
+    def _cls_attr_resolver(
+        self, cls: Type[Any]
+    ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]:
         """produce a function to iterate the "attributes" of a class,
         adjusting for SQLAlchemy fields embedded in dataclass fields.
 
         """
-        sa_dataclass_metadata_key = _get_immediate_cls_attr(
-            cls, "__sa_dataclass_metadata_key__", None
+        sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+            cls, "__sa_dataclass_metadata_key__"
         )
 
         cls_annotations = util.get_annotations(cls)
@@ -477,7 +571,9 @@ class _ClassScanMapperConfig(_MapperConfig):
         )
         if sa_dataclass_metadata_key is None:
 
-            def local_attributes_for_class():
+            def local_attributes_for_class() -> Iterable[
+                Tuple[str, Any, Any, bool]
+            ]:
                 return (
                     (
                         name,
@@ -493,12 +589,16 @@ class _ClassScanMapperConfig(_MapperConfig):
                 field.name: field for field in util.local_dataclass_fields(cls)
             }
 
-            def local_attributes_for_class():
+            fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key
+
+            def local_attributes_for_class() -> Iterable[
+                Tuple[str, Any, Any, bool]
+            ]:
                 for name in names:
                     field = dataclass_fields.get(name, None)
                     if field and sa_dataclass_metadata_key in field.metadata:
                         yield field.name, _as_dc_declaredattr(
-                            field.metadata, sa_dataclass_metadata_key
+                            field.metadata, fixed_sa_dataclass_metadata_key
                         ), cls_annotations.get(field.name), True
                     else:
                         yield name, cls_vars.get(name), cls_annotations.get(
@@ -507,14 +607,17 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         return local_attributes_for_class
 
-    def _scan_attributes(self):
+    def _scan_attributes(self) -> None:
         cls = self.cls
 
+        cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+
         clsdict_view = self.clsdict_view
         collected_attributes = self.collected_attributes
         column_copies = self.column_copies
         mapper_args_fn = None
         table_args = inherited_table_args = None
+
         tablename = None
         fixed_table = "__table__" in clsdict_view
 
@@ -555,21 +658,23 @@ class _ClassScanMapperConfig(_MapperConfig):
                         # make a copy of it so a class-level dictionary
                         # is not overwritten when we update column-based
                         # arguments.
-                        def mapper_args_fn():
-                            return dict(cls.__mapper_args__)
+                        def _mapper_args_fn() -> Dict[str, Any]:
+                            return dict(cls_as_Decl.__mapper_args__)
+
+                        mapper_args_fn = _mapper_args_fn
 
                 elif name == "__tablename__":
                     check_decl = _check_declared_props_nocascade(
                         obj, name, cls
                     )
                     if not tablename and (not class_mapped or check_decl):
-                        tablename = cls.__tablename__
+                        tablename = cls_as_Decl.__tablename__
                 elif name == "__table_args__":
                     check_decl = _check_declared_props_nocascade(
                         obj, name, cls
                     )
                     if not table_args and (not class_mapped or check_decl):
-                        table_args = cls.__table_args__
+                        table_args = cls_as_Decl.__table_args__
                         if not isinstance(
                             table_args, (tuple, dict, type(None))
                         ):
@@ -657,9 +762,10 @@ class _ClassScanMapperConfig(_MapperConfig):
                             # or similar.  note there is no known case that
                             # produces nested proxies, so we are only
                             # looking one level deep right now.
+
                             if (
                                 isinstance(ret, InspectionAttr)
-                                and ret._is_internal_proxy
+                                and attr_is_internal_proxy(ret)
                                 and not isinstance(
                                     ret.original_property, MapperProperty
                                 )
@@ -669,6 +775,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                             collected_attributes[name] = column_copies[
                                 obj
                             ] = ret
+
                         if (
                             isinstance(ret, (Column, MapperProperty))
                             and ret.doc is None
@@ -737,7 +844,9 @@ class _ClassScanMapperConfig(_MapperConfig):
         self.tablename = tablename
         self.mapper_args_fn = mapper_args_fn
 
-    def _warn_for_decl_attributes(self, cls, key, c):
+    def _warn_for_decl_attributes(
+        self, cls: Type[Any], key: str, c: Any
+    ) -> None:
         if isinstance(c, expression.ColumnClause):
             util.warn(
                 f"Attribute '{key}' on class {cls} appears to "
@@ -746,8 +855,12 @@ class _ClassScanMapperConfig(_MapperConfig):
             )
 
     def _produce_column_copies(
-        self, attributes_for_class, attribute_is_overridden
-    ):
+        self,
+        attributes_for_class: Callable[
+            [], Iterable[Tuple[str, Any, Any, bool]]
+        ],
+        attribute_is_overridden: Callable[[str, Any], bool],
+    ) -> None:
         cls = self.cls
         dict_ = self.clsdict_view
         collected_attributes = self.collected_attributes
@@ -763,7 +876,8 @@ class _ClassScanMapperConfig(_MapperConfig):
                     continue
                 elif name not in dict_ and not (
                     "__table__" in dict_
-                    and (obj.name or name) in dict_["__table__"].c
+                    and (getattr(obj, "name", None) or name)
+                    in dict_["__table__"].c
                 ):
                     if obj.foreign_keys:
                         for fk in obj.foreign_keys:
@@ -786,7 +900,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
                     setattr(cls, name, copy_)
 
-    def _extract_mappable_attributes(self):
+    def _extract_mappable_attributes(self) -> None:
         cls = self.cls
         collected_attributes = self.collected_attributes
 
@@ -858,17 +972,19 @@ class _ClassScanMapperConfig(_MapperConfig):
                     "declarative base class."
                 )
             elif isinstance(value, Column):
-                _undefer_column_name(k, self.column_copies.get(value, value))
+                _undefer_column_name(
+                    k, self.column_copies.get(value, value)  # type: ignore
+                )
             elif isinstance(value, _IntrospectsAnnotations):
                 annotation, is_dataclass = self.collected_annotations.get(
-                    k, (None, None)
+                    k, (None, False)
                 )
                 value.declarative_scan(
                     self.registry, cls, k, annotation, is_dataclass
                 )
             our_stuff[k] = value
 
-    def _extract_declared_columns(self):
+    def _extract_declared_columns(self) -> None:
         our_stuff = self.properties
 
         # extract columns from the class dict
@@ -914,8 +1030,10 @@ class _ClassScanMapperConfig(_MapperConfig):
                     % (self.classname, name, (", ".join(sorted(keys))))
                 )
 
-    def _setup_table(self, table=None):
+    def _setup_table(self, table: Optional[FromClause] = None) -> None:
         cls = self.cls
+        cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+
         tablename = self.tablename
         table_args = self.table_args
         clsdict_view = self.clsdict_view
@@ -925,13 +1043,18 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         if "__table__" not in clsdict_view and table is None:
             if hasattr(cls, "__table_cls__"):
-                table_cls = util.unbound_method_to_callable(cls.__table_cls__)
+                table_cls = cast(
+                    Type[Table],
+                    util.unbound_method_to_callable(cls.__table_cls__),  # type: ignore  # noqa: E501
+                )
             else:
                 table_cls = Table
 
             if tablename is not None:
 
-                args, table_kw = (), {}
+                args: Tuple[Any, ...] = ()
+                table_kw: Dict[str, Any] = {}
+
                 if table_args:
                     if isinstance(table_args, dict):
                         table_kw = table_args
@@ -960,7 +1083,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                 )
         else:
             if table is None:
-                table = cls.__table__
+                table = cls_as_Decl.__table__
             if declared_columns:
                 for c in declared_columns:
                     if not table.c.contains_column(c):
@@ -968,15 +1091,16 @@ class _ClassScanMapperConfig(_MapperConfig):
                             "Can't add additional column %r when "
                             "specifying __table__" % c.key
                         )
+
         self.local_table = table
 
-    def _metadata_for_cls(self, manager):
+    def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData:
         if hasattr(self.cls, "metadata"):
-            return self.cls.metadata
+            return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata
         else:
             return manager.registry.metadata
 
-    def _setup_inheritance(self, mapper_kw):
+    def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
         table = self.local_table
         cls = self.cls
         table_args = self.table_args
@@ -988,8 +1112,8 @@ class _ClassScanMapperConfig(_MapperConfig):
             # since we search for classical mappings now, search for
             # multiple mapped bases as well and raise an error.
             inherits_search = []
-            for c in cls.__bases__:
-                c = _resolve_for_abstract_or_classical(c)
+            for base_ in cls.__bases__:
+                c = _resolve_for_abstract_or_classical(base_)
                 if c is None:
                     continue
                 if _declared_mapping_info(
@@ -1024,9 +1148,12 @@ class _ClassScanMapperConfig(_MapperConfig):
                 "table-mapped class." % cls
             )
         elif self.inherits:
-            inherited_mapper = _declared_mapping_info(self.inherits)
-            inherited_table = inherited_mapper.local_table
-            inherited_persist_selectable = inherited_mapper.persist_selectable
+            inherited_mapper_or_config = _declared_mapping_info(self.inherits)
+            assert inherited_mapper_or_config is not None
+            inherited_table = inherited_mapper_or_config.local_table
+            inherited_persist_selectable = (
+                inherited_mapper_or_config.persist_selectable
+            )
 
             if table is None:
                 # single table inheritance.
@@ -1036,29 +1163,44 @@ class _ClassScanMapperConfig(_MapperConfig):
                         "Can't place __table_args__ on an inherited class "
                         "with no table."
                     )
+
                 # add any columns declared here to the inherited table.
-                for c in declared_columns:
-                    if c.name in inherited_table.c:
-                        if inherited_table.c[c.name] is c:
+                if declared_columns and not isinstance(inherited_table, Table):
+                    raise exc.ArgumentError(
+                        f"Can't declare columns on single-table-inherited "
+                        f"subclass {self.cls}; superclass {self.inherits} "
+                        "is not mapped to a Table"
+                    )
+
+                for col in declared_columns:
+                    assert inherited_table is not None
+                    if col.name in inherited_table.c:
+                        if inherited_table.c[col.name] is col:
                             continue
                         raise exc.ArgumentError(
                             "Column '%s' on class %s conflicts with "
                             "existing column '%s'"
-                            % (c, cls, inherited_table.c[c.name])
+                            % (col, cls, inherited_table.c[col.name])
                         )
-                    if c.primary_key:
+                    if col.primary_key:
                         raise exc.ArgumentError(
                             "Can't place primary key columns on an inherited "
                             "class with no table."
                         )
-                    inherited_table.append_column(c)
+
+                    if TYPE_CHECKING:
+                        assert isinstance(inherited_table, Table)
+
+                    inherited_table.append_column(col)
                     if (
                         inherited_persist_selectable is not None
                         and inherited_persist_selectable is not inherited_table
                     ):
-                        inherited_persist_selectable._refresh_for_new_column(c)
+                        inherited_persist_selectable._refresh_for_new_column(
+                            col
+                        )
 
-    def _prepare_mapper_arguments(self, mapper_kw):
+    def _prepare_mapper_arguments(self, mapper_kw: _MapperKwArgs) -> None:
         properties = self.properties
 
         if self.mapper_args_fn:
@@ -1100,6 +1242,7 @@ class _ClassScanMapperConfig(_MapperConfig):
             # not mapped on the parent class, to avoid
             # mapping columns specific to sibling/nephew classes
             inherited_mapper = _declared_mapping_info(self.inherits)
+            assert isinstance(inherited_mapper, Mapper)
             inherited_table = inherited_mapper.local_table
 
             if "exclude_properties" not in mapper_args:
@@ -1133,11 +1276,14 @@ class _ClassScanMapperConfig(_MapperConfig):
         result_mapper_args["properties"] = properties
         self.mapper_args = result_mapper_args
 
-    def map(self, mapper_kw=util.EMPTY_DICT):
+    def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
         self._prepare_mapper_arguments(mapper_kw)
         if hasattr(self.cls, "__mapper_cls__"):
-            mapper_cls = util.unbound_method_to_callable(
-                self.cls.__mapper_cls__
+            mapper_cls = cast(
+                "Type[Mapper[Any]]",
+                util.unbound_method_to_callable(
+                    self.cls.__mapper_cls__  # type: ignore
+                ),
             )
         else:
             mapper_cls = mapper
@@ -1149,7 +1295,9 @@ class _ClassScanMapperConfig(_MapperConfig):
 
 
 @util.preload_module("sqlalchemy.orm.decl_api")
-def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
+def _as_dc_declaredattr(
+    field_metadata: Mapping[str, Any], sa_dataclass_metadata_key: str
+) -> Any:
     # wrap lambdas inside dataclass fields inside an ad-hoc declared_attr.
     # we can't write it because field.metadata is immutable :( so we have
     # to go through extra trouble to compare these
@@ -1162,46 +1310,55 @@ def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
 
 
 class _DeferredMapperConfig(_ClassScanMapperConfig):
-    _configs = util.OrderedDict()
+    _cls: weakref.ref[Type[Any]]
+
+    _configs: util.OrderedDict[
+        weakref.ref[Type[Any]], _DeferredMapperConfig
+    ] = util.OrderedDict()
 
-    def _early_mapping(self, mapper_kw):
+    def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None:
         pass
 
-    @property
-    def cls(self):
-        return self._cls()
+    # mypy disallows plain property override of variable
+    @property  # type: ignore
+    def cls(self) -> Type[Any]:  # type: ignore
+        return self._cls()  # type: ignore
 
     @cls.setter
-    def cls(self, class_):
+    def cls(self, class_: Type[Any]) -> None:
         self._cls = weakref.ref(class_, self._remove_config_cls)
         self._configs[self._cls] = self
 
     @classmethod
-    def _remove_config_cls(cls, ref):
+    def _remove_config_cls(cls, ref: weakref.ref[Type[Any]]) -> None:
         cls._configs.pop(ref, None)
 
     @classmethod
-    def has_cls(cls, class_):
+    def has_cls(cls, class_: Type[Any]) -> bool:
         # 2.6 fails on weakref if class_ is an old style class
         return isinstance(class_, type) and weakref.ref(class_) in cls._configs
 
     @classmethod
-    def raise_unmapped_for_cls(cls, class_):
+    def raise_unmapped_for_cls(cls, class_: Type[Any]) -> NoReturn:
         if hasattr(class_, "_sa_raise_deferred_config"):
-            class_._sa_raise_deferred_config()
+            class_._sa_raise_deferred_config()  # type: ignore
 
         raise orm_exc.UnmappedClassError(
             class_,
-            msg="Class %s has a deferred mapping on it.  It is not yet "
-            "usable as a mapped class." % orm_exc._safe_cls_name(class_),
+            msg=(
+                f"Class {orm_exc._safe_cls_name(class_)} has a deferred "
+                "mapping on it.  It is not yet usable as a mapped class."
+            ),
         )
 
     @classmethod
-    def config_for_cls(cls, class_):
+    def config_for_cls(cls, class_: Type[Any]) -> _DeferredMapperConfig:
         return cls._configs[weakref.ref(class_)]
 
     @classmethod
-    def classes_for_base(cls, base_cls, sort=True):
+    def classes_for_base(
+        cls, base_cls: Type[Any], sort: bool = True
+    ) -> List[_DeferredMapperConfig]:
         classes_for_base = [
             m
             for m, cls_ in [(m, m.cls) for m in cls._configs.values()]
@@ -1213,7 +1370,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig):
 
         all_m_by_cls = dict((m.cls, m) for m in classes_for_base)
 
-        tuples = []
+        tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = []
         for m_cls in all_m_by_cls:
             tuples.extend(
                 (all_m_by_cls[base_cls], all_m_by_cls[m_cls])
@@ -1222,12 +1379,14 @@ class _DeferredMapperConfig(_ClassScanMapperConfig):
             )
         return list(topological.sort(tuples, classes_for_base))
 
-    def map(self, mapper_kw=util.EMPTY_DICT):
+    def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
         self._configs.pop(self._cls, None)
         return super(_DeferredMapperConfig, self).map(mapper_kw)
 
 
-def _add_attribute(cls, key, value):
+def _add_attribute(
+    cls: Type[Any], key: str, value: MapperProperty[Any]
+) -> None:
     """add an attribute to an existing declarative class.
 
     This runs through the logic to determine MapperProperty,
@@ -1236,39 +1395,44 @@ def _add_attribute(cls, key, value):
     """
 
     if "__mapper__" in cls.__dict__:
+        mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls)
         if isinstance(value, Column):
             _undefer_column_name(key, value)
-            cls.__table__.append_column(value, replace_existing=True)
-            cls.__mapper__.add_property(key, value)
+            # TODO: raise for this is not a Table
+            mapped_cls.__table__.append_column(value, replace_existing=True)
+            mapped_cls.__mapper__.add_property(key, value)
         elif isinstance(value, _MapsColumns):
             mp = value.mapper_property_to_assign
             for col in value.columns_to_assign:
                 _undefer_column_name(key, col)
-                cls.__table__.append_column(col, replace_existing=True)
+                # TODO: raise for this is not a Table
+                mapped_cls.__table__.append_column(col, replace_existing=True)
                 if not mp:
-                    cls.__mapper__.add_property(key, col)
+                    mapped_cls.__mapper__.add_property(key, col)
             if mp:
-                cls.__mapper__.add_property(key, mp)
+                mapped_cls.__mapper__.add_property(key, mp)
         elif isinstance(value, MapperProperty):
-            cls.__mapper__.add_property(key, value)
+            mapped_cls.__mapper__.add_property(key, value)
         elif isinstance(value, QueryableAttribute) and value.key != key:
             # detect a QueryableAttribute that's already mapped being
             # assigned elsewhere in userland, turn into a synonym()
             value = Synonym(value.key)
-            cls.__mapper__.add_property(key, value)
+            mapped_cls.__mapper__.add_property(key, value)
         else:
             type.__setattr__(cls, key, value)
-            cls.__mapper__._expire_memoizations()
+            mapped_cls.__mapper__._expire_memoizations()
     else:
         type.__setattr__(cls, key, value)
 
 
-def _del_attribute(cls, key):
+def _del_attribute(cls: Type[Any], key: str) -> None:
 
     if (
         "__mapper__" in cls.__dict__
         and key in cls.__dict__
-        and not cls.__mapper__._dispose_called
+        and not cast(
+            "_DeclMappedClassProtocol[Any]", cls
+        ).__mapper__._dispose_called
     ):
         value = cls.__dict__[key]
         if isinstance(
@@ -1279,12 +1443,14 @@ def _del_attribute(cls, key):
             )
         else:
             type.__delattr__(cls, key)
-            cls.__mapper__._expire_memoizations()
+            cast(
+                "_DeclMappedClassProtocol[Any]", cls
+            ).__mapper__._expire_memoizations()
     else:
         type.__delattr__(cls, key)
 
 
-def _declarative_constructor(self, **kwargs):
+def _declarative_constructor(self: Any, **kwargs: Any) -> None:
     """A simple constructor that allows initialization from kwargs.
 
     Sets attributes on the constructed instance using the names and
@@ -1306,7 +1472,7 @@ def _declarative_constructor(self, **kwargs):
 _declarative_constructor.__name__ = "__init__"
 
 
-def _undefer_column_name(key, column):
+def _undefer_column_name(key: str, column: Column[Any]) -> None:
     if column.key is None:
         column.key = key
     if column.name is None:
index 5975c30db3f1f2828c9eb7c38d4f30fa89fe49e0..8c89f96aa950b7cde53ef4bdf9a84dc711d55fa3 100644 (file)
@@ -20,15 +20,21 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import List
+from typing import NoReturn
 from typing import Optional
+from typing import Sequence
 from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
 from . import attributes
 from . import util as orm_util
+from .base import LoaderCallableStatus
 from .base import Mapped
+from .base import PassiveFlag
+from .base import SQLORMOperations
 from .interfaces import _IntrospectsAnnotations
 from .interfaces import _MapsColumns
 from .interfaces import MapperProperty
@@ -41,20 +47,41 @@ from .. import schema
 from .. import sql
 from .. import util
 from ..sql import expression
-from ..sql import operators
+from ..sql.elements import BindParameter
 from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from ._typing import _InstanceDict
+    from ._typing import _RegistryType
+    from .attributes import History
     from .attributes import InstrumentedAttribute
+    from .attributes import QueryableAttribute
+    from .context import ORMCompileState
+    from .mapper import Mapper
+    from .properties import ColumnProperty
     from .properties import MappedColumn
+    from .state import InstanceState
+    from ..engine.base import Connection
+    from ..engine.row import Row
+    from ..sql._typing import _DMLColumnArgument
     from ..sql._typing import _InfoType
+    from ..sql.elements import ClauseList
+    from ..sql.elements import ColumnElement
     from ..sql.schema import Column
+    from ..sql.selectable import Select
+    from ..util.typing import _AnnotationScanType
+    from ..util.typing import CallableReference
+    from ..util.typing import DescriptorReference
+    from ..util.typing import RODescriptorReference
 
 _T = TypeVar("_T", bound=Any)
 _PT = TypeVar("_PT", bound=Any)
 
 
 class _CompositeClassProto(Protocol):
+    def __init__(self, *args: Any):
+        ...
+
     def __composite_values__(self) -> Tuple[Any, ...]:
         ...
 
@@ -63,32 +90,43 @@ class DescriptorProperty(MapperProperty[_T]):
     """:class:`.MapperProperty` which proxies access to a
     user-defined descriptor."""
 
-    doc = None
+    doc: Optional[str] = None
 
     uses_objects = False
     _links_to_entity = False
 
-    def instrument_class(self, mapper):
+    descriptor: DescriptorReference[Any]
+
+    def get_history(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+    ) -> History:
+        raise NotImplementedError()
+
+    def instrument_class(self, mapper: Mapper[Any]) -> None:
         prop = self
 
-        class _ProxyImpl:
+        class _ProxyImpl(attributes.AttributeImpl):
             accepts_scalar_loader = False
             load_on_unexpire = True
             collection = False
 
             @property
-            def uses_objects(self):
+            def uses_objects(self) -> bool:  # type: ignore
                 return prop.uses_objects
 
-            def __init__(self, key):
+            def __init__(self, key: str):
                 self.key = key
 
-            if hasattr(prop, "get_history"):
-
-                def get_history(
-                    self, state, dict_, passive=attributes.PASSIVE_OFF
-                ):
-                    return prop.get_history(state, dict_, passive)
+            def get_history(
+                self,
+                state: InstanceState[Any],
+                dict_: _InstanceDict,
+                passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+            ) -> History:
+                return prop.get_history(state, dict_, passive)
 
         if self.descriptor is None:
             desc = getattr(mapper.class_, self.key, None)
@@ -97,13 +135,13 @@ class DescriptorProperty(MapperProperty[_T]):
 
         if self.descriptor is None:
 
-            def fset(obj, value):
+            def fset(obj: Any, value: Any) -> None:
                 setattr(obj, self.name, value)
 
-            def fdel(obj):
+            def fdel(obj: Any) -> None:
                 delattr(obj, self.name)
 
-            def fget(obj):
+            def fget(obj: Any) -> Any:
                 return getattr(obj, self.name)
 
             self.descriptor = property(fget=fget, fset=fset, fdel=fdel)
@@ -129,8 +167,11 @@ _CompositeAttrType = Union[
 ]
 
 
+_CC = TypeVar("_CC", bound=_CompositeClassProto)
+
+
 class Composite(
-    _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T]
+    _MapsColumns[_CC], _IntrospectsAnnotations, DescriptorProperty[_CC]
 ):
     """Defines a "composite" mapped attribute, representing a collection
     of columns as one attribute.
@@ -148,19 +189,25 @@ class Composite(
 
     """
 
-    composite_class: Union[
-        Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]]
+    composite_class: Union[Type[_CC], Callable[..., _CC]]
+    attrs: Tuple[_CompositeAttrType[Any], ...]
+
+    _generated_composite_accessor: CallableReference[
+        Optional[Callable[[_CC], Tuple[Any, ...]]]
     ]
-    attrs: Tuple[_CompositeAttrType, ...]
+
+    comparator_factory: Type[Comparator[_CC]]
 
     def __init__(
         self,
-        class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None,
-        *attrs: _CompositeAttrType,
+        class_: Union[
+            None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+        ] = None,
+        *attrs: _CompositeAttrType[Any],
         active_history: bool = False,
         deferred: bool = False,
         group: Optional[str] = None,
-        comparator_factory: Optional[Type[Comparator]] = None,
+        comparator_factory: Optional[Type[Comparator[_CC]]] = None,
         info: Optional[_InfoType] = None,
     ):
         super().__init__()
@@ -170,7 +217,7 @@ class Composite(
             # will initialize within declarative_scan
             self.composite_class = None  # type: ignore
         else:
-            self.composite_class = class_
+            self.composite_class = class_  # type: ignore
             self.attrs = attrs
 
         self.active_history = active_history
@@ -183,18 +230,16 @@ class Composite(
         )
         self._generated_composite_accessor = None
         if info is not None:
-            self.info = info
+            self.info.update(info)
 
         util.set_creation_order(self)
         self._create_descriptor()
 
-    def instrument_class(self, mapper):
+    def instrument_class(self, mapper: Mapper[Any]) -> None:
         super().instrument_class(mapper)
         self._setup_event_handlers()
 
-    def _composite_values_from_instance(
-        self, value: _CompositeClassProto
-    ) -> Tuple[Any, ...]:
+    def _composite_values_from_instance(self, value: _CC) -> Tuple[Any, ...]:
         if self._generated_composite_accessor:
             return self._generated_composite_accessor(value)
         else:
@@ -209,7 +254,7 @@ class Composite(
             else:
                 return accessor()
 
-    def do_init(self):
+    def do_init(self) -> None:
         """Initialization which occurs after the :class:`.Composite`
         has been associated with its parent mapper.
 
@@ -218,13 +263,13 @@ class Composite(
 
     _COMPOSITE_FGET = object()
 
-    def _create_descriptor(self):
+    def _create_descriptor(self) -> None:
         """Create the Python descriptor that will serve as
         the access point on instances of the mapped class.
 
         """
 
-        def fget(instance):
+        def fget(instance: Any) -> Any:
             dict_ = attributes.instance_dict(instance)
             state = attributes.instance_state(instance)
 
@@ -251,11 +296,11 @@ class Composite(
 
             return dict_.get(self.key, None)
 
-        def fset(instance, value):
+        def fset(instance: Any, value: Any) -> None:
             dict_ = attributes.instance_dict(instance)
             state = attributes.instance_state(instance)
             attr = state.manager[self.key]
-            previous = dict_.get(self.key, attributes.NO_VALUE)
+            previous = dict_.get(self.key, LoaderCallableStatus.NO_VALUE)
             for fn in attr.dispatch.set:
                 value = fn(state, value, previous, attr.impl)
             dict_[self.key] = value
@@ -269,10 +314,10 @@ class Composite(
                 ):
                     setattr(instance, key, value)
 
-        def fdel(instance):
+        def fdel(instance: Any) -> None:
             state = attributes.instance_state(instance)
             dict_ = attributes.instance_dict(instance)
-            previous = dict_.pop(self.key, attributes.NO_VALUE)
+            previous = dict_.pop(self.key, LoaderCallableStatus.NO_VALUE)
             attr = state.manager[self.key]
             attr.dispatch.remove(state, previous, attr.impl)
             for key in self._attribute_keys:
@@ -282,8 +327,13 @@ class Composite(
 
     @util.preload_module("sqlalchemy.orm.properties")
     def declarative_scan(
-        self, registry, cls, key, annotation, is_dataclass_field
-    ):
+        self,
+        registry: _RegistryType,
+        cls: Type[Any],
+        key: str,
+        annotation: Optional[_AnnotationScanType],
+        is_dataclass_field: bool,
+    ) -> None:
         MappedColumn = util.preloaded.orm_properties.MappedColumn
 
         argument = _extract_mapped_subtype(
@@ -310,7 +360,9 @@ class Composite(
 
     @util.preload_module("sqlalchemy.orm.properties")
     @util.preload_module("sqlalchemy.orm.decl_base")
-    def _setup_for_dataclass(self, registry, cls, key):
+    def _setup_for_dataclass(
+        self, registry: _RegistryType, cls: Type[Any], key: str
+    ) -> None:
         MappedColumn = util.preloaded.orm_properties.MappedColumn
 
         decl_base = util.preloaded.orm_decl_base
@@ -341,12 +393,12 @@ class Composite(
                 self._generated_composite_accessor = getter
 
     @util.memoized_property
-    def _comparable_elements(self):
+    def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
         return [getattr(self.parent.class_, prop.key) for prop in self.props]
 
     @util.memoized_property
     @util.preload_module("orm.properties")
-    def props(self):
+    def props(self) -> Sequence[MapperProperty[Any]]:
         props = []
         MappedColumn = util.preloaded.orm_properties.MappedColumn
 
@@ -360,17 +412,20 @@ class Composite(
             elif isinstance(attr, attributes.InstrumentedAttribute):
                 prop = attr.property
             else:
+                prop = None
+
+            if not isinstance(prop, MapperProperty):
                 raise sa_exc.ArgumentError(
                     "Composite expects Column objects or mapped "
-                    "attributes/attribute names as arguments, got: %r"
-                    % (attr,)
+                    f"attributes/attribute names as arguments, got: {attr!r}"
                 )
+
             props.append(prop)
         return props
 
-    @property
+    @util.non_memoized_property
     @util.preload_module("orm.properties")
-    def columns(self):
+    def columns(self) -> Sequence[Column[Any]]:
         MappedColumn = util.preloaded.orm_properties.MappedColumn
         return [
             a.column if isinstance(a, MappedColumn) else a
@@ -379,32 +434,46 @@ class Composite(
         ]
 
     @property
-    def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
+    def mapper_property_to_assign(self) -> Optional[MapperProperty[_CC]]:
         return self
 
     @property
-    def columns_to_assign(self) -> List[schema.Column]:
+    def columns_to_assign(self) -> List[schema.Column[Any]]:
         return [c for c in self.columns if c.table is None]
 
-    def _setup_arguments_on_columns(self):
+    @util.preload_module("orm.properties")
+    def _setup_arguments_on_columns(self) -> None:
         """Propagate configuration arguments made on this composite
         to the target columns, for those that apply.
 
         """
+        ColumnProperty = util.preloaded.orm_properties.ColumnProperty
+
         for prop in self.props:
-            prop.active_history = self.active_history
+            if not isinstance(prop, ColumnProperty):
+                continue
+            else:
+                cprop = prop
+
+            cprop.active_history = self.active_history
             if self.deferred:
-                prop.deferred = self.deferred
-                prop.strategy_key = (("deferred", True), ("instrument", True))
-            prop.group = self.group
+                cprop.deferred = self.deferred
+                cprop.strategy_key = (("deferred", True), ("instrument", True))
+            cprop.group = self.group
 
-    def _setup_event_handlers(self):
+    def _setup_event_handlers(self) -> None:
         """Establish events that populate/expire the composite attribute."""
 
-        def load_handler(state, context):
+        def load_handler(
+            state: InstanceState[Any], context: ORMCompileState
+        ) -> None:
             _load_refresh_handler(state, context, None, is_refresh=False)
 
-        def refresh_handler(state, context, to_load):
+        def refresh_handler(
+            state: InstanceState[Any],
+            context: ORMCompileState,
+            to_load: Optional[Sequence[str]],
+        ) -> None:
             # note this corresponds to sqlalchemy.ext.mutable load_attrs()
 
             if not to_load or (
@@ -412,7 +481,12 @@ class Composite(
             ).intersection(to_load):
                 _load_refresh_handler(state, context, to_load, is_refresh=True)
 
-        def _load_refresh_handler(state, context, to_load, is_refresh):
+        def _load_refresh_handler(
+            state: InstanceState[Any],
+            context: ORMCompileState,
+            to_load: Optional[Sequence[str]],
+            is_refresh: bool,
+        ) -> None:
             dict_ = state.dict
 
             # if context indicates we are coming from the
@@ -440,11 +514,17 @@ class Composite(
                 *[state.dict[key] for key in self._attribute_keys]
             )
 
-        def expire_handler(state, keys):
+        def expire_handler(
+            state: InstanceState[Any], keys: Optional[Sequence[str]]
+        ) -> None:
             if keys is None or set(self._attribute_keys).intersection(keys):
                 state.dict.pop(self.key, None)
 
-        def insert_update_handler(mapper, connection, state):
+        def insert_update_handler(
+            mapper: Mapper[Any],
+            connection: Connection,
+            state: InstanceState[Any],
+        ) -> None:
             """After an insert or update, some columns may be expired due
             to server side defaults, or re-populated due to client side
             defaults.  Pop out the composite value here so that it
@@ -473,14 +553,19 @@ class Composite(
         # TODO: need a deserialize hook here
 
     @util.memoized_property
-    def _attribute_keys(self):
+    def _attribute_keys(self) -> Sequence[str]:
         return [prop.key for prop in self.props]
 
-    def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
+    def get_history(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+    ) -> History:
         """Provided for userland code that uses attributes.get_history()."""
 
-        added = []
-        deleted = []
+        added: List[Any] = []
+        deleted: List[Any] = []
 
         has_history = False
         for prop in self.props:
@@ -508,16 +593,27 @@ class Composite(
         else:
             return attributes.History((), [self.composite_class(*added)], ())
 
-    def _comparator_factory(self, mapper):
+    def _comparator_factory(
+        self, mapper: Mapper[Any]
+    ) -> Composite.Comparator[_CC]:
         return self.comparator_factory(self, mapper)
 
-    class CompositeBundle(orm_util.Bundle):
-        def __init__(self, property_, expr):
+    class CompositeBundle(orm_util.Bundle[_T]):
+        def __init__(
+            self,
+            property_: Composite[_T],
+            expr: ClauseList,
+        ):
             self.property = property_
             super().__init__(property_.key, *expr)
 
-        def create_row_processor(self, query, procs, labels):
-            def proc(row):
+        def create_row_processor(
+            self,
+            query: Select[Any],
+            procs: Sequence[Callable[[Row[Any]], Any]],
+            labels: Sequence[str],
+        ) -> Callable[[Row[Any]], Any]:
+            def proc(row: Row[Any]) -> Any:
                 return self.property.composite_class(
                     *[proc(row) for proc in procs]
                 )
@@ -546,17 +642,19 @@ class Composite(
         # https://github.com/python/mypy/issues/4266
         __hash__ = None  # type: ignore
 
+        prop: RODescriptorReference[Composite[_PT]]
+
         @util.memoized_property
-        def clauses(self):
+        def clauses(self) -> ClauseList:
             return expression.ClauseList(
                 group=False, *self._comparable_elements
             )
 
-        def __clause_element__(self):
+        def __clause_element__(self) -> Composite.CompositeBundle[_PT]:
             return self.expression
 
         @util.memoized_property
-        def expression(self):
+        def expression(self) -> Composite.CompositeBundle[_PT]:
             clauses = self.clauses._annotate(
                 {
                     "parententity": self._parententity,
@@ -566,13 +664,19 @@ class Composite(
             )
             return Composite.CompositeBundle(self.prop, clauses)
 
-        def _bulk_update_tuples(self, value):
-            if isinstance(value, sql.elements.BindParameter):
+        def _bulk_update_tuples(
+            self, value: Any
+        ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
+            if isinstance(value, BindParameter):
                 value = value.value
 
+            values: Sequence[Any]
+
             if value is None:
                 values = [None for key in self.prop._attribute_keys]
-            elif isinstance(value, self.prop.composite_class):
+            elif isinstance(self.prop.composite_class, type) and isinstance(
+                value, self.prop.composite_class
+            ):
                 values = self.prop._composite_values_from_instance(value)
             else:
                 raise sa_exc.ArgumentError(
@@ -580,10 +684,10 @@ class Composite(
                     % (self.prop, value)
                 )
 
-            return zip(self._comparable_elements, values)
+            return list(zip(self._comparable_elements, values))
 
         @util.memoized_property
-        def _comparable_elements(self):
+        def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
             if self._adapt_to_entity:
                 return [
                     getattr(self._adapt_to_entity.entity, prop.key)
@@ -592,7 +696,8 @@ class Composite(
             else:
                 return self.prop._comparable_elements
 
-        def __eq__(self, other):
+        def __eq__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
+            values: Sequence[Any]
             if other is None:
                 values = [None] * len(self.prop._comparable_elements)
             else:
@@ -601,13 +706,14 @@ class Composite(
                 a == b for a, b in zip(self.prop._comparable_elements, values)
             ]
             if self._adapt_to_entity:
+                assert self.adapter is not None
                 comparisons = [self.adapter(x) for x in comparisons]
             return sql.and_(*comparisons)
 
-        def __ne__(self, other):
+        def __ne__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
             return sql.not_(self.__eq__(other))
 
-    def __str__(self):
+    def __str__(self) -> str:
         return str(self.parent.class_.__name__) + "." + self.key
 
 
@@ -628,20 +734,24 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]):
 
     """
 
-    def _comparator_factory(self, mapper):
+    def _comparator_factory(
+        self, mapper: Mapper[Any]
+    ) -> Type[PropComparator[_T]]:
+
         comparator_callable = None
 
         for m in self.parent.iterate_to_root():
             p = m._props[self.key]
-            if not isinstance(p, ConcreteInheritedProperty):
+            if getattr(p, "comparator_factory", None) is not None:
                 comparator_callable = p.comparator_factory
                 break
-        return comparator_callable
+        assert comparator_callable is not None
+        return comparator_callable(p, mapper)  # type: ignore
 
-    def __init__(self):
+    def __init__(self) -> None:
         super().__init__()
 
-        def warn():
+        def warn() -> NoReturn:
             raise AttributeError(
                 "Concrete %s does not implement "
                 "attribute %r at the instance level.  Add "
@@ -650,13 +760,13 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]):
             )
 
         class NoninheritedConcreteProp:
-            def __set__(s, obj, value):
+            def __set__(s: Any, obj: Any, value: Any) -> NoReturn:
                 warn()
 
-            def __delete__(s, obj):
+            def __delete__(s: Any, obj: Any) -> NoReturn:
                 warn()
 
-            def __get__(s, obj, owner):
+            def __get__(s: Any, obj: Any, owner: Any) -> Any:
                 if obj is None:
                     return self.descriptor
                 warn()
@@ -682,14 +792,16 @@ class Synonym(DescriptorProperty[_T]):
 
     """
 
+    comparator_factory: Optional[Type[PropComparator[_T]]]
+
     def __init__(
         self,
-        name,
-        map_column=None,
-        descriptor=None,
-        comparator_factory=None,
-        doc=None,
-        info=None,
+        name: str,
+        map_column: Optional[bool] = None,
+        descriptor: Optional[Any] = None,
+        comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+        info: Optional[_InfoType] = None,
+        doc: Optional[str] = None,
     ):
         super().__init__()
 
@@ -697,21 +809,30 @@ class Synonym(DescriptorProperty[_T]):
         self.map_column = map_column
         self.descriptor = descriptor
         self.comparator_factory = comparator_factory
-        self.doc = doc or (descriptor and descriptor.__doc__) or None
+        if doc:
+            self.doc = doc
+        elif descriptor and descriptor.__doc__:
+            self.doc = descriptor.__doc__
+        else:
+            self.doc = None
         if info:
-            self.info = info
+            self.info.update(info)
 
         util.set_creation_order(self)
 
-    @property
-    def uses_objects(self):
-        return getattr(self.parent.class_, self.name).impl.uses_objects
+    if not TYPE_CHECKING:
+
+        @property
+        def uses_objects(self) -> bool:
+            return getattr(self.parent.class_, self.name).impl.uses_objects
 
     # TODO: when initialized, check _proxied_object,
     # emit a warning if its not a column-based property
 
     @util.memoized_property
-    def _proxied_object(self):
+    def _proxied_object(
+        self,
+    ) -> Union[MapperProperty[_T], SQLORMOperations[_T]]:
         attr = getattr(self.parent.class_, self.name)
         if not hasattr(attr, "property") or not isinstance(
             attr.property, MapperProperty
@@ -720,7 +841,8 @@ class Synonym(DescriptorProperty[_T]):
             # hybrid or association proxy
             if isinstance(attr, attributes.QueryableAttribute):
                 return attr.comparator
-            elif isinstance(attr, operators.ColumnOperators):
+            elif isinstance(attr, SQLORMOperations):
+                # assocaition proxy comes here
                 return attr
 
             raise sa_exc.InvalidRequestError(
@@ -730,7 +852,7 @@ class Synonym(DescriptorProperty[_T]):
             )
         return attr.property
 
-    def _comparator_factory(self, mapper):
+    def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]:
         prop = self._proxied_object
 
         if isinstance(prop, MapperProperty):
@@ -742,12 +864,17 @@ class Synonym(DescriptorProperty[_T]):
         else:
             return prop
 
-    def get_history(self, *arg, **kw):
-        attr = getattr(self.parent.class_, self.name)
-        return attr.impl.get_history(*arg, **kw)
+    def get_history(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+    ) -> History:
+        attr: QueryableAttribute[Any] = getattr(self.parent.class_, self.name)
+        return attr.impl.get_history(state, dict_, passive=passive)
 
     @util.preload_module("sqlalchemy.orm.properties")
-    def set_parent(self, parent, init):
+    def set_parent(self, parent: Mapper[Any], init: bool) -> None:
         properties = util.preloaded.orm_properties
 
         if self.map_column:
@@ -776,7 +903,7 @@ class Synonym(DescriptorProperty[_T]):
                     "%r for column %r"
                     % (self.key, self.name, self.name, self.key)
                 )
-            p = properties.ColumnProperty(
+            p: ColumnProperty[Any] = properties.ColumnProperty(
                 parent.persist_selectable.c[self.key]
             )
             parent._configure_property(self.name, p, init=init, setparent=True)
index 1b4f573b506bf63ff54377931ba84c1bf46d76cf..084ba969fb6d82d627d915aa7166fc428fc26c60 100644 (file)
@@ -16,6 +16,12 @@ basic add/delete mutation.
 
 from __future__ import annotations
 
+from typing import Any
+from typing import Optional
+from typing import overload
+from typing import TYPE_CHECKING
+from typing import Union
+
 from . import attributes
 from . import exc as orm_exc
 from . import interfaces
@@ -23,17 +29,27 @@ from . import relationships
 from . import strategies
 from . import util as orm_util
 from .base import object_mapper
+from .base import PassiveFlag
 from .query import Query
 from .session import object_session
 from .. import exc
 from .. import log
 from .. import util
 from ..engine import result
+from ..util.typing import Literal
+
+if TYPE_CHECKING:
+    from ._typing import _InstanceDict
+    from .attributes import _AdaptedCollectionProtocol
+    from .attributes import AttributeEventToken
+    from .attributes import CollectionAdapter
+    from .base import LoaderCallableStatus
+    from .state import InstanceState
 
 
 @log.class_logger
 @relationships.Relationship.strategy_for(lazy="dynamic")
-class DynaLoader(strategies.AbstractRelationshipLoader):
+class DynaLoader(strategies.AbstractRelationshipLoader, log.Identified):
     def init_class_attribute(self, mapper):
         self.is_class_level = True
         if not self.uselist:
@@ -106,13 +122,47 @@ class DynamicAttributeImpl(
         else:
             return self.query_class(self, state)
 
+    @overload
     def get_collection(
         self,
-        state,
-        dict_,
-        user_data=None,
-        passive=attributes.PASSIVE_NO_INITIALIZE,
-    ):
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Literal[None] = ...,
+        passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+    ) -> CollectionAdapter:
+        ...
+
+    @overload
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: _AdaptedCollectionProtocol = ...,
+        passive: PassiveFlag = ...,
+    ) -> CollectionAdapter:
+        ...
+
+    @overload
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = ...,
+        passive: PassiveFlag = ...,
+    ) -> Union[
+        Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+    ]:
+        ...
+
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+    ) -> Union[
+        Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+    ]:
         if not passive & attributes.SQL_OK:
             data = self._get_collection_history(state, passive).added_items
         else:
@@ -170,15 +220,15 @@ class DynamicAttributeImpl(
 
     def set(
         self,
-        state,
-        dict_,
-        value,
-        initiator=None,
-        passive=attributes.PASSIVE_OFF,
-        check_old=None,
-        pop=False,
-        _adapt=True,
-    ):
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken] = None,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+        check_old: Any = None,
+        pop: bool = False,
+        _adapt: bool = True,
+    ) -> None:
         if initiator and initiator.parent_token is self.parent_token:
             return
 
index 331c224eef83b5cbfc1a5fa2929658f3a9805785..726ea79b5bdebc463c04e19b5f79802398c254f6 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
 
 """ORM event interfaces.
 
index f157919ab979bee0c92ab5cc72ada3789a8be335..57e5fe8c6e9fd917acbeb1847d799dcc0f958254 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """SQLAlchemy ORM exceptions."""
 
@@ -12,13 +11,22 @@ from __future__ import annotations
 
 from typing import Any
 from typing import Optional
+from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
 
 from .. import exc as sa_exc
 from .. import util
 from ..exc import MultipleResultsFound  # noqa
 from ..exc import NoResultFound  # noqa
 
+if TYPE_CHECKING:
+    from .interfaces import LoaderStrategy
+    from .interfaces import MapperProperty
+    from .state import InstanceState
+
+_T = TypeVar("_T", bound=Any)
 
 NO_STATE = (AttributeError, KeyError)
 """Exception types that may be raised by instrumentation implementations."""
@@ -100,14 +108,14 @@ class UnmappedInstanceError(UnmappedError):
                     )
         UnmappedError.__init__(self, msg)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Any:
         return self.__class__, (None, self.args[0])
 
 
 class UnmappedClassError(UnmappedError):
     """An mapping operation was requested for an unknown class."""
 
-    def __init__(self, cls: Type[object], msg: Optional[str] = None):
+    def __init__(self, cls: Type[_T], msg: Optional[str] = None):
         if not msg:
             msg = _default_unmapped(cls)
         UnmappedError.__init__(self, msg)
@@ -137,7 +145,7 @@ class ObjectDeletedError(sa_exc.InvalidRequestError):
     """
 
     @util.preload_module("sqlalchemy.orm.base")
-    def __init__(self, state, msg=None):
+    def __init__(self, state: InstanceState[Any], msg: Optional[str] = None):
         base = util.preloaded.orm_base
 
         if not msg:
@@ -148,7 +156,7 @@ class ObjectDeletedError(sa_exc.InvalidRequestError):
 
         sa_exc.InvalidRequestError.__init__(self, msg)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Any:
         return self.__class__, (None, self.args[0])
 
 
@@ -161,11 +169,11 @@ class LoaderStrategyException(sa_exc.InvalidRequestError):
 
     def __init__(
         self,
-        applied_to_property_type,
-        requesting_property,
-        applies_to,
-        actual_strategy_type,
-        strategy_key,
+        applied_to_property_type: Type[Any],
+        requesting_property: MapperProperty[Any],
+        applies_to: Optional[Type[MapperProperty[Any]]],
+        actual_strategy_type: Optional[Type[LoaderStrategy]],
+        strategy_key: Tuple[Any, ...],
     ):
         if actual_strategy_type is None:
             sa_exc.InvalidRequestError.__init__(
@@ -174,6 +182,7 @@ class LoaderStrategyException(sa_exc.InvalidRequestError):
                 % (strategy_key, requesting_property),
             )
         else:
+            assert applies_to is not None
             sa_exc.InvalidRequestError.__init__(
                 self,
                 'Can\'t apply "%s" strategy to property "%s", '
@@ -188,7 +197,8 @@ class LoaderStrategyException(sa_exc.InvalidRequestError):
             )
 
 
-def _safe_cls_name(cls):
+def _safe_cls_name(cls: Type[Any]) -> str:
+    cls_name: Optional[str]
     try:
         cls_name = ".".join((cls.__module__, cls.__name__))
     except AttributeError:
@@ -199,7 +209,7 @@ def _safe_cls_name(cls):
 
 
 @util.preload_module("sqlalchemy.orm.base")
-def _default_unmapped(cls) -> Optional[str]:
+def _default_unmapped(cls: Type[Any]) -> Optional[str]:
     base = util.preloaded.orm_base
 
     try:
index d13265c56072c1df138cc005054b6ca69c1258a1..63b131a780e796105e624c9987aeb70431016342 100644 (file)
@@ -8,6 +8,7 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import cast
 from typing import Dict
 from typing import Iterable
 from typing import Iterator
@@ -15,6 +16,7 @@ from typing import List
 from typing import NoReturn
 from typing import Optional
 from typing import Set
+from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import TypeVar
 import weakref
@@ -66,7 +68,7 @@ class IdentityMap:
     ) -> Optional[_O]:
         raise NotImplementedError()
 
-    def keys(self):
+    def keys(self) -> Iterable[_IdentityKeyType[Any]]:
         return self._dict.keys()
 
     def values(self) -> Iterable[object]:
@@ -117,10 +119,10 @@ class IdentityMap:
 
 
 class WeakInstanceDict(IdentityMap):
-    _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]]
+    _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]]
 
     def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
-        state = self._dict[key]
+        state = cast("InstanceState[_O]", self._dict[key])
         o = state.obj()
         if o is None:
             raise KeyError(key)
@@ -140,6 +142,8 @@ class WeakInstanceDict(IdentityMap):
 
     def contains_state(self, state: InstanceState[Any]) -> bool:
         if state.key in self._dict:
+            if TYPE_CHECKING:
+                assert state.key is not None
             try:
                 return self._dict[state.key] is state
             except KeyError:
@@ -150,15 +154,16 @@ class WeakInstanceDict(IdentityMap):
     def replace(
         self, state: InstanceState[Any]
     ) -> Optional[InstanceState[Any]]:
+        assert state.key is not None
         if state.key in self._dict:
             try:
-                existing = self._dict[state.key]
+                existing = existing_non_none = self._dict[state.key]
             except KeyError:
                 # catch gc removed the key after we just checked for it
                 existing = None
             else:
-                if existing is not state:
-                    self._manage_removed_state(existing)
+                if existing_non_none is not state:
+                    self._manage_removed_state(existing_non_none)
                 else:
                     return None
         else:
@@ -170,6 +175,7 @@ class WeakInstanceDict(IdentityMap):
 
     def add(self, state: InstanceState[Any]) -> bool:
         key = state.key
+        assert key is not None
         # inline of self.__contains__
         if key in self._dict:
             try:
@@ -206,7 +212,7 @@ class WeakInstanceDict(IdentityMap):
         if key not in self._dict:
             return default
         try:
-            state = self._dict[key]
+            state = cast("InstanceState[_O]", self._dict[key])
         except KeyError:
             # catch gc removed the key after we just checked for it
             return default
@@ -216,13 +222,15 @@ class WeakInstanceDict(IdentityMap):
                 return default
             return o
 
-    def items(self) -> List[InstanceState[Any]]:
+    def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]:
         values = self.all_states()
         result = []
         for state in values:
             value = state.obj()
+            key = state.key
+            assert key is not None
             if value is not None:
-                result.append((state.key, value))
+                result.append((key, value))
         return result
 
     def values(self) -> List[object]:
@@ -244,28 +252,32 @@ class WeakInstanceDict(IdentityMap):
     def _fast_discard(self, state: InstanceState[Any]) -> None:
         # used by InstanceState for state being
         # GC'ed, inlines _managed_removed_state
+        key = state.key
+        assert key is not None
         try:
-            st = self._dict[state.key]
+            st = self._dict[key]
         except KeyError:
             # catch gc removed the key after we just checked for it
             pass
         else:
             if st is state:
-                self._dict.pop(state.key, None)
+                self._dict.pop(key, None)
 
     def discard(self, state: InstanceState[Any]) -> None:
         self.safe_discard(state)
 
     def safe_discard(self, state: InstanceState[Any]) -> None:
-        if state.key in self._dict:
+        key = state.key
+        if key in self._dict:
+            assert key is not None
             try:
-                st = self._dict[state.key]
+                st = self._dict[key]
             except KeyError:
                 # catch gc removed the key after we just checked for it
                 pass
             else:
                 if st is state:
-                    self._dict.pop(state.key, None)
+                    self._dict.pop(key, None)
                     self._manage_removed_state(state)
 
 
index 85b85215ea019043b3f2881957973d446a760d4e..4fa61b7ceef56cb8451b99c3d42419ca799b0314 100644 (file)
@@ -66,7 +66,7 @@ from ..util.typing import Protocol
 if TYPE_CHECKING:
     from ._typing import _RegistryType
     from .attributes import AttributeImpl
-    from .attributes import InstrumentedAttribute
+    from .attributes import QueryableAttribute
     from .collections import _AdaptedCollectionProtocol
     from .collections import _CollectionFactoryType
     from .decl_base import _MapperConfig
@@ -96,7 +96,7 @@ class _ManagerFactory(Protocol):
 
 class ClassManager(
     HasMemoized,
-    Dict[str, "InstrumentedAttribute[Any]"],
+    Dict[str, "QueryableAttribute[Any]"],
     Generic[_O],
     EventTarget,
 ):
@@ -117,7 +117,14 @@ class ClassManager(
     factory: Optional[_ManagerFactory]
 
     declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
-    registry: Optional[_RegistryType] = None
+
+    registry: _RegistryType
+
+    if not TYPE_CHECKING:
+        # starts as None during setup
+        registry = None
+
+    class_: Type[_O]
 
     _bases: List[ClassManager[Any]]
 
@@ -312,7 +319,7 @@ class ClassManager(
         else:
             return default
 
-    def _attr_has_impl(self, key):
+    def _attr_has_impl(self, key: str) -> bool:
         """Return True if the given attribute is fully initialized.
 
         i.e. has an impl.
@@ -366,7 +373,12 @@ class ClassManager(
     def dict_getter(self):
         return _default_dict_getter
 
-    def instrument_attribute(self, key, inst, propagated=False):
+    def instrument_attribute(
+        self,
+        key: str,
+        inst: QueryableAttribute[Any],
+        propagated: bool = False,
+    ) -> None:
         if propagated:
             if key in self.local_attrs:
                 return  # don't override local attr with inherited attr
@@ -429,7 +441,7 @@ class ClassManager(
             delattr(self.class_, self.MANAGER_ATTR)
 
     def install_descriptor(
-        self, key: str, inst: InstrumentedAttribute[Any]
+        self, key: str, inst: QueryableAttribute[Any]
     ) -> None:
         if key in (self.STATE_ATTR, self.MANAGER_ATTR):
             raise KeyError(
@@ -490,7 +502,11 @@ class ClassManager(
     # InstanceState management
 
     def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O:
-        instance = self.class_.__new__(self.class_)
+        # here, we would prefer _O to be bound to "object"
+        # so that mypy sees that __new__ is present.   currently
+        # it's bound to Any as there were other problems not having
+        # it that way but these can be revisited
+        instance = self.class_.__new__(self.class_)  # type: ignore
         if state is None:
             state = self._state_constructor(instance, self)
         self._state_setter(instance, state)
index c9c54c1b08dea9192976f072b1e0c600faabec43..b5569ce063e988726fe8e9b171a618bb6532b1b4 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
 
 """
 
@@ -33,6 +32,7 @@ from typing import Sequence
 from typing import Set
 from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -48,6 +48,7 @@ from .base import MANYTOMANY as MANYTOMANY  # noqa: F401
 from .base import MANYTOONE as MANYTOONE  # noqa: F401
 from .base import NotExtension as NotExtension  # noqa: F401
 from .base import ONETOMANY as ONETOMANY  # noqa: F401
+from .base import RelationshipDirection as RelationshipDirection  # noqa: F401
 from .base import SQLORMOperations
 from .. import ColumnElement
 from .. import inspection
@@ -59,7 +60,7 @@ from ..sql.base import ExecutableOption
 from ..sql.cache_key import HasCacheKey
 from ..sql.schema import Column
 from ..sql.type_api import TypeEngine
-from ..util.typing import DescriptorReference
+from ..util.typing import RODescriptorReference
 from ..util.typing import TypedDict
 
 if typing.TYPE_CHECKING:
@@ -75,13 +76,11 @@ if typing.TYPE_CHECKING:
     from .loading import _PopulatorDict
     from .mapper import Mapper
     from .path_registry import AbstractEntityRegistry
-    from .path_registry import PathRegistry
     from .query import Query
     from .session import Session
     from .state import InstanceState
     from .strategy_options import _LoadElement
     from .util import AliasedInsp
-    from .util import CascadeOptions
     from .util import ORMAdapter
     from ..engine.result import Result
     from ..sql._typing import _ColumnExpressionArgument
@@ -89,8 +88,10 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _DMLColumnArgument
     from ..sql._typing import _InfoType
     from ..sql.operators import OperatorType
-    from ..sql.util import ColumnAdapter
     from ..sql.visitors import _TraverseInternalsType
+    from ..util.typing import _AnnotationScanType
+
+_StrategyKey = Tuple[Any, ...]
 
 _T = TypeVar("_T", bound=Any)
 
@@ -104,7 +105,9 @@ class ORMStatementRole(roles.StatementRole):
     )
 
 
-class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]):
+class ORMColumnsClauseRole(
+    roles.ColumnsClauseRole, roles.TypedColumnsClauseRole[_T]
+):
     __slots__ = ()
     _role_name = "ORM mapped entity, aliased entity, or Column expression"
 
@@ -137,8 +140,8 @@ class _IntrospectsAnnotations:
         registry: RegistryType,
         cls: Type[Any],
         key: str,
-        annotation: Optional[Type[Any]],
-        is_dataclass_field: Optional[bool],
+        annotation: Optional[_AnnotationScanType],
+        is_dataclass_field: bool,
     ) -> None:
         """Perform class-specific initializaton at early declarative scanning
         time.
@@ -199,6 +202,7 @@ class MapperProperty(
         "parent",
         "key",
         "info",
+        "doc",
     )
 
     _cache_key_traversal: _TraverseInternalsType = [
@@ -206,14 +210,8 @@ class MapperProperty(
         ("key", visitors.ExtendedInternalTraversal.dp_string),
     ]
 
-    cascade: Optional[CascadeOptions] = None
-    """The set of 'cascade' attribute names.
-
-    This collection is checked before the 'cascade_iterator' method is called.
-
-    The collection typically only applies to a Relationship.
-
-    """
+    if not TYPE_CHECKING:
+        cascade = None
 
     is_property = True
     """Part of the InspectionAttr interface; states this object is a
@@ -240,6 +238,9 @@ class MapperProperty(
 
     """
 
+    doc: Optional[str]
+    """optional documentation string"""
+
     def _memoized_attr_info(self) -> _InfoType:
         """Info dictionary associated with the object, allowing user-defined
         data to be associated with this :class:`.InspectionAttr`.
@@ -268,8 +269,8 @@ class MapperProperty(
         self,
         context: ORMCompileState,
         query_entity: _MapperEntity,
-        path: PathRegistry,
-        adapter: Optional[ColumnAdapter],
+        path: AbstractEntityRegistry,
+        adapter: Optional[ORMAdapter],
         **kwargs: Any,
     ) -> None:
         """Called by Query for the purposes of constructing a SQL statement.
@@ -284,10 +285,10 @@ class MapperProperty(
         self,
         context: ORMCompileState,
         query_entity: _MapperEntity,
-        path: PathRegistry,
+        path: AbstractEntityRegistry,
         mapper: Mapper[Any],
         result: Result[Any],
-        adapter: Optional[ColumnAdapter],
+        adapter: Optional[ORMAdapter],
         populators: _PopulatorDict,
     ) -> None:
         """Produce row processing functions and append to the given
@@ -421,7 +422,7 @@ class MapperProperty(
         dest_state: InstanceState[Any],
         dest_dict: _InstanceDict,
         load: bool,
-        _recursive: Set[InstanceState[Any]],
+        _recursive: Dict[Any, object],
         _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
     ) -> None:
         """Merge the attribute represented by this ``MapperProperty``
@@ -526,7 +527,7 @@ class PropComparator(SQLORMOperations[_T]):
 
     _parententity: _InternalEntityType[Any]
     _adapt_to_entity: Optional[AliasedInsp[Any]]
-    prop: DescriptorReference[MapperProperty[_T]]
+    prop: RODescriptorReference[MapperProperty[_T]]
 
     def __init__(
         self,
@@ -539,7 +540,7 @@ class PropComparator(SQLORMOperations[_T]):
         self._adapt_to_entity = adapt_to_entity
 
     @util.non_memoized_property
-    def property(self) -> Optional[MapperProperty[_T]]:
+    def property(self) -> MapperProperty[_T]:
         """Return the :class:`.MapperProperty` associated with this
         :class:`.PropComparator`.
 
@@ -589,7 +590,7 @@ class PropComparator(SQLORMOperations[_T]):
         return self.prop.comparator._criterion_exists(criterion, **kwargs)
 
     @util.ro_non_memoized_property
-    def adapter(self) -> Optional[_ORMAdapterProto[_T]]:
+    def adapter(self) -> Optional[_ORMAdapterProto]:
         """Produce a callable that adapts column expressions
         to suit an aliased version of this comparator.
 
@@ -597,7 +598,7 @@ class PropComparator(SQLORMOperations[_T]):
         if self._adapt_to_entity is None:
             return None
         else:
-            return self._adapt_to_entity._adapt_element
+            return self._adapt_to_entity._orm_adapt_element
 
     @util.ro_non_memoized_property
     def info(self) -> _InfoType:
@@ -631,7 +632,7 @@ class PropComparator(SQLORMOperations[_T]):
         ) -> ColumnElement[Any]:
             ...
 
-    def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
+    def of_type(self, class_: _EntityType[_T]) -> PropComparator[_T]:
         r"""Redefine this object in terms of a polymorphic subclass,
         :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased`
         construct.
@@ -763,9 +764,9 @@ class StrategizedProperty(MapperProperty[_T]):
     inherit_cache = True
     strategy_wildcard_key: ClassVar[str]
 
-    strategy_key: Tuple[Any, ...]
+    strategy_key: _StrategyKey
 
-    _strategies: Dict[Tuple[Any, ...], LoaderStrategy]
+    _strategies: Dict[_StrategyKey, LoaderStrategy]
 
     def _memoized_attr__wildcard_token(self) -> Tuple[str]:
         return (
@@ -808,7 +809,7 @@ class StrategizedProperty(MapperProperty[_T]):
 
         return load
 
-    def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy:
+    def _get_strategy(self, key: _StrategyKey) -> LoaderStrategy:
         try:
             return self._strategies[key]
         except KeyError:
@@ -822,7 +823,14 @@ class StrategizedProperty(MapperProperty[_T]):
         self._strategies[key] = strategy = cls(self, key)
         return strategy
 
-    def setup(self, context, query_entity, path, adapter, **kwargs):
+    def setup(
+        self,
+        context: ORMCompileState,
+        query_entity: _MapperEntity,
+        path: AbstractEntityRegistry,
+        adapter: Optional[ORMAdapter],
+        **kwargs: Any,
+    ) -> None:
         loader = self._get_context_loader(context, path)
         if loader and loader.strategy:
             strat = self._get_strategy(loader.strategy)
@@ -833,8 +841,15 @@ class StrategizedProperty(MapperProperty[_T]):
         )
 
     def create_row_processor(
-        self, context, query_entity, path, mapper, result, adapter, populators
-    ):
+        self,
+        context: ORMCompileState,
+        query_entity: _MapperEntity,
+        path: AbstractEntityRegistry,
+        mapper: Mapper[Any],
+        result: Result[Any],
+        adapter: Optional[ORMAdapter],
+        populators: _PopulatorDict,
+    ) -> None:
         loader = self._get_context_loader(context, path)
         if loader and loader.strategy:
             strat = self._get_strategy(loader.strategy)
@@ -851,11 +866,11 @@ class StrategizedProperty(MapperProperty[_T]):
             populators,
         )
 
-    def do_init(self):
+    def do_init(self) -> None:
         self._strategies = {}
         self.strategy = self._get_strategy(self.strategy_key)
 
-    def post_instrument_class(self, mapper):
+    def post_instrument_class(self, mapper: Mapper[Any]) -> None:
         if (
             not self.parent.non_primary
             and not mapper.class_manager._attr_has_impl(self.key)
@@ -863,7 +878,7 @@ class StrategizedProperty(MapperProperty[_T]):
             self.strategy.init_class_attribute(mapper)
 
     _all_strategies: collections.defaultdict[
-        Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]]
+        Type[MapperProperty[Any]], Dict[_StrategyKey, Type[LoaderStrategy]]
     ] = collections.defaultdict(dict)
 
     @classmethod
@@ -888,6 +903,8 @@ class StrategizedProperty(MapperProperty[_T]):
 
         for prop_cls in cls.__mro__:
             if prop_cls in cls._all_strategies:
+                if TYPE_CHECKING:
+                    assert issubclass(prop_cls, MapperProperty)
                 strategies = cls._all_strategies[prop_cls]
                 try:
                     return strategies[key]
@@ -976,8 +993,8 @@ class CompileStateOption(HasCacheKey, ORMOption):
 
     _is_compile_state = True
 
-    def process_compile_state(self, compile_state):
-        """Apply a modification to a given :class:`.CompileState`.
+    def process_compile_state(self, compile_state: ORMCompileState) -> None:
+        """Apply a modification to a given :class:`.ORMCompileState`.
 
         This method is part of the implementation of a particular
         :class:`.CompileStateOption` and is only invoked internally
@@ -986,9 +1003,11 @@ class CompileStateOption(HasCacheKey, ORMOption):
         """
 
     def process_compile_state_replaced_entities(
-        self, compile_state, mapper_entities
-    ):
-        """Apply a modification to a given :class:`.CompileState`,
+        self,
+        compile_state: ORMCompileState,
+        mapper_entities: Sequence[_MapperEntity],
+    ) -> None:
+        """Apply a modification to a given :class:`.ORMCompileState`,
         given entities that were replaced by with_only_columns() or
         with_entities().
 
@@ -1011,8 +1030,10 @@ class LoaderOption(CompileStateOption):
     __slots__ = ()
 
     def process_compile_state_replaced_entities(
-        self, compile_state, mapper_entities
-    ):
+        self,
+        compile_state: ORMCompileState,
+        mapper_entities: Sequence[_MapperEntity],
+    ) -> None:
         self.process_compile_state(compile_state)
 
 
@@ -1028,7 +1049,7 @@ class CriteriaOption(CompileStateOption):
 
     _is_criteria_option = True
 
-    def get_global_criteria(self, attributes):
+    def get_global_criteria(self, attributes: Dict[str, Any]) -> None:
         """update additional entity criteria options in the given
         attributes dictionary.
 
@@ -1054,7 +1075,7 @@ class UserDefinedOption(ORMOption):
 
     """
 
-    def __init__(self, payload=None):
+    def __init__(self, payload: Optional[Any] = None):
         self.payload = payload
 
 
@@ -1132,10 +1153,10 @@ class LoaderStrategy:
         "strategy_opts",
     )
 
-    _strategy_keys: ClassVar[List[Tuple[Any, ...]]]
+    _strategy_keys: ClassVar[List[_StrategyKey]]
 
     def __init__(
-        self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...]
+        self, parent: MapperProperty[Any], strategy_key: _StrategyKey
     ):
         self.parent_property = parent
         self.is_class_level = False
@@ -1186,5 +1207,5 @@ class LoaderStrategy:
 
         """
 
-    def __str__(self):
+    def __str__(self) -> str:
         return str(self.parent_property)
index 75887367e72fad0001867813e253246cb54980ea..1a5ea5fe651b6332222ba03a4f4d25cff00637d9 100644 (file)
@@ -54,11 +54,15 @@ from ..sql.selectable import SelectState
 if TYPE_CHECKING:
     from ._typing import _IdentityKeyType
     from .base import LoaderCallableStatus
+    from .context import QueryContext
     from .interfaces import ORMOption
     from .mapper import Mapper
+    from .query import Query
     from .session import Session
     from .state import InstanceState
+    from ..engine.cursor import CursorResult
     from ..engine.interfaces import _ExecuteOptions
+    from ..engine.result import Result
     from ..sql import Select
 
 _T = TypeVar("_T", bound=Any)
@@ -69,7 +73,7 @@ _new_runid = util.counter()
 _PopulatorDict = Dict[str, List[Tuple[str, Any]]]
 
 
-def instances(cursor, context):
+def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]:
     """Return a :class:`.Result` given an ORM query context.
 
     :param cursor: a :class:`.CursorResult`, generated by a statement
@@ -152,7 +156,7 @@ def instances(cursor, context):
         unique_filters = [
             _no_unique
             if context.yield_per
-            else _not_hashable(ent.column.type)
+            else _not_hashable(ent.column.type)  # type: ignore
             if (not ent.use_id_for_hash and ent._non_hashable_value)
             else id
             if ent.use_id_for_hash
@@ -164,7 +168,7 @@ def instances(cursor, context):
         labels, extra, _unique_filters=unique_filters
     )
 
-    def chunks(size):
+    def chunks(size):  # type: ignore
         while True:
             yield_per = size
 
@@ -302,7 +306,11 @@ def merge_frozen_result(session, statement, frozen_result, load=True):
     "is superseded by the :func:`_orm.merge_frozen_result` function.",
 )
 @util.preload_module("sqlalchemy.orm.context")
-def merge_result(query, iterator, load=True):
+def merge_result(
+    query: Query[Any],
+    iterator: Union[FrozenResult, Iterable[Sequence[Any]], Iterable[object]],
+    load: bool = True,
+) -> Union[FrozenResult, Iterable[Any]]:
     """Merge a result into the given :class:`.Query` object's Session.
 
     See :meth:`_orm.Query.merge_result` for top-level documentation on this
@@ -375,7 +383,7 @@ def merge_result(query, iterator, load=True):
                 result.append(keyed_tuple(newrow))
 
         if frozen_result:
-            return frozen_result.with_data(result)
+            return frozen_result.with_new_rows(result)
         else:
             return iter(result)
     finally:
index 4324a000d184984f784aae4b98e12373b06b2b83..d1057ca5f3b45f4ccd4278e93f244c81825d916f 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
 
 from __future__ import annotations
 
index 337a7178b0dec7d8914129e1490f74a87731fdc2..2d3bceb92807df9a3c16b6d3f64166af47ac29d5 100644 (file)
@@ -80,6 +80,7 @@ from ..sql import roles
 from ..sql import util as sql_util
 from ..sql import visitors
 from ..sql.cache_key import MemoizedHasCacheKey
+from ..sql.elements import KeyedColumnElement
 from ..sql.schema import Table
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..util import HasMemoized
@@ -108,7 +109,6 @@ if TYPE_CHECKING:
     from ..sql.base import ReadOnlyColumnCollection
     from ..sql.elements import ColumnClause
     from ..sql.elements import ColumnElement
-    from ..sql.elements import KeyedColumnElement
     from ..sql.schema import Column
     from ..sql.selectable import FromClause
     from ..sql.util import ColumnAdapter
@@ -182,6 +182,7 @@ class Mapper(
     dispatch: dispatcher[Mapper[_O]]
 
     _dispose_called = False
+    _configure_failed: Any = False
     _ready_for_configure = False
 
     @util.deprecated_params(
@@ -710,8 +711,11 @@ class Mapper(
         self.batch = batch
         self.eager_defaults = eager_defaults
         self.column_prefix = column_prefix
-        self.polymorphic_on = (
-            coercions.expect(
+
+        # interim - polymorphic_on is further refined in
+        # _configure_polymorphic_setter
+        self.polymorphic_on = (  # type: ignore
+            coercions.expect(  # type: ignore
                 roles.ColumnArgumentOrKeyRole,
                 polymorphic_on,
                 argname="polymorphic_on",
@@ -1832,12 +1836,22 @@ class Mapper(
                 )
 
     @util.preload_module("sqlalchemy.orm.descriptor_props")
-    def _configure_property(self, key, prop, init=True, setparent=True):
+    def _configure_property(
+        self,
+        key: str,
+        prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]],
+        init: bool = True,
+        setparent: bool = True,
+    ) -> MapperProperty[Any]:
         descriptor_props = util.preloaded.orm_descriptor_props
-        self._log("_configure_property(%s, %s)", key, prop.__class__.__name__)
+        self._log(
+            "_configure_property(%s, %s)", key, prop_arg.__class__.__name__
+        )
 
-        if not isinstance(prop, MapperProperty):
-            prop = self._property_from_column(key, prop)
+        if not isinstance(prop_arg, MapperProperty):
+            prop = self._property_from_column(key, prop_arg)
+        else:
+            prop = prop_arg
 
         if isinstance(prop, properties.ColumnProperty):
             col = self.persist_selectable.corresponding_column(prop.columns[0])
@@ -1950,18 +1964,23 @@ class Mapper(
         if self.configured:
             self._expire_memoizations()
 
+        return prop
+
     @util.preload_module("sqlalchemy.orm.descriptor_props")
-    def _property_from_column(self, key, prop):
+    def _property_from_column(
+        self,
+        key: str,
+        prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]],
+    ) -> MapperProperty[Any]:
         """generate/update a :class:`.ColumnProperty` given a
         :class:`_schema.Column` object."""
         descriptor_props = util.preloaded.orm_descriptor_props
         # we were passed a Column or a list of Columns;
         # generate a properties.ColumnProperty
-        columns = util.to_list(prop)
+        columns = util.to_list(prop_arg)
         column = columns[0]
-        assert isinstance(column, expression.ColumnElement)
 
-        prop = self._props.get(key, None)
+        prop = self._props.get(key)
 
         if isinstance(prop, properties.ColumnProperty):
             if (
@@ -2033,11 +2052,11 @@ class Mapper(
                 "columns get mapped." % (key, self, column.key, prop)
             )
 
-    def _check_configure(self):
+    def _check_configure(self) -> None:
         if self.registry._new_mappers:
             _configure_registries({self.registry}, cascade=True)
 
-    def _post_configure_properties(self):
+    def _post_configure_properties(self) -> None:
         """Call the ``init()`` method on all ``MapperProperties``
         attached to this mapper.
 
@@ -2068,7 +2087,9 @@ class Mapper(
         for key, value in dict_of_properties.items():
             self.add_property(key, value)
 
-    def add_property(self, key, prop):
+    def add_property(
+        self, key: str, prop: Union[Column[Any], MapperProperty[Any]]
+    ) -> None:
         """Add an individual MapperProperty to this mapper.
 
         If the mapper has not been configured yet, just adds the
@@ -2077,15 +2098,16 @@ class Mapper(
         the given MapperProperty is configured immediately.
 
         """
+        prop = self._configure_property(key, prop, init=self.configured)
+        assert isinstance(prop, MapperProperty)
         self._init_properties[key] = prop
-        self._configure_property(key, prop, init=self.configured)
 
-    def _expire_memoizations(self):
+    def _expire_memoizations(self) -> None:
         for mapper in self.iterate_to_root():
             mapper._reset_memoizations()
 
     @property
-    def _log_desc(self):
+    def _log_desc(self) -> str:
         return (
             "("
             + self.class_.__name__
@@ -2099,16 +2121,16 @@ class Mapper(
             + ")"
         )
 
-    def _log(self, msg, *args):
+    def _log(self, msg: str, *args: Any) -> None:
         self.logger.info("%s " + msg, *((self._log_desc,) + args))
 
-    def _log_debug(self, msg, *args):
+    def _log_debug(self, msg: str, *args: Any) -> None:
         self.logger.debug("%s " + msg, *((self._log_desc,) + args))
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<Mapper at 0x%x; %s>" % (id(self), self.class_.__name__)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "Mapper[%s%s(%s)]" % (
             self.class_.__name__,
             self.non_primary and " (non-primary)" or "",
@@ -2155,7 +2177,9 @@ class Mapper(
                 "Mapper '%s' has no property '%s'" % (self, key)
             ) from err
 
-    def get_property_by_column(self, column):
+    def get_property_by_column(
+        self, column: ColumnElement[_T]
+    ) -> MapperProperty[_T]:
         """Given a :class:`_schema.Column` object, return the
         :class:`.MapperProperty` which maps this column."""
 
@@ -2795,7 +2819,7 @@ class Mapper(
 
         return result
 
-    def _is_userland_descriptor(self, assigned_name, obj):
+    def _is_userland_descriptor(self, assigned_name: str, obj: Any) -> bool:
         if isinstance(
             obj,
             (
@@ -3603,7 +3627,9 @@ def configure_mappers():
     _configure_registries(_all_registries(), cascade=True)
 
 
-def _configure_registries(registries, cascade):
+def _configure_registries(
+    registries: Set[_RegistryType], cascade: bool
+) -> None:
     for reg in registries:
         if reg._new_mappers:
             break
@@ -3637,7 +3663,9 @@ def _configure_registries(registries, cascade):
 
 
 @util.preload_module("sqlalchemy.orm.decl_api")
-def _do_configure_registries(registries, cascade):
+def _do_configure_registries(
+    registries: Set[_RegistryType], cascade: bool
+) -> None:
 
     registry = util.preloaded.orm_decl_api.registry
 
@@ -3688,7 +3716,7 @@ def _do_configure_registries(registries, cascade):
 
 
 @util.preload_module("sqlalchemy.orm.decl_api")
-def _dispose_registries(registries, cascade):
+def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None:
 
     registry = util.preloaded.orm_decl_api.registry
 
index 361cea9757731aa326f440125539949b700d11e1..36c14a6727226bce97426909e8b4e8f4eff2c613 100644 (file)
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
     from ..sql.cache_key import _CacheKeyTraversalType
     from ..sql.elements import BindParameter
     from ..sql.visitors import anon_map
+    from ..util.typing import _LiteralStar
     from ..util.typing import TypeGuard
 
     def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]:
@@ -80,7 +81,7 @@ def _unreduce_path(path: _SerializedPath) -> PathRegistry:
     return PathRegistry.deserialize(path)
 
 
-_WILDCARD_TOKEN = "*"
+_WILDCARD_TOKEN: _LiteralStar = "*"
 _DEFAULT_TOKEN = "_sa_default"
 
 
@@ -115,6 +116,7 @@ class PathRegistry(HasCacheKey):
     is_token = False
     is_root = False
     has_entity = False
+    is_property = False
     is_entity = False
 
     path: _PathRepresentation
@@ -175,7 +177,40 @@ class PathRegistry(HasCacheKey):
     def __hash__(self) -> int:
         return id(self)
 
-    def __getitem__(self, key: Any) -> PathRegistry:
+    @overload
+    def __getitem__(self, entity: str) -> TokenRegistry:
+        ...
+
+    @overload
+    def __getitem__(self, entity: int) -> _PathElementType:
+        ...
+
+    @overload
+    def __getitem__(self, entity: slice) -> _PathRepresentation:
+        ...
+
+    @overload
+    def __getitem__(
+        self, entity: _InternalEntityType[Any]
+    ) -> AbstractEntityRegistry:
+        ...
+
+    @overload
+    def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
+        ...
+
+    def __getitem__(
+        self,
+        entity: Union[
+            str, int, slice, _InternalEntityType[Any], MapperProperty[Any]
+        ],
+    ) -> Union[
+        TokenRegistry,
+        _PathElementType,
+        _PathRepresentation,
+        PropRegistry,
+        AbstractEntityRegistry,
+    ]:
         raise NotImplementedError()
 
     # TODO: what are we using this for?
@@ -343,18 +378,8 @@ class RootRegistry(CreatesToken):
     is_root = True
     is_unnatural = False
 
-    @overload
-    def __getitem__(self, entity: str) -> TokenRegistry:
-        ...
-
-    @overload
-    def __getitem__(
-        self, entity: _InternalEntityType[Any]
-    ) -> AbstractEntityRegistry:
-        ...
-
-    def __getitem__(
-        self, entity: Union[str, _InternalEntityType[Any]]
+    def _getitem(
+        self, entity: Any
     ) -> Union[TokenRegistry, AbstractEntityRegistry]:
         if entity in PathToken._intern:
             if TYPE_CHECKING:
@@ -368,6 +393,9 @@ class RootRegistry(CreatesToken):
                     f"invalid argument for RootRegistry.__getitem__: {entity}"
                 )
 
+    if not TYPE_CHECKING:
+        __getitem__ = _getitem
+
 
 PathRegistry.root = RootRegistry()
 
@@ -441,12 +469,15 @@ class TokenRegistry(PathRegistry):
         else:
             yield self
 
-    def __getitem__(self, entity: Any) -> Any:
+    def _getitem(self, entity: Any) -> Any:
         try:
             return self.path[entity]
         except TypeError as err:
             raise IndexError(f"{entity}") from err
 
+    if not TYPE_CHECKING:
+        __getitem__ = _getitem
+
 
 class PropRegistry(PathRegistry):
     __slots__ = (
@@ -463,6 +494,7 @@ class PropRegistry(PathRegistry):
         "is_unnatural",
     )
     inherit_cache = True
+    is_property = True
 
     prop: MapperProperty[Any]
     mapper: Optional[Mapper[Any]]
@@ -557,21 +589,7 @@ class PropRegistry(PathRegistry):
         assert self.entity is not None
         return self[self.entity]
 
-    @overload
-    def __getitem__(self, entity: slice) -> _PathRepresentation:
-        ...
-
-    @overload
-    def __getitem__(self, entity: int) -> _PathElementType:
-        ...
-
-    @overload
-    def __getitem__(
-        self, entity: _InternalEntityType[Any]
-    ) -> AbstractEntityRegistry:
-        ...
-
-    def __getitem__(
+    def _getitem(
         self, entity: Union[int, slice, _InternalEntityType[Any]]
     ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]:
         if isinstance(entity, (int, slice)):
@@ -579,6 +597,9 @@ class PropRegistry(PathRegistry):
         else:
             return SlotsEntityRegistry(self, entity)
 
+    if not TYPE_CHECKING:
+        __getitem__ = _getitem
+
 
 class AbstractEntityRegistry(CreatesToken):
     __slots__ = (
@@ -642,6 +663,10 @@ class AbstractEntityRegistry(CreatesToken):
             # self.natural_path = parent.natural_path + (entity, )
             self.natural_path = self.path
 
+    @property
+    def root_entity(self) -> _InternalEntityType[Any]:
+        return cast("_InternalEntityType[Any]", self.path[0])
+
     @property
     def entity_path(self) -> PathRegistry:
         return self
@@ -653,23 +678,7 @@ class AbstractEntityRegistry(CreatesToken):
     def __bool__(self) -> bool:
         return True
 
-    @overload
-    def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
-        ...
-
-    @overload
-    def __getitem__(self, entity: str) -> TokenRegistry:
-        ...
-
-    @overload
-    def __getitem__(self, entity: int) -> _PathElementType:
-        ...
-
-    @overload
-    def __getitem__(self, entity: slice) -> _PathRepresentation:
-        ...
-
-    def __getitem__(
+    def _getitem(
         self, entity: Any
     ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]:
         if isinstance(entity, (int, slice)):
@@ -679,6 +688,9 @@ class AbstractEntityRegistry(CreatesToken):
         else:
             return PropRegistry(self, entity)
 
+    if not TYPE_CHECKING:
+        __getitem__ = _getitem
+
 
 class SlotsEntityRegistry(AbstractEntityRegistry):
     # for aliased class, return lightweight, no-cycles created
@@ -715,10 +727,28 @@ class CachingEntityRegistry(AbstractEntityRegistry):
     def pop(self, key: Any, default: Any) -> Any:
         return self._cache.pop(key, default)
 
-    def __getitem__(self, entity: Any) -> Any:
+    def _getitem(self, entity: Any) -> Any:
         if isinstance(entity, (int, slice)):
             return self.path[entity]
         elif isinstance(entity, PathToken):
             return TokenRegistry(self, entity)
         else:
             return self._cache[entity]
+
+    if not TYPE_CHECKING:
+        __getitem__ = _getitem
+
+
+if TYPE_CHECKING:
+
+    def path_is_entity(
+        path: PathRegistry,
+    ) -> TypeGuard[AbstractEntityRegistry]:
+        ...
+
+    def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]:
+        ...
+
+else:
+    path_is_entity = operator.attrgetter("is_entity")
+    path_is_property = operator.attrgetter("is_property")
index 0ca0559b4523b9c73ca56bde3812f7332cb61d16..911617d6d0e1432acc763b97c2a4677ff8a67182 100644 (file)
@@ -16,8 +16,10 @@ from __future__ import annotations
 
 from typing import Any
 from typing import cast
+from typing import Dict
 from typing import List
 from typing import Optional
+from typing import Sequence
 from typing import Set
 from typing import Type
 from typing import TYPE_CHECKING
@@ -25,7 +27,6 @@ from typing import TypeVar
 
 from . import attributes
 from . import strategy_options
-from .base import SQLCoreOperations
 from .descriptor_props import Composite
 from .descriptor_props import ConcreteInheritedProperty
 from .descriptor_props import Synonym
@@ -44,20 +45,34 @@ from .. import util
 from ..sql import coercions
 from ..sql import roles
 from ..sql import sqltypes
+from ..sql.elements import SQLCoreOperations
 from ..sql.schema import Column
 from ..sql.schema import SchemaConst
 from ..util.typing import de_optionalize_union_types
 from ..util.typing import de_stringify_annotation
 from ..util.typing import is_fwd_ref
 from ..util.typing import NoneType
+from ..util.typing import Self
 
 if TYPE_CHECKING:
+    from ._typing import _IdentityKeyType
+    from ._typing import _InstanceDict
     from ._typing import _ORMColumnExprArgument
+    from ._typing import _RegistryType
+    from .mapper import Mapper
+    from .session import Session
+    from .state import _InstallLoaderCallableProto
+    from .state import InstanceState
     from ..sql._typing import _InfoType
-    from ..sql.elements import KeyedColumnElement
+    from ..sql.elements import ColumnElement
+    from ..sql.elements import NamedColumn
+    from ..sql.operators import OperatorType
+    from ..util.typing import _AnnotationScanType
+    from ..util.typing import RODescriptorReference
 
 _T = TypeVar("_T", bound=Any)
 _PT = TypeVar("_PT", bound=Any)
+_NC = TypeVar("_NC", bound="NamedColumn[Any]")
 
 __all__ = [
     "ColumnProperty",
@@ -85,11 +100,15 @@ class ColumnProperty(
     inherit_cache = True
     _links_to_entity = False
 
-    columns: List[KeyedColumnElement[Any]]
-    _orig_columns: List[KeyedColumnElement[Any]]
+    columns: List[NamedColumn[Any]]
+    _orig_columns: List[NamedColumn[Any]]
 
     _is_polymorphic_discriminator: bool
 
+    _mapped_by_synonym: Optional[str]
+
+    comparator_factory: Type[PropComparator[_T]]
+
     __slots__ = (
         "_orig_columns",
         "columns",
@@ -100,7 +119,6 @@ class ColumnProperty(
         "descriptor",
         "active_history",
         "expire_on_flush",
-        "doc",
         "_creation_order",
         "_is_polymorphic_discriminator",
         "_mapped_by_synonym",
@@ -117,7 +135,7 @@ class ColumnProperty(
         group: Optional[str] = None,
         deferred: bool = False,
         raiseload: bool = False,
-        comparator_factory: Optional[Type[PropComparator]] = None,
+        comparator_factory: Optional[Type[PropComparator[_T]]] = None,
         descriptor: Optional[Any] = None,
         active_history: bool = False,
         expire_on_flush: bool = True,
@@ -150,7 +168,7 @@ class ColumnProperty(
         self.expire_on_flush = expire_on_flush
 
         if info is not None:
-            self.info = info
+            self.info.update(info)
 
         if doc is not None:
             self.doc = doc
@@ -173,8 +191,13 @@ class ColumnProperty(
             self.strategy_key += (("raiseload", True),)
 
     def declarative_scan(
-        self, registry, cls, key, annotation, is_dataclass_field
-    ):
+        self,
+        registry: _RegistryType,
+        cls: Type[Any],
+        key: str,
+        annotation: Optional[_AnnotationScanType],
+        is_dataclass_field: bool,
+    ) -> None:
         column = self.columns[0]
         if column.key is None:
             column.key = key
@@ -186,20 +209,23 @@ class ColumnProperty(
         return self
 
     @property
-    def columns_to_assign(self) -> List[Column]:
+    def columns_to_assign(self) -> List[Column[Any]]:
+        # mypy doesn't care about the isinstance here
         return [
-            c
+            c  # type: ignore
             for c in self.columns
             if isinstance(c, Column) and c.table is None
         ]
 
-    def _memoized_attr__renders_in_subqueries(self):
+    def _memoized_attr__renders_in_subqueries(self) -> bool:
         return ("deferred", True) not in self.strategy_key or (
-            self not in self.parent._readonly_props
+            self not in self.parent._readonly_props  # type: ignore
         )
 
     @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
-    def _memoized_attr__deferred_column_loader(self):
+    def _memoized_attr__deferred_column_loader(
+        self,
+    ) -> _InstallLoaderCallableProto[Any]:
         state = util.preloaded.orm_state
         strategies = util.preloaded.orm_strategies
         return state.InstanceState._instance_level_callable_processor(
@@ -209,7 +235,9 @@ class ColumnProperty(
         )
 
     @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
-    def _memoized_attr__raise_column_loader(self):
+    def _memoized_attr__raise_column_loader(
+        self,
+    ) -> _InstallLoaderCallableProto[Any]:
         state = util.preloaded.orm_state
         strategies = util.preloaded.orm_strategies
         return state.InstanceState._instance_level_callable_processor(
@@ -218,7 +246,7 @@ class ColumnProperty(
             self.key,
         )
 
-    def __clause_element__(self):
+    def __clause_element__(self) -> roles.ColumnsClauseRole:
         """Allow the ColumnProperty to work in expression before it is turned
         into an instrumented attribute.
         """
@@ -226,7 +254,7 @@ class ColumnProperty(
         return self.expression
 
     @property
-    def expression(self):
+    def expression(self) -> roles.ColumnsClauseRole:
         """Return the primary column or expression for this ColumnProperty.
 
         E.g.::
@@ -247,7 +275,7 @@ class ColumnProperty(
         """
         return self.columns[0]
 
-    def instrument_class(self, mapper):
+    def instrument_class(self, mapper: Mapper[Any]) -> None:
         if not self.instrument:
             return
 
@@ -259,7 +287,7 @@ class ColumnProperty(
             doc=self.doc,
         )
 
-    def do_init(self):
+    def do_init(self) -> None:
         super().do_init()
 
         if len(self.columns) > 1 and set(self.parent.primary_key).issuperset(
@@ -275,32 +303,25 @@ class ColumnProperty(
                 % (self.parent, self.columns[1], self.columns[0], self.key)
             )
 
-    def copy(self):
+    def copy(self) -> ColumnProperty[_T]:
         return ColumnProperty(
+            *self.columns,
             deferred=self.deferred,
             group=self.group,
             active_history=self.active_history,
-            *self.columns,
-        )
-
-    def _getcommitted(
-        self, state, dict_, column, passive=attributes.PASSIVE_OFF
-    ):
-        return state.get_impl(self.key).get_committed_value(
-            state, dict_, passive=passive
         )
 
     def merge(
         self,
-        session,
-        source_state,
-        source_dict,
-        dest_state,
-        dest_dict,
-        load,
-        _recursive,
-        _resolve_conflict_map,
-    ):
+        session: Session,
+        source_state: InstanceState[Any],
+        source_dict: _InstanceDict,
+        dest_state: InstanceState[Any],
+        dest_dict: _InstanceDict,
+        load: bool,
+        _recursive: Dict[Any, object],
+        _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+    ) -> None:
         if not self.instrument:
             return
         elif self.key in source_dict:
@@ -335,9 +356,13 @@ class ColumnProperty(
 
         """
 
-        __slots__ = "__clause_element__", "info", "expressions"
+        if not TYPE_CHECKING:
+            # prevent pylance from being clever about slots
+            __slots__ = "__clause_element__", "info", "expressions"
+
+        prop: RODescriptorReference[ColumnProperty[_PT]]
 
-        def _orm_annotate_column(self, column):
+        def _orm_annotate_column(self, column: _NC) -> _NC:
             """annotate and possibly adapt a column to be returned
             as the mapped-attribute exposed version of the column.
 
@@ -351,7 +376,7 @@ class ColumnProperty(
             """
 
             pe = self._parententity
-            annotations = {
+            annotations: Dict[str, Any] = {
                 "entity_namespace": pe,
                 "parententity": pe,
                 "parentmapper": pe,
@@ -377,22 +402,29 @@ class ColumnProperty(
                 {"compile_state_plugin": "orm", "plugin_subject": pe}
             )
 
-        def _memoized_method___clause_element__(self):
+        if TYPE_CHECKING:
+
+            def __clause_element__(self) -> NamedColumn[_PT]:
+                ...
+
+        def _memoized_method___clause_element__(
+            self,
+        ) -> NamedColumn[_PT]:
             if self.adapter:
                 return self.adapter(self.prop.columns[0], self.prop.key)
             else:
                 return self._orm_annotate_column(self.prop.columns[0])
 
-        def _memoized_attr_info(self):
+        def _memoized_attr_info(self) -> _InfoType:
             """The .info dictionary for this attribute."""
 
             ce = self.__clause_element__()
             try:
-                return ce.info
+                return ce.info  # type: ignore
             except AttributeError:
                 return self.prop.info
 
-        def _memoized_attr_expressions(self):
+        def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]:
             """The full sequence of columns referenced by this
             attribute, adjusted for any aliasing in progress.
 
@@ -409,21 +441,25 @@ class ColumnProperty(
                     self._orm_annotate_column(col) for col in self.prop.columns
                 ]
 
-        def _fallback_getattr(self, key):
+        def _fallback_getattr(self, key: str) -> Any:
             """proxy attribute access down to the mapped column.
 
             this allows user-defined comparison methods to be accessed.
             """
             return getattr(self.__clause_element__(), key)
 
-        def operate(self, op, *other, **kwargs):
-            return op(self.__clause_element__(), *other, **kwargs)
+        def operate(
+            self, op: OperatorType, *other: Any, **kwargs: Any
+        ) -> ColumnElement[Any]:
+            return op(self.__clause_element__(), *other, **kwargs)  # type: ignore[return-value]  # noqa: E501
 
-        def reverse_operate(self, op, other, **kwargs):
+        def reverse_operate(
+            self, op: OperatorType, other: Any, **kwargs: Any
+        ) -> ColumnElement[Any]:
             col = self.__clause_element__()
-            return op(col._bind_param(op, other), col, **kwargs)
+            return op(col._bind_param(op, other), col, **kwargs)  # type: ignore[return-value]  # noqa: E501
 
-    def __str__(self):
+    def __str__(self) -> str:
         if not self.parent or not self.key:
             return object.__repr__(self)
         return str(self.parent.class_.__name__) + "." + self.key
@@ -460,7 +496,7 @@ class MappedColumn(
     column: Column[_T]
     foreign_keys: Optional[Set[ForeignKey]]
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any):
         self.deferred = kw.pop("deferred", False)
         self.column = cast("Column[_T]", Column(*arg, **kw))
         self.foreign_keys = self.column.foreign_keys
@@ -470,8 +506,8 @@ class MappedColumn(
         )
         util.set_creation_order(self)
 
-    def _copy(self, **kw):
-        new = self.__class__.__new__(self.__class__)
+    def _copy(self: Self, **kw: Any) -> Self:
+        new = cast(Self, self.__class__.__new__(self.__class__))
         new.column = self.column._copy(**kw)
         new.deferred = self.deferred
         new.foreign_keys = new.column.foreign_keys
@@ -487,22 +523,31 @@ class MappedColumn(
             return None
 
     @property
-    def columns_to_assign(self) -> List[Column]:
+    def columns_to_assign(self) -> List[Column[Any]]:
         return [self.column]
 
-    def __clause_element__(self):
+    def __clause_element__(self) -> Column[_T]:
         return self.column
 
-    def operate(self, op, *other, **kwargs):
-        return op(self.__clause_element__(), *other, **kwargs)
+    def operate(
+        self, op: OperatorType, *other: Any, **kwargs: Any
+    ) -> ColumnElement[Any]:
+        return op(self.__clause_element__(), *other, **kwargs)  # type: ignore[return-value]  # noqa: E501
 
-    def reverse_operate(self, op, other, **kwargs):
+    def reverse_operate(
+        self, op: OperatorType, other: Any, **kwargs: Any
+    ) -> ColumnElement[Any]:
         col = self.__clause_element__()
-        return op(col._bind_param(op, other), col, **kwargs)
+        return op(col._bind_param(op, other), col, **kwargs)  # type: ignore[return-value]  # noqa: E501
 
     def declarative_scan(
-        self, registry, cls, key, annotation, is_dataclass_field
-    ):
+        self,
+        registry: _RegistryType,
+        cls: Type[Any],
+        key: str,
+        annotation: Optional[_AnnotationScanType],
+        is_dataclass_field: bool,
+    ) -> None:
         column = self.column
         if column.key is None:
             column.key = key
@@ -526,38 +571,48 @@ class MappedColumn(
 
     @util.preload_module("sqlalchemy.orm.decl_base")
     def declarative_scan_for_composite(
-        self, registry, cls, key, param_name, param_annotation
-    ):
+        self,
+        registry: _RegistryType,
+        cls: Type[Any],
+        key: str,
+        param_name: str,
+        param_annotation: _AnnotationScanType,
+    ) -> None:
         decl_base = util.preloaded.orm_decl_base
         decl_base._undefer_column_name(param_name, self.column)
         self._init_column_for_annotation(cls, registry, param_annotation)
 
-    def _init_column_for_annotation(self, cls, registry, argument):
+    def _init_column_for_annotation(
+        self,
+        cls: Type[Any],
+        registry: _RegistryType,
+        argument: _AnnotationScanType,
+    ) -> None:
         sqltype = self.column.type
 
         nullable = False
 
         if hasattr(argument, "__origin__"):
-            nullable = NoneType in argument.__args__
+            nullable = NoneType in argument.__args__  # type: ignore
 
         if not self._has_nullable:
             self.column.nullable = nullable
 
         if sqltype._isnull and not self.column.foreign_keys:
-            sqltype = None
+            new_sqltype = None
             our_type = de_optionalize_union_types(argument)
 
             if is_fwd_ref(our_type):
                 our_type = de_stringify_annotation(cls, our_type)
 
             if registry.type_annotation_map:
-                sqltype = registry.type_annotation_map.get(our_type)
-            if sqltype is None:
-                sqltype = sqltypes._type_map_get(our_type)
+                new_sqltype = registry.type_annotation_map.get(our_type)
+            if new_sqltype is None:
+                new_sqltype = sqltypes._type_map_get(our_type)  # type: ignore
 
-            if sqltype is None:
+            if new_sqltype is None:
                 raise sa_exc.ArgumentError(
                     f"Could not locate SQLAlchemy Core "
                     f"type for Python type: {our_type}"
                 )
-            self.column.type = sqltype
+            self.column.type = new_sqltype  # type: ignore
index a60a167ac81fce0c1d270c0606df7c16a0ba97ad..419891708cdff8a503c2e197a552b1c8fa4b7ce3 100644 (file)
@@ -23,18 +23,23 @@ from __future__ import annotations
 import collections.abc as collections_abc
 import operator
 from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
 from typing import Generic
 from typing import Iterable
 from typing import List
+from typing import Mapping
 from typing import Optional
 from typing import overload
 from typing import Sequence
 from typing import Tuple
+from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
-from . import exc as orm_exc
+from . import attributes
 from . import interfaces
 from . import loading
 from . import util as orm_util
@@ -44,7 +49,6 @@ from .context import _column_descriptions
 from .context import _determine_last_joined_entity
 from .context import _legacy_filter_by_entity_zero
 from .context import FromStatement
-from .context import LABEL_STYLE_LEGACY_ORM
 from .context import ORMCompileState
 from .context import QueryContext
 from .interfaces import ORMColumnDescription
@@ -60,6 +64,8 @@ from .. import sql
 from .. import util
 from ..engine import Result
 from ..engine import Row
+from ..event import dispatcher
+from ..event import EventTarget
 from ..sql import coercions
 from ..sql import expression
 from ..sql import roles
@@ -71,8 +77,10 @@ from ..sql._typing import _TP
 from ..sql.annotation import SupportsCloneAnnotations
 from ..sql.base import _entity_namespace_key
 from ..sql.base import _generative
+from ..sql.base import _NoArg
 from ..sql.base import Executable
 from ..sql.base import Generative
+from ..sql.elements import BooleanClauseList
 from ..sql.expression import Exists
 from ..sql.selectable import _MemoizedSelectEntities
 from ..sql.selectable import _SelectFromElements
@@ -81,17 +89,31 @@ from ..sql.selectable import HasHints
 from ..sql.selectable import HasPrefixes
 from ..sql.selectable import HasSuffixes
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectLabelStyle
 from ..util.typing import Literal
+from ..util.typing import Self
 
 if TYPE_CHECKING:
     from ._typing import _EntityType
+    from ._typing import _ExternalEntityType
+    from ._typing import _InternalEntityType
+    from .mapper import Mapper
+    from .path_registry import PathRegistry
+    from .session import _PKIdentityArgument
     from .session import Session
+    from .state import InstanceState
+    from ..engine.cursor import CursorResult
+    from ..engine.interfaces import _ImmutableExecuteOptions
+    from ..engine.result import FrozenResult
     from ..engine.result import ScalarResult
     from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _ColumnsClauseArgument
+    from ..sql._typing import _DMLColumnArgument
+    from ..sql._typing import _JoinTargetArgument
     from ..sql._typing import _MAYBE_ENTITY
     from ..sql._typing import _no_kw
     from ..sql._typing import _NOT_ENTITY
+    from ..sql._typing import _OnClauseArgument
     from ..sql._typing import _PropagateAttrsType
     from ..sql._typing import _T0
     from ..sql._typing import _T1
@@ -102,18 +124,25 @@ if TYPE_CHECKING:
     from ..sql._typing import _T6
     from ..sql._typing import _T7
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
-    from ..sql.roles import TypedColumnsClauseRole
+    from ..sql.base import CacheableOptions
+    from ..sql.base import ExecutableOption
+    from ..sql.elements import ColumnElement
+    from ..sql.elements import Label
+    from ..sql.selectable import _JoinTargetElement
     from ..sql.selectable import _SetupJoinsElement
     from ..sql.selectable import Alias
+    from ..sql.selectable import CTE
     from ..sql.selectable import ExecutableReturnsRows
+    from ..sql.selectable import FromClause
     from ..sql.selectable import ScalarSelect
     from ..sql.selectable import Subquery
 
+
 __all__ = ["Query", "QueryContext"]
 
 _T = TypeVar("_T", bound=Any)
 
-SelfQuery = TypeVar("SelfQuery", bound="Query")
+SelfQuery = TypeVar("SelfQuery", bound="Query[Any]")
 
 
 @inspection._self_inspects
@@ -124,6 +153,7 @@ class Query(
     HasPrefixes,
     HasSuffixes,
     HasHints,
+    EventTarget,
     log.Identified,
     Generative,
     Executable,
@@ -150,40 +180,47 @@ class Query(
     """
 
     # elements that are in Core and can be cached in the same way
-    _where_criteria = ()
-    _having_criteria = ()
+    _where_criteria: Tuple[ColumnElement[Any], ...] = ()
+    _having_criteria: Tuple[ColumnElement[Any], ...] = ()
 
-    _order_by_clauses = ()
-    _group_by_clauses = ()
-    _limit_clause = None
-    _offset_clause = None
+    _order_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+    _group_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+    _limit_clause: Optional[ColumnElement[Any]] = None
+    _offset_clause: Optional[ColumnElement[Any]] = None
 
-    _distinct = False
-    _distinct_on = ()
+    _distinct: bool = False
+    _distinct_on: Tuple[ColumnElement[Any], ...] = ()
 
-    _for_update_arg = None
-    _correlate = ()
-    _auto_correlate = True
-    _from_obj = ()
+    _for_update_arg: Optional[ForUpdateArg] = None
+    _correlate: Tuple[FromClause, ...] = ()
+    _auto_correlate: bool = True
+    _from_obj: Tuple[FromClause, ...] = ()
     _setup_joins: Tuple[_SetupJoinsElement, ...] = ()
 
-    _label_style = LABEL_STYLE_LEGACY_ORM
+    _label_style: SelectLabelStyle = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM
 
     _memoized_select_entities = ()
 
-    _compile_options = ORMCompileState.default_compile_options
+    _compile_options: Union[
+        Type[CacheableOptions], CacheableOptions
+    ] = ORMCompileState.default_compile_options
 
+    _with_options: Tuple[ExecutableOption, ...]
     load_options = QueryContext.default_load_options + {
         "_legacy_uniquing": True
     }
 
-    _params = util.EMPTY_DICT
+    _params: util.immutabledict[str, Any] = util.EMPTY_DICT
 
     # local Query builder state, not needed for
     # compilation or execution
     _enable_assertions = True
 
-    _statement = None
+    _statement: Optional[ExecutableReturnsRows] = None
+
+    session: Session
+
+    dispatch: dispatcher[Query[_T]]
 
     # mirrors that of ClauseElement, used to propagate the "orm"
     # plugin as well as the "subject" of the plugin, e.g. the mapper
@@ -224,14 +261,23 @@ class Query(
 
         """
 
-        self.session = session
+        # session is usually present.  There's one case in subqueryloader
+        # where it stores a Query without a Session and also there are tests
+        # for the query(Entity).with_session(session) API which is likely in
+        # some old recipes, however these are legacy as select() can now be
+        # used.
+        self.session = session  # type: ignore
         self._set_entities(entities)
 
-    def _set_propagate_attrs(self, values):
-        self._propagate_attrs = util.immutabledict(values)
+    def _set_propagate_attrs(
+        self: SelfQuery, values: Mapping[str, Any]
+    ) -> SelfQuery:
+        self._propagate_attrs = util.immutabledict(values)  # type: ignore
         return self
 
-    def _set_entities(self, entities):
+    def _set_entities(
+        self, entities: Iterable[_ColumnsClauseArgument[Any]]
+    ) -> None:
         self._raw_columns = [
             coercions.expect(
                 roles.ColumnsClauseRole,
@@ -242,15 +288,7 @@ class Query(
             for ent in util.to_list(entities)
         ]
 
-    @overload
-    def tuples(self: Query[Row[_TP]]) -> Query[_TP]:
-        ...
-
-    @overload
     def tuples(self: Query[_O]) -> Query[Tuple[_O]]:
-        ...
-
-    def tuples(self) -> Query[Any]:
         """return a tuple-typed form of this :class:`.Query`.
 
         This method invokes the :meth:`.Query.only_return_tuples`
@@ -270,29 +308,27 @@ class Query(
         .. versionadded:: 2.0
 
         """
-        return self.only_return_tuples(True)
+        return self.only_return_tuples(True)  # type: ignore
 
-    def _entity_from_pre_ent_zero(self):
+    def _entity_from_pre_ent_zero(self) -> Optional[_InternalEntityType[Any]]:
         if not self._raw_columns:
             return None
 
         ent = self._raw_columns[0]
 
         if "parententity" in ent._annotations:
-            return ent._annotations["parententity"]
-        elif isinstance(ent, ORMColumnsClauseRole):
-            return ent.entity
+            return ent._annotations["parententity"]  # type: ignore
         elif "bundle" in ent._annotations:
-            return ent._annotations["bundle"]
+            return ent._annotations["bundle"]  # type: ignore
         else:
             # label, other SQL expression
             for element in visitors.iterate(ent):
                 if "parententity" in element._annotations:
-                    return element._annotations["parententity"]
+                    return element._annotations["parententity"]  # type: ignore  # noqa: E501
             else:
                 return None
 
-    def _only_full_mapper_zero(self, methname):
+    def _only_full_mapper_zero(self, methname: str) -> Mapper[Any]:
         if (
             len(self._raw_columns) != 1
             or "parententity" not in self._raw_columns[0]._annotations
@@ -303,9 +339,11 @@ class Query(
                 "a single mapped class." % methname
             )
 
-        return self._raw_columns[0]._annotations["parententity"]
+        return self._raw_columns[0]._annotations["parententity"]  # type: ignore  # noqa: E501
 
-    def _set_select_from(self, obj, set_base_alias):
+    def _set_select_from(
+        self, obj: Iterable[_FromClauseArgument], set_base_alias: bool
+    ) -> None:
         fa = [
             coercions.expect(
                 roles.StrictFromClauseRole,
@@ -320,19 +358,22 @@ class Query(
         self._from_obj = tuple(fa)
 
     @_generative
-    def _set_lazyload_from(self: SelfQuery, state) -> SelfQuery:
+    def _set_lazyload_from(
+        self: SelfQuery, state: InstanceState[Any]
+    ) -> SelfQuery:
         self.load_options += {"_lazy_loaded_from": state}
         return self
 
-    def _get_condition(self):
-        return self._no_criterion_condition(
-            "get", order_by=False, distinct=False
-        )
+    def _get_condition(self) -> None:
+        """used by legacy BakedQuery"""
+        self._no_criterion_condition("get", order_by=False, distinct=False)
 
-    def _get_existing_condition(self):
+    def _get_existing_condition(self) -> None:
         self._no_criterion_assertion("get", order_by=False, distinct=False)
 
-    def _no_criterion_assertion(self, meth, order_by=True, distinct=True):
+    def _no_criterion_assertion(
+        self, meth: str, order_by: bool = True, distinct: bool = True
+    ) -> None:
         if not self._enable_assertions:
             return
         if (
@@ -351,7 +392,9 @@ class Query(
                 "Query with existing criterion. " % meth
             )
 
-    def _no_criterion_condition(self, meth, order_by=True, distinct=True):
+    def _no_criterion_condition(
+        self, meth: str, order_by: bool = True, distinct: bool = True
+    ) -> None:
         self._no_criterion_assertion(meth, order_by, distinct)
 
         self._from_obj = self._setup_joins = ()
@@ -362,7 +405,7 @@ class Query(
 
         self._order_by_clauses = self._group_by_clauses = ()
 
-    def _no_clauseelement_condition(self, meth):
+    def _no_clauseelement_condition(self, meth: str) -> None:
         if not self._enable_assertions:
             return
         if self._order_by_clauses:
@@ -372,7 +415,7 @@ class Query(
             )
         self._no_criterion_condition(meth)
 
-    def _no_statement_condition(self, meth):
+    def _no_statement_condition(self, meth: str) -> None:
         if not self._enable_assertions:
             return
         if self._statement is not None:
@@ -384,7 +427,7 @@ class Query(
                 % meth
             )
 
-    def _no_limit_offset(self, meth):
+    def _no_limit_offset(self, meth: str) -> None:
         if not self._enable_assertions:
             return
         if self._limit_clause is not None or self._offset_clause is not None:
@@ -395,21 +438,21 @@ class Query(
             )
 
     @property
-    def _has_row_limiting_clause(self):
+    def _has_row_limiting_clause(self) -> bool:
         return (
             self._limit_clause is not None or self._offset_clause is not None
         )
 
     def _get_options(
-        self,
-        populate_existing=None,
-        version_check=None,
-        only_load_props=None,
-        refresh_state=None,
-        identity_token=None,
-    ):
-        load_options = {}
-        compile_options = {}
+        self: SelfQuery,
+        populate_existing: Optional[bool] = None,
+        version_check: Optional[bool] = None,
+        only_load_props: Optional[Sequence[str]] = None,
+        refresh_state: Optional[InstanceState[Any]] = None,
+        identity_token: Optional[Any] = None,
+    ) -> SelfQuery:
+        load_options: Dict[str, Any] = {}
+        compile_options: Dict[str, Any] = {}
 
         if version_check:
             load_options["_version_check"] = version_check
@@ -430,11 +473,18 @@ class Query(
 
         return self
 
-    def _clone(self):
-        return self._generate()
+    def _clone(self: Self, **kw: Any) -> Self:
+        return self._generate()  # type: ignore
+
+    def _get_select_statement_only(self) -> Select[_T]:
+        if self._statement is not None:
+            raise sa_exc.InvalidRequestError(
+                "Can't call this method on a Query that uses from_statement()"
+            )
+        return cast("Select[_T]", self.statement)
 
     @property
-    def statement(self):
+    def statement(self) -> Union[Select[_T], FromStatement[_T]]:
         """The full SELECT statement represented by this Query.
 
         The statement by default will not have disambiguating labels
@@ -474,14 +524,15 @@ class Query(
 
         return stmt
 
-    def _final_statement(self, legacy_query_style=True):
+    def _final_statement(self, legacy_query_style: bool = True) -> Select[Any]:
         """Return the 'final' SELECT statement for this :class:`.Query`.
 
+        This is used by the testing suite only and is fairly inefficient.
+
         This is the Core-only select() that will be rendered by a complete
         compilation of this query, and is what .statement used to return
         in 1.3.
 
-        This method creates a complete compile state so is fairly expensive.
 
         """
 
@@ -489,9 +540,11 @@ class Query(
 
         return q._compile_state(
             use_legacy_query_style=legacy_query_style
-        ).statement
+        ).statement  # type: ignore
 
-    def _statement_20(self, for_statement=False, use_legacy_query_style=True):
+    def _statement_20(
+        self, for_statement: bool = False, use_legacy_query_style: bool = True
+    ) -> Union[Select[_T], FromStatement[_T]]:
         # TODO: this event needs to be deprecated, as it currently applies
         # only to ORM query and occurs at this spot that is now more
         # or less an artificial spot
@@ -500,7 +553,7 @@ class Query(
                 new_query = fn(self)
                 if new_query is not None and new_query is not self:
                     self = new_query
-                    if not fn._bake_ok:
+                    if not fn._bake_ok:  # type: ignore
                         self._compile_options += {"_bake_ok": False}
 
         compile_options = self._compile_options
@@ -509,6 +562,8 @@ class Query(
             "_use_legacy_query_style": use_legacy_query_style,
         }
 
+        stmt: Union[Select[_T], FromStatement[_T]]
+
         if self._statement is not None:
             stmt = FromStatement(self._raw_columns, self._statement)
             stmt.__dict__.update(
@@ -541,10 +596,10 @@ class Query(
 
     def subquery(
         self,
-        name=None,
-        with_labels=False,
-        reduce_columns=False,
-    ):
+        name: Optional[str] = None,
+        with_labels: bool = False,
+        reduce_columns: bool = False,
+    ) -> Subquery:
         """Return the full SELECT statement represented by
         this :class:`_query.Query`, embedded within an
         :class:`_expression.Alias`.
@@ -571,13 +626,21 @@ class Query(
         if with_labels:
             q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
 
-        q = q.statement
+        stmt = q._get_select_statement_only()
+
+        if TYPE_CHECKING:
+            assert isinstance(stmt, Select)
 
         if reduce_columns:
-            q = q.reduce_columns()
-        return q.alias(name=name)
+            stmt = stmt.reduce_columns()
+        return stmt.subquery(name=name)
 
-    def cte(self, name=None, recursive=False, nesting=False):
+    def cte(
+        self,
+        name: Optional[str] = None,
+        recursive: bool = False,
+        nesting: bool = False,
+    ) -> CTE:
         r"""Return the full SELECT statement represented by this
         :class:`_query.Query` represented as a common table expression (CTE).
 
@@ -632,11 +695,13 @@ class Query(
             :meth:`_expression.HasCTE.cte`
 
         """
-        return self.enable_eagerloads(False).statement.cte(
-            name=name, recursive=recursive, nesting=nesting
+        return (
+            self.enable_eagerloads(False)
+            ._get_select_statement_only()
+            .cte(name=name, recursive=recursive, nesting=nesting)
         )
 
-    def label(self, name):
+    def label(self, name: Optional[str]) -> Label[Any]:
         """Return the full SELECT statement represented by this
         :class:`_query.Query`, converted
         to a scalar subquery with a label of the given name.
@@ -645,7 +710,11 @@ class Query(
 
         """
 
-        return self.enable_eagerloads(False).statement.label(name)
+        return (
+            self.enable_eagerloads(False)
+            ._get_select_statement_only()
+            .label(name)
+        )
 
     @overload
     def as_scalar(
@@ -704,10 +773,14 @@ class Query(
 
         """
 
-        return self.enable_eagerloads(False).statement.scalar_subquery()
+        return (
+            self.enable_eagerloads(False)
+            ._get_select_statement_only()
+            .scalar_subquery()
+        )
 
     @property
-    def selectable(self):
+    def selectable(self) -> Union[Select[_T], FromStatement[_T]]:
         """Return the :class:`_expression.Select` object emitted by this
         :class:`_query.Query`.
 
@@ -718,7 +791,7 @@ class Query(
         """
         return self.__clause_element__()
 
-    def __clause_element__(self):
+    def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]:
         return (
             self._with_compile_options(
                 _enable_eagerloads=False, _render_for_subquery=True
@@ -759,7 +832,7 @@ class Query(
         return self
 
     @property
-    def is_single_entity(self):
+    def is_single_entity(self) -> bool:
         """Indicates if this :class:`_query.Query`
         returns tuples or single entities.
 
@@ -785,7 +858,7 @@ class Query(
         )
 
     @_generative
-    def enable_eagerloads(self: SelfQuery, value) -> SelfQuery:
+    def enable_eagerloads(self: SelfQuery, value: bool) -> SelfQuery:
         """Control whether or not eager joins and subqueries are
         rendered.
 
@@ -804,7 +877,7 @@ class Query(
         return self
 
     @_generative
-    def _with_compile_options(self: SelfQuery, **opt) -> SelfQuery:
+    def _with_compile_options(self: SelfQuery, **opt: Any) -> SelfQuery:
         self._compile_options += opt
         return self
 
@@ -813,13 +886,15 @@ class Query(
         alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
         "instead.",
     )
-    def with_labels(self):
-        return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+    def with_labels(self: SelfQuery) -> SelfQuery:
+        return self.set_label_style(
+            SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL
+        )
 
     apply_labels = with_labels
 
     @property
-    def get_label_style(self):
+    def get_label_style(self) -> SelectLabelStyle:
         """
         Retrieve the current label style.
 
@@ -828,7 +903,7 @@ class Query(
         """
         return self._label_style
 
-    def set_label_style(self, style):
+    def set_label_style(self: SelfQuery, style: SelectLabelStyle) -> SelfQuery:
         """Apply column labels to the return value of Query.statement.
 
         Indicates that this Query's `statement` accessor should return
@@ -864,7 +939,7 @@ class Query(
         return self
 
     @_generative
-    def enable_assertions(self: SelfQuery, value) -> SelfQuery:
+    def enable_assertions(self: SelfQuery, value: bool) -> SelfQuery:
         """Control whether assertions are generated.
 
         When set to False, the returned Query will
@@ -887,7 +962,7 @@ class Query(
         return self
 
     @property
-    def whereclause(self):
+    def whereclause(self) -> Optional[ColumnElement[bool]]:
         """A readonly attribute which returns the current WHERE criterion for
         this Query.
 
@@ -895,12 +970,12 @@ class Query(
         criterion has been established.
 
         """
-        return sql.elements.BooleanClauseList._construct_for_whereclause(
+        return BooleanClauseList._construct_for_whereclause(
             self._where_criteria
         )
 
     @_generative
-    def _with_current_path(self: SelfQuery, path) -> SelfQuery:
+    def _with_current_path(self: SelfQuery, path: PathRegistry) -> SelfQuery:
         """indicate that this query applies to objects loaded
         within a certain path.
 
@@ -913,7 +988,7 @@ class Query(
         return self
 
     @_generative
-    def yield_per(self: SelfQuery, count) -> SelfQuery:
+    def yield_per(self: SelfQuery, count: int) -> SelfQuery:
         r"""Yield only ``count`` rows at a time.
 
         The purpose of this method is when fetching very large result sets
@@ -938,7 +1013,7 @@ class Query(
         ":meth:`_orm.Query.get`",
         alternative="The method is now available as :meth:`_orm.Session.get`",
     )
-    def get(self, ident):
+    def get(self, ident: _PKIdentityArgument) -> Optional[Any]:
         """Return an instance based on the given primary key identifier,
         or ``None`` if not found.
 
@@ -1022,7 +1097,12 @@ class Query(
         # it
         return self._get_impl(ident, loading.load_on_pk_identity)
 
-    def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
+    def _get_impl(
+        self,
+        primary_key_identity: _PKIdentityArgument,
+        db_load_fn: Callable[..., Any],
+        identity_token: Optional[Any] = None,
+    ) -> Optional[Any]:
         mapper = self._only_full_mapper_zero("get")
         return self.session._get_impl(
             mapper,
@@ -1036,7 +1116,7 @@ class Query(
         )
 
     @property
-    def lazy_loaded_from(self):
+    def lazy_loaded_from(self) -> Optional[InstanceState[Any]]:
         """An :class:`.InstanceState` that is using this :class:`_query.Query`
         for a lazy load operation.
 
@@ -1050,14 +1130,17 @@ class Query(
             :attr:`.ORMExecuteState.lazy_loaded_from`
 
         """
-        return self.load_options._lazy_loaded_from
+        return self.load_options._lazy_loaded_from  # type: ignore
 
     @property
-    def _current_path(self):
-        return self._compile_options._current_path
+    def _current_path(self) -> PathRegistry:
+        return self._compile_options._current_path  # type: ignore
 
     @_generative
-    def correlate(self: SelfQuery, *fromclauses) -> SelfQuery:
+    def correlate(
+        self: SelfQuery,
+        *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+    ) -> SelfQuery:
         """Return a :class:`.Query` construct which will correlate the given
         FROM clauses to that of an enclosing :class:`.Query` or
         :func:`~.expression.select`.
@@ -1082,13 +1165,13 @@ class Query(
         if fromclauses and fromclauses[0] in {None, False}:
             self._correlate = ()
         else:
-            self._correlate = set(self._correlate).union(
+            self._correlate = self._correlate + tuple(
                 coercions.expect(roles.FromClauseRole, f) for f in fromclauses
             )
         return self
 
     @_generative
-    def autoflush(self: SelfQuery, setting) -> SelfQuery:
+    def autoflush(self: SelfQuery, setting: bool) -> SelfQuery:
         """Return a Query with a specific 'autoflush' setting.
 
         As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method
@@ -1116,7 +1199,7 @@ class Query(
         return self
 
     @_generative
-    def _with_invoke_all_eagers(self: SelfQuery, value) -> SelfQuery:
+    def _with_invoke_all_eagers(self: SelfQuery, value: bool) -> SelfQuery:
         """Set the 'invoke all eagers' flag which causes joined- and
         subquery loaders to traverse into already-loaded related objects
         and collections.
@@ -1132,7 +1215,14 @@ class Query(
         alternative="Use the :func:`_orm.with_parent` standalone construct.",
     )
     @util.preload_module("sqlalchemy.orm.relationships")
-    def with_parent(self, instance, property=None, from_entity=None):  # noqa
+    def with_parent(
+        self: SelfQuery,
+        instance: object,
+        property: Optional[  # noqa: A002
+            attributes.QueryableAttribute[Any]
+        ] = None,
+        from_entity: Optional[_ExternalEntityType[Any]] = None,
+    ) -> SelfQuery:
         """Add filtering criterion that relates the given instance
         to a child object or collection, using its attribute state
         as well as an established :func:`_orm.relationship()`
@@ -1150,7 +1240,7 @@ class Query(
           An instance which has some :func:`_orm.relationship`.
 
         :param property:
-          String property name, or class-bound attribute, which indicates
+          Class bound attribute which indicates
           what relationship from the instance should be used to reconcile the
           parent/child relationship.
 
@@ -1172,21 +1262,27 @@ class Query(
             for prop in mapper.iterate_properties:
                 if (
                     isinstance(prop, relationships.Relationship)
-                    and prop.mapper is entity_zero.mapper
+                    and prop.mapper is entity_zero.mapper  # type: ignore
                 ):
-                    property = prop  # noqa
+                    property = prop  # type: ignore  # noqa: A001
                     break
             else:
                 raise sa_exc.InvalidRequestError(
                     "Could not locate a property which relates instances "
                     "of class '%s' to instances of class '%s'"
                     % (
-                        entity_zero.mapper.class_.__name__,
+                        entity_zero.mapper.class_.__name__,  # type: ignore
                         instance.__class__.__name__,
                     )
                 )
 
-        return self.filter(with_parent(instance, property, entity_zero.entity))
+        return self.filter(
+            with_parent(
+                instance,
+                property,  # type: ignore
+                entity_zero.entity,  # type: ignore
+            )
+        )
 
     @_generative
     def add_entity(
@@ -1211,7 +1307,7 @@ class Query(
         return self
 
     @_generative
-    def with_session(self: SelfQuery, session) -> SelfQuery:
+    def with_session(self: SelfQuery, session: Session) -> SelfQuery:
         """Return a :class:`_query.Query` that will use the given
         :class:`.Session`.
 
@@ -1237,7 +1333,9 @@ class Query(
         self.session = session
         return self
 
-    def _legacy_from_self(self, *entities):
+    def _legacy_from_self(
+        self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
+    ) -> SelfQuery:
         # used for query.count() as well as for the same
         # function in BakedQuery, as well as some old tests in test_baked.py.
 
@@ -1255,13 +1353,13 @@ class Query(
         return q
 
     @_generative
-    def _set_enable_single_crit(self: SelfQuery, val) -> SelfQuery:
+    def _set_enable_single_crit(self: SelfQuery, val: bool) -> SelfQuery:
         self._compile_options += {"_enable_single_crit": val}
         return self
 
     @_generative
     def _from_selectable(
-        self: SelfQuery, fromclause, set_entity_from=True
+        self: SelfQuery, fromclause: FromClause, set_entity_from: bool = True
     ) -> SelfQuery:
         for attr in (
             "_where_criteria",
@@ -1292,7 +1390,7 @@ class Query(
         "is deprecated and will be removed in a "
         "future release.  Please use :meth:`_query.Query.with_entities`",
     )
-    def values(self, *columns):
+    def values(self, *columns: _ColumnsClauseArgument[Any]) -> Iterable[Any]:
         """Return an iterator yielding result tuples corresponding
         to the given list of columns
 
@@ -1304,7 +1402,7 @@ class Query(
         q._set_entities(columns)
         if not q.load_options._yield_per:
             q.load_options += {"_yield_per": 10}
-        return iter(q)
+        return iter(q)  # type: ignore
 
     _values = values
 
@@ -1315,25 +1413,24 @@ class Query(
         "future release.  Please use :meth:`_query.Query.with_entities` "
         "in combination with :meth:`_query.Query.scalar`",
     )
-    def value(self, column):
+    def value(self, column: _ColumnExpressionArgument[Any]) -> Any:
         """Return a scalar result corresponding to the given
         column expression.
 
         """
         try:
-            return next(self.values(column))[0]
+            return next(self.values(column))[0]  # type: ignore
         except StopIteration:
             return None
 
     @overload
-    def with_entities(
-        self, _entity: _EntityType[_O], **kwargs: Any
-    ) -> Query[_O]:
+    def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]:
         ...
 
     @overload
     def with_entities(
-        self, _colexpr: TypedColumnsClauseRole[_T]
+        self,
+        _colexpr: roles.TypedColumnsClauseRole[_T],
     ) -> RowReturningQuery[Tuple[_T]]:
         ...
 
@@ -1418,14 +1515,14 @@ class Query(
 
     @overload
     def with_entities(
-        self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
-    ) -> SelfQuery:
+        self, *entities: _ColumnsClauseArgument[Any]
+    ) -> Query[Any]:
         ...
 
     @_generative
     def with_entities(
-        self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any
-    ) -> SelfQuery:
+        self, *entities: _ColumnsClauseArgument[Any], **__kw: Any
+    ) -> Query[Any]:
         r"""Return a new :class:`_query.Query`
         replacing the SELECT list with the
         given entities.
@@ -1451,12 +1548,18 @@ class Query(
         """
         if __kw:
             raise _no_kw()
-        _MemoizedSelectEntities._generate_for_statement(self)
+
+        # Query has all the same fields as Select for this operation
+        # this could in theory be based on a protocol but not sure if it's
+        # worth it
+        _MemoizedSelectEntities._generate_for_statement(self)  # type: ignore
         self._set_entities(entities)
         return self
 
     @_generative
-    def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]:
+    def add_columns(
+        self, *column: _ColumnExpressionArgument[Any]
+    ) -> Query[Any]:
         """Add one or more column expressions to the list
         of result columns to be returned."""
 
@@ -1479,7 +1582,7 @@ class Query(
         "is deprecated and will be removed in a "
         "future release.  Please use :meth:`_query.Query.add_columns`",
     )
-    def add_column(self, column) -> Query[Any]:
+    def add_column(self, column: _ColumnExpressionArgument[Any]) -> Query[Any]:
         """Add a column expression to the list of result columns to be
         returned.
 
@@ -1487,7 +1590,7 @@ class Query(
         return self.add_columns(column)
 
     @_generative
-    def options(self: SelfQuery, *args) -> SelfQuery:
+    def options(self: SelfQuery, *args: ExecutableOption) -> SelfQuery:
         """Return a new :class:`_query.Query` object,
         applying the given list of
         mapper options.
@@ -1505,18 +1608,21 @@ class Query(
 
         opts = tuple(util.flatten_iterator(args))
         if self._compile_options._current_path:
+            # opting for lower method overhead for the checks
             for opt in opts:
-                if opt._is_legacy_option:
-                    opt.process_query_conditionally(self)
+                if not opt._is_core and opt._is_legacy_option:  # type: ignore
+                    opt.process_query_conditionally(self)  # type: ignore
         else:
             for opt in opts:
-                if opt._is_legacy_option:
-                    opt.process_query(self)
+                if not opt._is_core and opt._is_legacy_option:  # type: ignore
+                    opt.process_query(self)  # type: ignore
 
         self._with_options += opts
         return self
 
-    def with_transformation(self, fn):
+    def with_transformation(
+        self, fn: Callable[[Query[Any]], Query[Any]]
+    ) -> Query[Any]:
         """Return a new :class:`_query.Query` object transformed by
         the given function.
 
@@ -1535,7 +1641,7 @@ class Query(
         """
         return fn(self)
 
-    def get_execution_options(self):
+    def get_execution_options(self) -> _ImmutableExecuteOptions:
         """Get the non-SQL options which will take effect during execution.
 
         .. versionadded:: 1.3
@@ -1547,7 +1653,7 @@ class Query(
         return self._execution_options
 
     @_generative
-    def execution_options(self: SelfQuery, **kwargs) -> SelfQuery:
+    def execution_options(self: SelfQuery, **kwargs: Any) -> SelfQuery:
         """Set non-SQL options which take effect during execution.
 
         Options allowed here include all of those accepted by
@@ -1596,11 +1702,17 @@ class Query(
     @_generative
     def with_for_update(
         self: SelfQuery,
-        read=False,
-        nowait=False,
-        of=None,
-        skip_locked=False,
-        key_share=False,
+        *,
+        nowait: bool = False,
+        read: bool = False,
+        of: Optional[
+            Union[
+                _ColumnExpressionArgument[Any],
+                Sequence[_ColumnExpressionArgument[Any]],
+            ]
+        ] = None,
+        skip_locked: bool = False,
+        key_share: bool = False,
     ) -> SelfQuery:
         """return a new :class:`_query.Query`
         with the specified options for the
@@ -1659,7 +1771,9 @@ class Query(
         return self
 
     @_generative
-    def params(self: SelfQuery, *args, **kwargs) -> SelfQuery:
+    def params(
+        self: SelfQuery, __params: Optional[Dict[str, Any]] = None, **kw: Any
+    ) -> SelfQuery:
         r"""Add values for bind parameters which may have been
         specified in filter().
 
@@ -1669,17 +1783,14 @@ class Query(
         contain unicode keys in which case \**kwargs cannot be used.
 
         """
-        if len(args) == 1:
-            kwargs.update(args[0])
-        elif len(args) > 0:
-            raise sa_exc.ArgumentError(
-                "params() takes zero or one positional argument, "
-                "which is a dictionary."
-            )
-        self._params = self._params.union(kwargs)
+        if __params:
+            kw.update(__params)
+        self._params = self._params.union(kw)
         return self
 
-    def where(self: SelfQuery, *criterion) -> SelfQuery:
+    def where(
+        self: SelfQuery, *criterion: _ColumnExpressionArgument[bool]
+    ) -> SelfQuery:
         """A synonym for :meth:`.Query.filter`.
 
         .. versionadded:: 1.4
@@ -1716,16 +1827,18 @@ class Query(
             :meth:`_query.Query.filter_by` - filter on keyword expressions.
 
         """
-        for criterion in list(criterion):
-            criterion = coercions.expect(
-                roles.WhereHavingRole, criterion, apply_propagate_attrs=self
+        for crit in list(criterion):
+            crit = coercions.expect(
+                roles.WhereHavingRole, crit, apply_propagate_attrs=self
             )
 
-            self._where_criteria += (criterion,)
+            self._where_criteria += (crit,)
         return self
 
     @util.memoized_property
-    def _last_joined_entity(self):
+    def _last_joined_entity(
+        self,
+    ) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]:
         if self._setup_joins:
             return _determine_last_joined_entity(
                 self._setup_joins,
@@ -1733,7 +1846,7 @@ class Query(
         else:
             return None
 
-    def _filter_by_zero(self):
+    def _filter_by_zero(self) -> Any:
         """for the filter_by() method, return the target entity for which
         we will attempt to derive an expression from based on string name.
 
@@ -1800,13 +1913,6 @@ class Query(
 
         """
         from_entity = self._filter_by_zero()
-        if from_entity is None:
-            raise sa_exc.InvalidRequestError(
-                "Can't use filter_by when the first entity '%s' of a query "
-                "is not a mapped class. Please use the filter method instead, "
-                "or change the order of the entities in the query"
-                % self._query_entity_zero()
-            )
 
         clauses = [
             _entity_namespace_key(from_entity, key) == value
@@ -1815,9 +1921,12 @@ class Query(
         return self.filter(*clauses)
 
     @_generative
-    @_assertions(_no_statement_condition, _no_limit_offset)
     def order_by(
-        self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+        self: SelfQuery,
+        __first: Union[
+            Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+        ] = _NoArg.NO_ARG,
+        *clauses: _ColumnExpressionArgument[Any],
     ) -> SelfQuery:
         """Apply one or more ORDER BY criteria to the query and return
         the newly resulting :class:`_query.Query`.
@@ -1844,20 +1953,27 @@ class Query(
 
         """
 
-        if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+        for assertion in (self._no_statement_condition, self._no_limit_offset):
+            assertion("order_by")
+
+        if not clauses and (__first is None or __first is False):
             self._order_by_clauses = ()
-        else:
+        elif __first is not _NoArg.NO_ARG:
             criterion = tuple(
                 coercions.expect(roles.OrderByRole, clause)
-                for clause in clauses
+                for clause in (__first,) + clauses
             )
             self._order_by_clauses += criterion
+
         return self
 
     @_generative
-    @_assertions(_no_statement_condition, _no_limit_offset)
     def group_by(
-        self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+        self: SelfQuery,
+        __first: Union[
+            Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+        ] = _NoArg.NO_ARG,
+        *clauses: _ColumnExpressionArgument[Any],
     ) -> SelfQuery:
         """Apply one or more GROUP BY criterion to the query and return
         the newly resulting :class:`_query.Query`.
@@ -1878,12 +1994,15 @@ class Query(
 
         """
 
-        if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+        for assertion in (self._no_statement_condition, self._no_limit_offset):
+            assertion("group_by")
+
+        if not clauses and (__first is None or __first is False):
             self._group_by_clauses = ()
-        else:
+        elif __first is not _NoArg.NO_ARG:
             criterion = tuple(
                 coercions.expect(roles.GroupByRole, clause)
-                for clause in clauses
+                for clause in (__first,) + clauses
             )
             self._group_by_clauses += criterion
         return self
@@ -1916,8 +2035,9 @@ class Query(
             self._having_criteria += (having_criteria,)
         return self
 
-    def _set_op(self, expr_fn, *q):
-        return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
+    def _set_op(self: SelfQuery, expr_fn: Any, *q: Query[Any]) -> SelfQuery:
+        list_of_queries = (self,) + q
+        return self._from_selectable(expr_fn(*(list_of_queries)).subquery())
 
     def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
         """Produce a UNION of this Query against one or more queries.
@@ -2006,7 +2126,12 @@ class Query(
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
     def join(
-        self: SelfQuery, target, onclause=None, *, isouter=False, full=False
+        self: SelfQuery,
+        target: _JoinTargetArgument,
+        onclause: Optional[_OnClauseArgument] = None,
+        *,
+        isouter: bool = False,
+        full: bool = False,
     ) -> SelfQuery:
         r"""Create a SQL JOIN against this :class:`_query.Query`
         object's criterion
@@ -2193,20 +2318,23 @@ class Query(
 
         """
 
-        target = coercions.expect(
+        join_target = coercions.expect(
             roles.JoinTargetRole,
             target,
             apply_propagate_attrs=self,
             legacy=True,
         )
         if onclause is not None:
-            onclause = coercions.expect(
+            onclause_element = coercions.expect(
                 roles.OnClauseRole, onclause, legacy=True
             )
+        else:
+            onclause_element = None
+
         self._setup_joins += (
             (
-                target,
-                onclause,
+                join_target,
+                onclause_element,
                 None,
                 {
                     "isouter": isouter,
@@ -2218,7 +2346,13 @@ class Query(
         self.__dict__.pop("_last_joined_entity", None)
         return self
 
-    def outerjoin(self, target, onclause=None, *, full=False):
+    def outerjoin(
+        self: SelfQuery,
+        target: _JoinTargetArgument,
+        onclause: Optional[_OnClauseArgument] = None,
+        *,
+        full: bool = False,
+    ) -> SelfQuery:
         """Create a left outer join against this ``Query`` object's criterion
         and apply generatively, returning the newly resulting ``Query``.
 
@@ -2295,7 +2429,7 @@ class Query(
         self._set_select_from(from_obj, False)
         return self
 
-    def __getitem__(self, item):
+    def __getitem__(self, item: Any) -> Any:
         return orm_util._getitem(
             self,
             item,
@@ -2303,7 +2437,11 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition)
-    def slice(self: SelfQuery, start, stop) -> SelfQuery:
+    def slice(
+        self: SelfQuery,
+        start: int,
+        stop: int,
+    ) -> SelfQuery:
         """Computes the "slice" of the :class:`_query.Query` represented by
         the given indices and returns the resulting :class:`_query.Query`.
 
@@ -2341,7 +2479,9 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition)
-    def limit(self: SelfQuery, limit) -> SelfQuery:
+    def limit(
+        self: SelfQuery, limit: Union[int, _ColumnExpressionArgument[int]]
+    ) -> SelfQuery:
         """Apply a ``LIMIT`` to the query and return the newly resulting
         ``Query``.
 
@@ -2351,7 +2491,9 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition)
-    def offset(self: SelfQuery, offset) -> SelfQuery:
+    def offset(
+        self: SelfQuery, offset: Union[int, _ColumnExpressionArgument[int]]
+    ) -> SelfQuery:
         """Apply an ``OFFSET`` to the query and return the newly resulting
         ``Query``.
 
@@ -2361,7 +2503,9 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition)
-    def distinct(self: SelfQuery, *expr) -> SelfQuery:
+    def distinct(
+        self: SelfQuery, *expr: _ColumnExpressionArgument[Any]
+    ) -> SelfQuery:
         r"""Apply a ``DISTINCT`` to the query and return the newly resulting
         ``Query``.
 
@@ -2415,7 +2559,7 @@ class Query(
 
                 :ref:`faq_query_deduplicating`
         """
-        return self._iter().all()
+        return self._iter().all()  # type: ignore
 
     @_generative
     @_assertions(_no_clauseelement_condition)
@@ -2462,9 +2606,9 @@ class Query(
         """
         # replicates limit(1) behavior
         if self._statement is not None:
-            return self._iter().first()
+            return self._iter().first()  # type: ignore
         else:
-            return self.limit(1)._iter().first()
+            return self.limit(1)._iter().first()  # type: ignore
 
     def one_or_none(self) -> Optional[_T]:
         """Return at most one result or raise an exception.
@@ -2490,7 +2634,7 @@ class Query(
             :meth:`_query.Query.one`
 
         """
-        return self._iter().one_or_none()
+        return self._iter().one_or_none()  # type: ignore
 
     def one(self) -> _T:
         """Return exactly one result or raise an exception.
@@ -2537,18 +2681,18 @@ class Query(
             if not isinstance(ret, collections_abc.Sequence):
                 return ret
             return ret[0]
-        except orm_exc.NoResultFound:
+        except sa_exc.NoResultFound:
             return None
 
     def __iter__(self) -> Iterable[_T]:
-        return self._iter().__iter__()
+        return self._iter().__iter__()  # type: ignore
 
     def _iter(self) -> Union[ScalarResult[_T], Result[_T]]:
         # new style execution.
         params = self._params
 
         statement = self._statement_20()
-        result = self.session.execute(
+        result: Union[ScalarResult[_T], Result[_T]] = self.session.execute(
             statement,
             params,
             execution_options={"_sa_orm_load_options": self.load_options},
@@ -2556,7 +2700,7 @@ class Query(
 
         # legacy: automatically set scalars, unique
         if result._attributes.get("is_single_entity", False):
-            result = result.scalars()
+            result = cast("Result[_T]", result).scalars()
 
         if (
             result._attributes.get("filtered", False)
@@ -2580,7 +2724,7 @@ class Query(
 
         return str(statement.compile(bind))
 
-    def _get_bind_args(self, statement, fn, **kw):
+    def _get_bind_args(self, statement: Any, fn: Any, **kw: Any) -> Any:
         return fn(clause=statement, **kw)
 
     @property
@@ -2634,7 +2778,11 @@ class Query(
 
         return _column_descriptions(self, legacy=True)
 
-    def instances(self, result_proxy: Result, context=None) -> Any:
+    def instances(
+        self,
+        result_proxy: CursorResult[Any],
+        context: Optional[QueryContext] = None,
+    ) -> Any:
         """Return an ORM result given a :class:`_engine.CursorResult` and
         :class:`.QueryContext`.
 
@@ -2661,7 +2809,7 @@ class Query(
 
         # legacy: automatically set scalars, unique
         if result._attributes.get("is_single_entity", False):
-            result = result.scalars()
+            result = result.scalars()  # type: ignore
 
         if result._attributes.get("filtered", False):
             result = result.unique()
@@ -2675,7 +2823,13 @@ class Query(
         ":func:`_orm.merge_frozen_result` function.",
         enable_warnings=False,  # warnings occur via loading.merge_result
     )
-    def merge_result(self, iterator, load=True):
+    def merge_result(
+        self,
+        iterator: Union[
+            FrozenResult[Any], Iterable[Sequence[Any]], Iterable[object]
+        ],
+        load: bool = True,
+    ) -> Union[FrozenResult[Any], Iterable[Any]]:
         """Merge a result into this :class:`_query.Query` object's Session.
 
         Given an iterator returned by a :class:`_query.Query`
@@ -2743,7 +2897,8 @@ class Query(
             self.enable_eagerloads(False)
             .add_columns(sql.literal_column("1"))
             .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
-            .statement.with_only_columns(1)
+            ._get_select_statement_only()
+            .with_only_columns(1)
         )
 
         ezero = self._entity_from_pre_ent_zero()
@@ -2752,7 +2907,7 @@ class Query(
 
         return sql.exists(inner)
 
-    def count(self):
+    def count(self) -> int:
         r"""Return a count of rows this the SQL formed by this :class:`Query`
         would return.
 
@@ -2806,9 +2961,11 @@ class Query(
 
         """
         col = sql.func.count(sql.literal_column("*"))
-        return self._legacy_from_self(col).enable_eagerloads(False).scalar()
+        return (  # type: ignore
+            self._legacy_from_self(col).enable_eagerloads(False).scalar()
+        )
 
-    def delete(self, synchronize_session="evaluate"):
+    def delete(self, synchronize_session: str = "evaluate") -> int:
         r"""Perform a DELETE with an arbitrary WHERE clause.
 
         Deletes rows matched by this query from the database.
@@ -2850,20 +3007,28 @@ class Query(
 
                 self = bulk_del.query
 
-        delete_ = sql.delete(*self._raw_columns)
+        delete_ = sql.delete(*self._raw_columns)  # type: ignore
         delete_._where_criteria = self._where_criteria
-        result = self.session.execute(
-            delete_,
-            self._params,
-            execution_options={"synchronize_session": synchronize_session},
+        result: CursorResult[Any] = cast(
+            "CursorResult[Any]",
+            self.session.execute(
+                delete_,
+                self._params,
+                execution_options={"synchronize_session": synchronize_session},
+            ),
         )
-        bulk_del.result = result
+        bulk_del.result = result  # type: ignore
         self.session.dispatch.after_bulk_delete(bulk_del)
         result.close()
 
         return result.rowcount
 
-    def update(self, values, synchronize_session="evaluate", update_args=None):
+    def update(
+        self,
+        values: Dict[_DMLColumnArgument, Any],
+        synchronize_session: str = "evaluate",
+        update_args: Optional[Dict[Any, Any]] = None,
+    ) -> int:
         r"""Perform an UPDATE with an arbitrary WHERE clause.
 
         Updates rows matched by this query in the database.
@@ -2926,28 +3091,33 @@ class Query(
                     bulk_ud.query = new_query
             self = bulk_ud.query
 
-        upd = sql.update(*self._raw_columns)
+        upd = sql.update(*self._raw_columns)  # type: ignore
 
         ppo = update_args.pop("preserve_parameter_order", False)
         if ppo:
-            upd = upd.ordered_values(*values)
+            upd = upd.ordered_values(*values)  # type: ignore
         else:
             upd = upd.values(values)
         if update_args:
             upd = upd.with_dialect_options(**update_args)
 
         upd._where_criteria = self._where_criteria
-        result = self.session.execute(
-            upd,
-            self._params,
-            execution_options={"synchronize_session": synchronize_session},
+        result: CursorResult[Any] = cast(
+            "CursorResult[Any]",
+            self.session.execute(
+                upd,
+                self._params,
+                execution_options={"synchronize_session": synchronize_session},
+            ),
         )
-        bulk_ud.result = result
+        bulk_ud.result = result  # type: ignore
         self.session.dispatch.after_bulk_update(bulk_ud)
         result.close()
         return result.rowcount
 
-    def _compile_state(self, for_statement=False, **kw):
+    def _compile_state(
+        self, for_statement: bool = False, **kw: Any
+    ) -> ORMCompileState:
         """Create an out-of-compiler ORMCompileState object.
 
         The ORMCompileState object is normally created directly as a result
@@ -2971,13 +3141,14 @@ class Query(
         # ORMSelectCompileState.  We could also base this on
         # query._statement is not None as we have the ORM Query here
         # however this is the more general path.
-        compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
-            stmt, "orm"
+        compile_state_cls = cast(
+            ORMCompileState,
+            ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"),
         )
 
         return compile_state_cls.create_for_statement(stmt, None)
 
-    def _compile_context(self, for_statement=False):
+    def _compile_context(self, for_statement: bool = False) -> QueryContext:
         compile_state = self._compile_state(for_statement=for_statement)
         context = QueryContext(
             compile_state,
@@ -3006,7 +3177,7 @@ class AliasOption(interfaces.LoaderOption):
 
         """
 
-    def process_compile_state(self, compile_state: ORMCompileState):
+    def process_compile_state(self, compile_state: ORMCompileState) -> None:
         pass
 
 
@@ -3017,12 +3188,12 @@ class BulkUD:
 
     """
 
-    def __init__(self, query):
+    def __init__(self, query: Query[Any]):
         self.query = query.enable_eagerloads(False)
         self._validate_query_state()
         self.mapper = self.query._entity_from_pre_ent_zero()
 
-    def _validate_query_state(self):
+    def _validate_query_state(self) -> None:
         for attr, methname, notset, op in (
             ("_limit_clause", "limit()", None, operator.is_),
             ("_offset_clause", "offset()", None, operator.is_),
@@ -3049,14 +3220,19 @@ class BulkUD:
                 )
 
     @property
-    def session(self):
+    def session(self) -> Session:
         return self.query.session
 
 
 class BulkUpdate(BulkUD):
     """BulkUD which handles UPDATEs."""
 
-    def __init__(self, query, values, update_kwargs):
+    def __init__(
+        self,
+        query: Query[Any],
+        values: Dict[_DMLColumnArgument, Any],
+        update_kwargs: Optional[Dict[Any, Any]],
+    ):
         super(BulkUpdate, self).__init__(query)
         self.values = values
         self.update_kwargs = update_kwargs
@@ -3067,4 +3243,7 @@ class BulkDelete(BulkUD):
 
 
 class RowReturningQuery(Query[Row[_TP]]):
-    pass
+    if TYPE_CHECKING:
+
+        def tuples(self) -> Query[_TP]:  # type: ignore
+            ...
index 8273775ae1e651ad49996979eeed7a4f9f4636cc..1186f0f541965e6b50131475cc2ec394b70050ad 100644 (file)
@@ -17,13 +17,23 @@ from __future__ import annotations
 
 import collections
 from collections import abc
+import dataclasses
 import re
 import typing
 from typing import Any
 from typing import Callable
+from typing import cast
+from typing import Collection
 from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NamedTuple
+from typing import NoReturn
 from typing import Optional
 from typing import Sequence
+from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TypeVar
@@ -32,14 +42,19 @@ import weakref
 
 from . import attributes
 from . import strategy_options
+from ._typing import insp_is_aliased_class
+from ._typing import is_has_collection_adapter
 from .base import _is_mapped_class
 from .base import class_mapper
+from .base import LoaderCallableStatus
+from .base import PassiveFlag
 from .base import state_str
 from .interfaces import _IntrospectsAnnotations
 from .interfaces import MANYTOMANY
 from .interfaces import MANYTOONE
 from .interfaces import ONETOMANY
 from .interfaces import PropComparator
+from .interfaces import RelationshipDirection
 from .interfaces import StrategizedProperty
 from .util import _extract_mapped_subtype
 from .util import _orm_annotate
@@ -60,6 +75,7 @@ from ..sql import visitors
 from ..sql._typing import _ColumnExpressionArgument
 from ..sql._typing import _HasClauseElement
 from ..sql.elements import ColumnClause
+from ..sql.elements import ColumnElement
 from ..sql.util import _deep_deannotate
 from ..sql.util import _shallow_annotate
 from ..sql.util import adapt_criterion_to_null
@@ -71,15 +87,42 @@ from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
     from ._typing import _EntityType
+    from ._typing import _ExternalEntityType
+    from ._typing import _IdentityKeyType
+    from ._typing import _InstanceDict
     from ._typing import _InternalEntityType
+    from ._typing import _O
+    from ._typing import _RegistryType
+    from .clsregistry import _class_resolver
+    from .clsregistry import _ModNS
+    from .dependency import DependencyProcessor
     from .mapper import Mapper
+    from .query import Query
+    from .session import Session
+    from .state import InstanceState
+    from .strategies import LazyLoader
     from .util import AliasedClass
     from .util import AliasedInsp
-    from ..sql.elements import ColumnElement
+    from ..sql._typing import _CoreAdapterProto
+    from ..sql._typing import _EquivalentColumnMap
+    from ..sql._typing import _InfoType
+    from ..sql.annotation import _AnnotationDict
+    from ..sql.elements import BinaryExpression
+    from ..sql.elements import BindParameter
+    from ..sql.elements import ClauseElement
+    from ..sql.schema import Table
+    from ..sql.selectable import FromClause
+    from ..util.typing import _AnnotationScanType
+    from ..util.typing import RODescriptorReference
 
 _T = TypeVar("_T", bound=Any)
+_T1 = TypeVar("_T1", bound=Any)
+_T2 = TypeVar("_T2", bound=Any)
+
 _PT = TypeVar("_PT", bound=Any)
 
+_PT2 = TypeVar("_PT2", bound=Any)
+
 
 _RelationshipArgumentType = Union[
     str,
@@ -111,7 +154,10 @@ _RelationshipJoinConditionArgument = Union[
     str, _ColumnExpressionArgument[bool]
 ]
 _ORMOrderByArgument = Union[
-    Literal[False], str, _ColumnExpressionArgument[Any]
+    Literal[False],
+    str,
+    _ColumnExpressionArgument[Any],
+    Iterable[Union[str, _ColumnExpressionArgument[Any]]],
 ]
 _ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
 _ORMColCollectionArgument = Union[
@@ -120,7 +166,19 @@ _ORMColCollectionArgument = Union[
 ]
 
 
-def remote(expr):
+_CEA = TypeVar("_CEA", bound=_ColumnExpressionArgument[Any])
+
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
+
+
+_ColumnPairIterable = Iterable[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+_ColumnPairs = Sequence[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+_MutableColumnPairs = List[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+
+def remote(expr: _CEA) -> _CEA:
     """Annotate a portion of a primaryjoin expression
     with a 'remote' annotation.
 
@@ -134,12 +192,12 @@ def remote(expr):
         :func:`.foreign`
 
     """
-    return _annotate_columns(
+    return _annotate_columns(  # type: ignore
         coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True}
     )
 
 
-def foreign(expr):
+def foreign(expr: _CEA) -> _CEA:
     """Annotate a portion of a primaryjoin expression
     with a 'foreign' annotation.
 
@@ -154,11 +212,71 @@ def foreign(expr):
 
     """
 
-    return _annotate_columns(
+    return _annotate_columns(  # type: ignore
         coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True}
     )
 
 
+@dataclasses.dataclass
+class _RelationshipArg(Generic[_T1, _T2]):
+    """stores a user-defined parameter value that must be resolved and
+    parsed later at mapper configuration time.
+
+    """
+
+    __slots__ = "name", "argument", "resolved"
+    name: str
+    argument: _T1
+    resolved: Optional[_T2]
+
+    def _is_populated(self) -> bool:
+        return self.argument is not None
+
+    def _resolve_against_registry(
+        self, clsregistry_resolver: Callable[[str, bool], _class_resolver]
+    ) -> None:
+        attr_value = self.argument
+
+        if isinstance(attr_value, str):
+            self.resolved = clsregistry_resolver(
+                attr_value, self.name == "secondary"
+            )()
+        elif callable(attr_value) and not _is_mapped_class(attr_value):
+            self.resolved = attr_value()
+        else:
+            self.resolved = attr_value
+
+
+class _RelationshipArgs(NamedTuple):
+    """stores user-passed parameters that are resolved at mapper configuration
+    time.
+
+    """
+
+    secondary: _RelationshipArg[
+        Optional[Union[FromClause, str]],
+        Optional[FromClause],
+    ]
+    primaryjoin: _RelationshipArg[
+        Optional[_RelationshipJoinConditionArgument],
+        Optional[ColumnElement[Any]],
+    ]
+    secondaryjoin: _RelationshipArg[
+        Optional[_RelationshipJoinConditionArgument],
+        Optional[ColumnElement[Any]],
+    ]
+    order_by: _RelationshipArg[
+        _ORMOrderByArgument,
+        Union[Literal[None, False], Tuple[ColumnElement[Any], ...]],
+    ]
+    foreign_keys: _RelationshipArg[
+        Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
+    ]
+    remote_side: _RelationshipArg[
+        Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
+    ]
+
+
 @log.class_logger
 class Relationship(
     _IntrospectsAnnotations, StrategizedProperty[_T], log.Identified
@@ -184,6 +302,10 @@ class Relationship(
     _links_to_entity = True
     _is_relationship = True
 
+    _overlaps: Sequence[str]
+
+    _lazy_strategy: LazyLoader
+
     _persistence_only = dict(
         passive_deletes=False,
         passive_updates=True,
@@ -192,56 +314,87 @@ class Relationship(
         cascade_backrefs=False,
     )
 
-    _dependency_processor = None
+    _dependency_processor: Optional[DependencyProcessor] = None
+
+    primaryjoin: ColumnElement[bool]
+    secondaryjoin: Optional[ColumnElement[bool]]
+    secondary: Optional[FromClause]
+    _join_condition: JoinCondition
+    order_by: Union[Literal[False], Tuple[ColumnElement[Any], ...]]
+
+    _user_defined_foreign_keys: Set[ColumnElement[Any]]
+    _calculated_foreign_keys: Set[ColumnElement[Any]]
+
+    remote_side: Set[ColumnElement[Any]]
+    local_columns: Set[ColumnElement[Any]]
+
+    synchronize_pairs: _ColumnPairs
+    secondary_synchronize_pairs: Optional[_ColumnPairs]
+
+    local_remote_pairs: Optional[_ColumnPairs]
+
+    direction: RelationshipDirection
+
+    _init_args: _RelationshipArgs
 
     def __init__(
         self,
         argument: Optional[_RelationshipArgumentType[_T]] = None,
-        secondary=None,
+        secondary: Optional[Union[FromClause, str]] = None,
         *,
-        uselist=None,
-        collection_class=None,
-        primaryjoin=None,
-        secondaryjoin=None,
-        back_populates=None,
-        order_by=False,
-        backref=None,
-        cascade_backrefs=False,
-        overlaps=None,
-        post_update=False,
-        cascade="save-update, merge",
-        viewonly=False,
+        uselist: Optional[bool] = None,
+        collection_class: Optional[
+            Union[Type[Collection[Any]], Callable[[], Collection[Any]]]
+        ] = None,
+        primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+        secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+        back_populates: Optional[str] = None,
+        order_by: _ORMOrderByArgument = False,
+        backref: Optional[_ORMBackrefArgument] = None,
+        overlaps: Optional[str] = None,
+        post_update: bool = False,
+        cascade: str = "save-update, merge",
+        viewonly: bool = False,
         lazy: _LazyLoadArgumentType = "select",
-        passive_deletes=False,
-        passive_updates=True,
-        active_history=False,
-        enable_typechecks=True,
-        foreign_keys=None,
-        remote_side=None,
-        join_depth=None,
-        comparator_factory=None,
-        single_parent=False,
-        innerjoin=False,
-        distinct_target_key=None,
-        load_on_pending=False,
-        query_class=None,
-        info=None,
-        omit_join=None,
-        sync_backref=None,
-        doc=None,
-        bake_queries=True,
-        _local_remote_pairs=None,
-        _legacy_inactive_history_style=False,
+        passive_deletes: Union[Literal["all"], bool] = False,
+        passive_updates: bool = True,
+        active_history: bool = False,
+        enable_typechecks: bool = True,
+        foreign_keys: Optional[_ORMColCollectionArgument] = None,
+        remote_side: Optional[_ORMColCollectionArgument] = None,
+        join_depth: Optional[int] = None,
+        comparator_factory: Optional[
+            Type[Relationship.Comparator[Any]]
+        ] = None,
+        single_parent: bool = False,
+        innerjoin: bool = False,
+        distinct_target_key: Optional[bool] = None,
+        load_on_pending: bool = False,
+        query_class: Optional[Type[Query[Any]]] = None,
+        info: Optional[_InfoType] = None,
+        omit_join: Literal[None, False] = None,
+        sync_backref: Optional[bool] = None,
+        doc: Optional[str] = None,
+        bake_queries: Literal[True] = True,
+        cascade_backrefs: Literal[False] = False,
+        _local_remote_pairs: Optional[_ColumnPairs] = None,
+        _legacy_inactive_history_style: bool = False,
     ):
         super(Relationship, self).__init__()
 
         self.uselist = uselist
         self.argument = argument
-        self.secondary = secondary
-        self.primaryjoin = primaryjoin
-        self.secondaryjoin = secondaryjoin
+
+        self._init_args = _RelationshipArgs(
+            _RelationshipArg("secondary", secondary, None),
+            _RelationshipArg("primaryjoin", primaryjoin, None),
+            _RelationshipArg("secondaryjoin", secondaryjoin, None),
+            _RelationshipArg("order_by", order_by, None),
+            _RelationshipArg("foreign_keys", foreign_keys, None),
+            _RelationshipArg("remote_side", remote_side, None),
+        )
+
         self.post_update = post_update
-        self.direction = None
         self.viewonly = viewonly
         if viewonly:
             self._warn_for_persistence_only_flags(
@@ -258,7 +411,6 @@ class Relationship(
         self.sync_backref = sync_backref
         self.lazy = lazy
         self.single_parent = single_parent
-        self._user_defined_foreign_keys = foreign_keys
         self.collection_class = collection_class
         self.passive_deletes = passive_deletes
 
@@ -269,7 +421,6 @@ class Relationship(
             )
 
         self.passive_updates = passive_updates
-        self.remote_side = remote_side
         self.enable_typechecks = enable_typechecks
         self.query_class = query_class
         self.innerjoin = innerjoin
@@ -292,23 +443,22 @@ class Relationship(
         self.local_remote_pairs = _local_remote_pairs
         self.load_on_pending = load_on_pending
         self.comparator_factory = comparator_factory or Relationship.Comparator
-        self.comparator = self.comparator_factory(self, None)
         util.set_creation_order(self)
 
         if info is not None:
-            self.info = info
+            self.info.update(info)
 
         self.strategy_key = (("lazy", self.lazy),)
 
-        self._reverse_property = set()
+        self._reverse_property: Set[Relationship[Any]] = set()
+
         if overlaps:
-            self._overlaps = set(re.split(r"\s*,\s*", overlaps))
+            self._overlaps = set(re.split(r"\s*,\s*", overlaps))  # type: ignore  # noqa: E501
         else:
             self._overlaps = ()
 
-        self.cascade = cascade
-
-        self.order_by = order_by
+        # mypy ignoring the @property setter
+        self.cascade = cascade  # type: ignore
 
         self.back_populates = back_populates
 
@@ -322,7 +472,7 @@ class Relationship(
         else:
             self.backref = backref
 
-    def _warn_for_persistence_only_flags(self, **kw):
+    def _warn_for_persistence_only_flags(self, **kw: Any) -> None:
         for k, v in kw.items():
             if v != self._persistence_only[k]:
                 # we are warning here rather than warn deprecated as this is a
@@ -340,7 +490,7 @@ class Relationship(
                     "in a future release." % (k,)
                 )
 
-    def instrument_class(self, mapper):
+    def instrument_class(self, mapper: Mapper[Any]) -> None:
         attributes.register_descriptor(
             mapper.class_,
             self.key,
@@ -378,13 +528,16 @@ class Relationship(
             "_extra_criteria",
         )
 
+        prop: RODescriptorReference[Relationship[_PT]]
+        _of_type: Optional[_EntityType[_PT]]
+
         def __init__(
             self,
-            prop,
-            parentmapper,
-            adapt_to_entity=None,
-            of_type=None,
-            extra_criteria=(),
+            prop: Relationship[_PT],
+            parentmapper: _InternalEntityType[Any],
+            adapt_to_entity: Optional[AliasedInsp[Any]] = None,
+            of_type: Optional[_EntityType[_PT]] = None,
+            extra_criteria: Tuple[ColumnElement[bool], ...] = (),
         ):
             """Construction of :class:`.Relationship.Comparator`
             is internal to the ORM's attribute mechanics.
@@ -399,15 +552,17 @@ class Relationship(
                 self._of_type = None
             self._extra_criteria = extra_criteria
 
-        def adapt_to_entity(self, adapt_to_entity):
+        def adapt_to_entity(
+            self, adapt_to_entity: AliasedInsp[Any]
+        ) -> Relationship.Comparator[Any]:
             return self.__class__(
-                self.property,
+                self.prop,
                 self._parententity,
                 adapt_to_entity=adapt_to_entity,
                 of_type=self._of_type,
             )
 
-        entity: _InternalEntityType
+        entity: _InternalEntityType[_PT]
         """The target entity referred to by this
         :class:`.Relationship.Comparator`.
 
@@ -419,7 +574,7 @@ class Relationship(
 
         """
 
-        mapper: Mapper[Any]
+        mapper: Mapper[_PT]
         """The target :class:`_orm.Mapper` referred to by this
         :class:`.Relationship.Comparator`.
 
@@ -428,22 +583,22 @@ class Relationship(
 
         """
 
-        def _memoized_attr_entity(self) -> _InternalEntityType:
+        def _memoized_attr_entity(self) -> _InternalEntityType[_PT]:
             if self._of_type:
-                return inspect(self._of_type)
+                return inspect(self._of_type)  # type: ignore
             else:
                 return self.prop.entity
 
-        def _memoized_attr_mapper(self) -> Mapper[Any]:
+        def _memoized_attr_mapper(self) -> Mapper[_PT]:
             return self.entity.mapper
 
-        def _source_selectable(self):
+        def _source_selectable(self) -> FromClause:
             if self._adapt_to_entity:
                 return self._adapt_to_entity.selectable
             else:
                 return self.property.parent._with_polymorphic_selectable
 
-        def __clause_element__(self):
+        def __clause_element__(self) -> ColumnElement[bool]:
             adapt_from = self._source_selectable()
             if self._of_type:
                 of_type_entity = inspect(self._of_type)
@@ -457,7 +612,7 @@ class Relationship(
                 dest,
                 secondary,
                 target_adapter,
-            ) = self.property._create_joins(
+            ) = self.prop._create_joins(
                 source_selectable=adapt_from,
                 source_polymorphic=True,
                 of_type_entity=of_type_entity,
@@ -469,7 +624,7 @@ class Relationship(
             else:
                 return pj
 
-        def of_type(self, cls):
+        def of_type(self, class_: _EntityType[_PT]) -> PropComparator[_PT]:
             r"""Redefine this object in terms of a polymorphic subclass.
 
             See :meth:`.PropComparator.of_type` for an example.
@@ -477,16 +632,16 @@ class Relationship(
 
             """
             return Relationship.Comparator(
-                self.property,
+                self.prop,
                 self._parententity,
                 adapt_to_entity=self._adapt_to_entity,
-                of_type=cls,
+                of_type=class_,
                 extra_criteria=self._extra_criteria,
             )
 
         def and_(
             self, *criteria: _ColumnExpressionArgument[bool]
-        ) -> PropComparator[bool]:
+        ) -> PropComparator[Any]:
             """Add AND criteria.
 
             See :meth:`.PropComparator.and_` for an example.
@@ -500,14 +655,14 @@ class Relationship(
             )
 
             return Relationship.Comparator(
-                self.property,
+                self.prop,
                 self._parententity,
                 adapt_to_entity=self._adapt_to_entity,
                 of_type=self._of_type,
                 extra_criteria=self._extra_criteria + exprs,
             )
 
-        def in_(self, other):
+        def in_(self, other: Any) -> NoReturn:
             """Produce an IN clause - this is not implemented
             for :func:`_orm.relationship`-based attributes at this time.
 
@@ -522,7 +677,7 @@ class Relationship(
         # https://github.com/python/mypy/issues/4266
         __hash__ = None  # type: ignore
 
-        def __eq__(self, other):
+        def __eq__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
             """Implement the ``==`` operator.
 
             In a many-to-one context, such as::
@@ -559,7 +714,7 @@ class Relationship(
               or many-to-many context produce a NOT EXISTS clause.
 
             """
-            if isinstance(other, (util.NoneType, expression.Null)):
+            if other is None or isinstance(other, expression.Null):
                 if self.property.direction in [ONETOMANY, MANYTOMANY]:
                     return ~self._criterion_exists()
                 else:
@@ -585,8 +740,18 @@ class Relationship(
             criterion: Optional[_ColumnExpressionArgument[bool]] = None,
             **kwargs: Any,
         ) -> Exists:
+
+            where_criteria = (
+                coercions.expect(roles.WhereHavingRole, criterion)
+                if criterion is not None
+                else None
+            )
+
             if getattr(self, "_of_type", None):
-                info = inspect(self._of_type)
+                info: Optional[_InternalEntityType[Any]] = inspect(
+                    self._of_type
+                )
+                assert info is not None
                 target_mapper, to_selectable, is_aliased_class = (
                     info.mapper,
                     info.selectable,
@@ -597,10 +762,10 @@ class Relationship(
 
                 single_crit = target_mapper._single_table_criterion
                 if single_crit is not None:
-                    if criterion is not None:
-                        criterion = single_crit & criterion
+                    if where_criteria is not None:
+                        where_criteria = single_crit & where_criteria
                     else:
-                        criterion = single_crit
+                        where_criteria = single_crit
             else:
                 is_aliased_class = False
                 to_selectable = None
@@ -624,10 +789,10 @@ class Relationship(
 
             for k in kwargs:
                 crit = getattr(self.property.mapper.class_, k) == kwargs[k]
-                if criterion is None:
-                    criterion = crit
+                if where_criteria is None:
+                    where_criteria = crit
                 else:
-                    criterion = criterion & crit
+                    where_criteria = where_criteria & crit
 
             # annotate the *local* side of the join condition, in the case
             # of pj + sj this is the full primaryjoin, in the case of just
@@ -638,24 +803,24 @@ class Relationship(
                 j = _orm_annotate(pj, exclude=self.property.remote_side)
 
             if (
-                criterion is not None
+                where_criteria is not None
                 and target_adapter
                 and not is_aliased_class
             ):
                 # limit this adapter to annotated only?
-                criterion = target_adapter.traverse(criterion)
+                where_criteria = target_adapter.traverse(where_criteria)
 
             # only have the "joined left side" of what we
             # return be subject to Query adaption.  The right
             # side of it is used for an exists() subquery and
             # should not correlate or otherwise reach out
             # to anything in the enclosing query.
-            if criterion is not None:
-                criterion = criterion._annotate(
+            if where_criteria is not None:
+                where_criteria = where_criteria._annotate(
                     {"no_replacement_traverse": True}
                 )
 
-            crit = j & sql.True_._ifnone(criterion)
+            crit = j & sql.True_._ifnone(where_criteria)
 
             if secondary is not None:
                 ex = (
@@ -673,7 +838,11 @@ class Relationship(
                 )
             return ex
 
-        def any(self, criterion=None, **kwargs):
+        def any(
+            self,
+            criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+            **kwargs: Any,
+        ) -> ColumnElement[bool]:
             """Produce an expression that tests a collection against
             particular criterion, using EXISTS.
 
@@ -722,7 +891,11 @@ class Relationship(
 
             return self._criterion_exists(criterion, **kwargs)
 
-        def has(self, criterion=None, **kwargs):
+        def has(
+            self,
+            criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+            **kwargs: Any,
+        ) -> ColumnElement[bool]:
             """Produce an expression that tests a scalar reference against
             particular criterion, using EXISTS.
 
@@ -756,7 +929,9 @@ class Relationship(
                 )
             return self._criterion_exists(criterion, **kwargs)
 
-        def contains(self, other, **kwargs):
+        def contains(
+            self, other: _ColumnExpressionArgument[Any], **kwargs: Any
+        ) -> ColumnElement[bool]:
             """Return a simple expression that tests a collection for
             containment of a particular item.
 
@@ -815,38 +990,45 @@ class Relationship(
             kwargs may be ignored by this operator but are required for API
             conformance.
             """
-            if not self.property.uselist:
+            if not self.prop.uselist:
                 raise sa_exc.InvalidRequestError(
                     "'contains' not implemented for scalar "
                     "attributes.  Use =="
                 )
-            clause = self.property._optimized_compare(
+
+            clause = self.prop._optimized_compare(
                 other, adapt_source=self.adapter
             )
 
-            if self.property.secondaryjoin is not None:
+            if self.prop.secondaryjoin is not None:
                 clause.negation_clause = self.__negated_contains_or_equals(
                     other
                 )
 
             return clause
 
-        def __negated_contains_or_equals(self, other):
-            if self.property.direction == MANYTOONE:
+        def __negated_contains_or_equals(
+            self, other: Any
+        ) -> ColumnElement[bool]:
+            if self.prop.direction == MANYTOONE:
                 state = attributes.instance_state(other)
 
-                def state_bindparam(local_col, state, remote_col):
+                def state_bindparam(
+                    local_col: ColumnElement[Any],
+                    state: InstanceState[Any],
+                    remote_col: ColumnElement[Any],
+                ) -> BindParameter[Any]:
                     dict_ = state.dict
                     return sql.bindparam(
                         local_col.key,
                         type_=local_col.type,
                         unique=True,
-                        callable_=self.property._get_attr_w_warn_on_none(
-                            self.property.mapper, state, dict_, remote_col
+                        callable_=self.prop._get_attr_w_warn_on_none(
+                            self.prop.mapper, state, dict_, remote_col
                         ),
                     )
 
-                def adapt(col):
+                def adapt(col: _CE) -> _CE:
                     if self.adapter:
                         return self.adapter(col)
                     else:
@@ -876,7 +1058,7 @@ class Relationship(
 
             return ~self._criterion_exists(criterion)
 
-        def __ne__(self, other):
+        def __ne__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
             """Implement the ``!=`` operator.
 
             In a many-to-one context, such as::
@@ -915,7 +1097,7 @@ class Relationship(
               or many-to-many context produce an EXISTS clause.
 
             """
-            if isinstance(other, (util.NoneType, expression.Null)):
+            if other is None or isinstance(other, expression.Null):
                 if self.property.direction == MANYTOONE:
                     return _orm_annotate(
                         ~self.property._optimized_compare(
@@ -934,12 +1116,10 @@ class Relationship(
             else:
                 return _orm_annotate(self.__negated_contains_or_equals(other))
 
-        def _memoized_attr_property(self):
+        def _memoized_attr_property(self) -> Relationship[_PT]:
             self.prop.parent._check_configure()
             return self.prop
 
-    comparator: Comparator[_T]
-
     def _with_parent(
         self,
         instance: object,
@@ -947,10 +1127,11 @@ class Relationship(
         from_entity: Optional[_EntityType[Any]] = None,
     ) -> ColumnElement[bool]:
         assert instance is not None
-        adapt_source = None
+        adapt_source: Optional[_CoreAdapterProto] = None
         if from_entity is not None:
-            insp = inspect(from_entity)
-            if insp.is_aliased_class:
+            insp: Optional[_InternalEntityType[Any]] = inspect(from_entity)
+            assert insp is not None
+            if insp_is_aliased_class(insp):
                 adapt_source = insp._adapter.adapt_clause
         return self._optimized_compare(
             instance,
@@ -961,11 +1142,11 @@ class Relationship(
 
     def _optimized_compare(
         self,
-        state,
-        value_is_parent=False,
-        adapt_source=None,
-        alias_secondary=True,
-    ):
+        state: Any,
+        value_is_parent: bool = False,
+        adapt_source: Optional[_CoreAdapterProto] = None,
+        alias_secondary: bool = True,
+    ) -> ColumnElement[bool]:
         if state is not None:
             try:
                 state = inspect(state)
@@ -1005,7 +1186,7 @@ class Relationship(
 
         dict_ = attributes.instance_dict(state.obj())
 
-        def visit_bindparam(bindparam):
+        def visit_bindparam(bindparam: BindParameter[Any]) -> None:
             if bindparam._identifying_key in bind_to_col:
                 bindparam.callable = self._get_attr_w_warn_on_none(
                     mapper,
@@ -1027,7 +1208,13 @@ class Relationship(
             criterion = adapt_source(criterion)
         return criterion
 
-    def _get_attr_w_warn_on_none(self, mapper, state, dict_, column):
+    def _get_attr_w_warn_on_none(
+        self,
+        mapper: Mapper[Any],
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        column: ColumnElement[Any],
+    ) -> Callable[[], Any]:
         """Create the callable that is used in a many-to-one expression.
 
         E.g.::
@@ -1077,9 +1264,14 @@ class Relationship(
         # this feature was added explicitly for use in this method.
         state._track_last_known_value(prop.key)
 
-        def _go():
-            last_known = to_return = state._last_known_values[prop.key]
-            existing_is_available = last_known is not attributes.NO_VALUE
+        lkv_fixed = state._last_known_values
+
+        def _go() -> Any:
+            assert lkv_fixed is not None
+            last_known = to_return = lkv_fixed[prop.key]
+            existing_is_available = (
+                last_known is not LoaderCallableStatus.NO_VALUE
+            )
 
             # we support that the value may have changed.  so here we
             # try to get the most recent value including re-fetching.
@@ -1089,19 +1281,19 @@ class Relationship(
                 state,
                 dict_,
                 column,
-                passive=attributes.PASSIVE_OFF
+                passive=PassiveFlag.PASSIVE_OFF
                 if state.persistent
-                else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK,
+                else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK,
             )
 
-            if current_value is attributes.NEVER_SET:
+            if current_value is LoaderCallableStatus.NEVER_SET:
                 if not existing_is_available:
                     raise sa_exc.InvalidRequestError(
                         "Can't resolve value for column %s on object "
                         "%s; no value has been set for this column"
                         % (column, state_str(state))
                     )
-            elif current_value is attributes.PASSIVE_NO_RESULT:
+            elif current_value is LoaderCallableStatus.PASSIVE_NO_RESULT:
                 if not existing_is_available:
                     raise sa_exc.InvalidRequestError(
                         "Can't resolve value for column %s on object "
@@ -1121,7 +1313,11 @@ class Relationship(
 
         return _go
 
-    def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
+    def _lazy_none_clause(
+        self,
+        reverse_direction: bool = False,
+        adapt_source: Optional[_CoreAdapterProto] = None,
+    ) -> ColumnElement[bool]:
         if not reverse_direction:
             criterion, bind_to_col = (
                 self._lazy_strategy._lazywhere,
@@ -1139,20 +1335,20 @@ class Relationship(
             criterion = adapt_source(criterion)
         return criterion
 
-    def __str__(self):
+    def __str__(self) -> str:
         return str(self.parent.class_.__name__) + "." + self.key
 
     def merge(
         self,
-        session,
-        source_state,
-        source_dict,
-        dest_state,
-        dest_dict,
-        load,
-        _recursive,
-        _resolve_conflict_map,
-    ):
+        session: Session,
+        source_state: InstanceState[Any],
+        source_dict: _InstanceDict,
+        dest_state: InstanceState[Any],
+        dest_dict: _InstanceDict,
+        load: bool,
+        _recursive: Dict[Any, object],
+        _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+    ) -> None:
 
         if load:
             for r in self._reverse_property:
@@ -1167,6 +1363,8 @@ class Relationship(
 
         if self.uselist:
             impl = source_state.get_impl(self.key)
+
+            assert is_has_collection_adapter(impl)
             instances_iterable = impl.get_collection(source_state, source_dict)
 
             # if this is a CollectionAttributeImpl, then empty should
@@ -1204,9 +1402,9 @@ class Relationship(
                 for c in dest_list:
                     coll.append_without_event(c)
             else:
-                dest_state.get_impl(self.key).set(
-                    dest_state, dest_dict, dest_list, _adapt=False
-                )
+                dest_impl = dest_state.get_impl(self.key)
+                assert is_has_collection_adapter(dest_impl)
+                dest_impl.set(dest_state, dest_dict, dest_list, _adapt=False)
         else:
             current = source_dict[self.key]
             if current is not None:
@@ -1231,8 +1429,12 @@ class Relationship(
                 )
 
     def _value_as_iterable(
-        self, state, dict_, key, passive=attributes.PASSIVE_OFF
-    ):
+        self,
+        state: InstanceState[_O],
+        dict_: _InstanceDict,
+        key: str,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+    ) -> Sequence[Tuple[InstanceState[_O], _O]]:
         """Return a list of tuples (state, obj) for the given
         key.
 
@@ -1241,9 +1443,9 @@ class Relationship(
 
         impl = state.manager[key].impl
         x = impl.get(state, dict_, passive=passive)
-        if x is attributes.PASSIVE_NO_RESULT or x is None:
+        if x is LoaderCallableStatus.PASSIVE_NO_RESULT or x is None:
             return []
-        elif hasattr(impl, "get_collection"):
+        elif is_has_collection_adapter(impl):
             return [
                 (attributes.instance_state(o), o)
                 for o in impl.get_collection(state, dict_, x, passive=passive)
@@ -1252,19 +1454,23 @@ class Relationship(
             return [(attributes.instance_state(x), x)]
 
     def cascade_iterator(
-        self, type_, state, dict_, visited_states, halt_on=None
-    ):
+        self,
+        type_: str,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        visited_states: Set[InstanceState[Any]],
+        halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None,
+    ) -> Iterator[Tuple[Any, Mapper[Any], InstanceState[Any], _InstanceDict]]:
         # assert type_ in self._cascade
 
         # only actively lazy load on the 'delete' cascade
         if type_ != "delete" or self.passive_deletes:
-            passive = attributes.PASSIVE_NO_INITIALIZE
+            passive = PassiveFlag.PASSIVE_NO_INITIALIZE
         else:
-            passive = attributes.PASSIVE_OFF
+            passive = PassiveFlag.PASSIVE_OFF
 
         if type_ == "save-update":
             tuples = state.manager[self.key].impl.get_all_pending(state, dict_)
-
         else:
             tuples = self._value_as_iterable(
                 state, dict_, self.key, passive=passive
@@ -1285,6 +1491,7 @@ class Relationship(
                 # see [ticket:2229]
                 continue
 
+            assert instance_state is not None
             instance_dict = attributes.instance_dict(c)
 
             if halt_on and halt_on(instance_state):
@@ -1308,14 +1515,16 @@ class Relationship(
             yield c, instance_mapper, instance_state, instance_dict
 
     @property
-    def _effective_sync_backref(self):
+    def _effective_sync_backref(self) -> bool:
         if self.viewonly:
             return False
         else:
             return self.sync_backref is not False
 
     @staticmethod
-    def _check_sync_backref(rel_a, rel_b):
+    def _check_sync_backref(
+        rel_a: Relationship[Any], rel_b: Relationship[Any]
+    ) -> None:
         if rel_a.viewonly and rel_b.sync_backref:
             raise sa_exc.InvalidRequestError(
                 "Relationship %s cannot specify sync_backref=True since %s "
@@ -1328,7 +1537,7 @@ class Relationship(
         ):
             rel_b.sync_backref = False
 
-    def _add_reverse_property(self, key):
+    def _add_reverse_property(self, key: str) -> None:
         other = self.mapper.get_property(key, _configure_mappers=False)
         if not isinstance(other, Relationship):
             raise sa_exc.InvalidRequestError(
@@ -1361,7 +1570,8 @@ class Relationship(
             )
 
         if (
-            self.direction in (ONETOMANY, MANYTOONE)
+            other._configure_started
+            and self.direction in (ONETOMANY, MANYTOONE)
             and self.direction == other.direction
         ):
             raise sa_exc.ArgumentError(
@@ -1372,7 +1582,7 @@ class Relationship(
             )
 
     @util.memoized_property
-    def entity(self) -> Union["Mapper", "AliasedInsp"]:
+    def entity(self) -> _InternalEntityType[_T]:
         """Return the target mapped entity, which is an inspect() of the
         class or aliased class that is referred towards.
 
@@ -1388,7 +1598,7 @@ class Relationship(
         """
         return self.entity.mapper
 
-    def do_init(self):
+    def do_init(self) -> None:
         self._check_conflicts()
         self._process_dependent_arguments()
         self._setup_entity()
@@ -1399,14 +1609,16 @@ class Relationship(
         self._generate_backref()
         self._join_condition._warn_for_conflicting_sync_targets()
         super(Relationship, self).do_init()
-        self._lazy_strategy = self._get_strategy((("lazy", "select"),))
+        self._lazy_strategy = cast(
+            "LazyLoader", self._get_strategy((("lazy", "select"),))
+        )
 
-    def _setup_registry_dependencies(self):
+    def _setup_registry_dependencies(self) -> None:
         self.parent.mapper.registry._set_depends_on(
             self.entity.mapper.registry
         )
 
-    def _process_dependent_arguments(self):
+    def _process_dependent_arguments(self) -> None:
         """Convert incoming configuration arguments to their
         proper form.
 
@@ -1417,78 +1629,80 @@ class Relationship(
         # accept callables for other attributes which may require
         # deferred initialization.  This technique is used
         # by declarative "string configs" and some recipes.
+        init_args = self._init_args
+
         for attr in (
             "order_by",
             "primaryjoin",
             "secondaryjoin",
             "secondary",
-            "_user_defined_foreign_keys",
+            "foreign_keys",
             "remote_side",
         ):
-            attr_value = getattr(self, attr)
-
-            if isinstance(attr_value, str):
-                setattr(
-                    self,
-                    attr,
-                    self._clsregistry_resolve_arg(
-                        attr_value, favor_tables=attr == "secondary"
-                    )(),
-                )
-            elif callable(attr_value) and not _is_mapped_class(attr_value):
-                setattr(self, attr, attr_value())
+
+            rel_arg = getattr(init_args, attr)
+
+            rel_arg._resolve_against_registry(self._clsregistry_resolvers[1])
 
         # remove "annotations" which are present if mapped class
         # descriptors are used to create the join expression.
         for attr in "primaryjoin", "secondaryjoin":
-            val = getattr(self, attr)
+            rel_arg = getattr(init_args, attr)
+            val = rel_arg.resolved
             if val is not None:
-                setattr(
-                    self,
-                    attr,
-                    _orm_deannotate(
-                        coercions.expect(
-                            roles.ColumnArgumentRole, val, argname=attr
-                        )
-                    ),
+                rel_arg.resolved = _orm_deannotate(
+                    coercions.expect(
+                        roles.ColumnArgumentRole, val, argname=attr
+                    )
                 )
 
-        if self.secondary is not None and _is_mapped_class(self.secondary):
+        secondary = init_args.secondary.resolved
+        if secondary is not None and _is_mapped_class(secondary):
             raise sa_exc.ArgumentError(
                 "secondary argument %s passed to to relationship() %s must "
                 "be a Table object or other FROM clause; can't send a mapped "
                 "class directly as rows in 'secondary' are persisted "
                 "independently of a class that is mapped "
-                "to that same table." % (self.secondary, self)
+                "to that same table." % (secondary, self)
             )
 
         # ensure expressions in self.order_by, foreign_keys,
         # remote_side are all columns, not strings.
-        if self.order_by is not False and self.order_by is not None:
+        if (
+            init_args.order_by.resolved is not False
+            and init_args.order_by.resolved is not None
+        ):
             self.order_by = tuple(
                 coercions.expect(
                     roles.ColumnArgumentRole, x, argname="order_by"
                 )
-                for x in util.to_list(self.order_by)
+                for x in util.to_list(init_args.order_by.resolved)
             )
+        else:
+            self.order_by = False
 
         self._user_defined_foreign_keys = util.column_set(
             coercions.expect(
                 roles.ColumnArgumentRole, x, argname="foreign_keys"
             )
-            for x in util.to_column_set(self._user_defined_foreign_keys)
+            for x in util.to_column_set(init_args.foreign_keys.resolved)
         )
 
         self.remote_side = util.column_set(
             coercions.expect(
                 roles.ColumnArgumentRole, x, argname="remote_side"
             )
-            for x in util.to_column_set(self.remote_side)
+            for x in util.to_column_set(init_args.remote_side.resolved)
         )
 
     def declarative_scan(
-        self, registry, cls, key, annotation, is_dataclass_field
-    ):
+        self,
+        registry: _RegistryType,
+        cls: Type[Any],
+        key: str,
+        annotation: Optional[_AnnotationScanType],
+        is_dataclass_field: bool,
+    ) -> None:
         argument = _extract_mapped_subtype(
             annotation,
             cls,
@@ -1502,17 +1716,19 @@ class Relationship(
 
         if hasattr(argument, "__origin__"):
 
-            collection_class = argument.__origin__
+            collection_class = argument.__origin__  # type: ignore
             if issubclass(collection_class, abc.Collection):
                 if self.collection_class is None:
                     self.collection_class = collection_class
             else:
                 self.uselist = False
-            if argument.__args__:
-                if issubclass(argument.__origin__, typing.Mapping):
-                    type_arg = argument.__args__[1]
+            if argument.__args__:  # type: ignore
+                if issubclass(
+                    argument.__origin__, typing.Mapping  # type: ignore
+                ):
+                    type_arg = argument.__args__[1]  # type: ignore
                 else:
-                    type_arg = argument.__args__[0]
+                    type_arg = argument.__args__[0]  # type: ignore
                 if hasattr(type_arg, "__forward_arg__"):
                     str_argument = type_arg.__forward_arg__
                     argument = str_argument
@@ -1523,12 +1739,12 @@ class Relationship(
                     f"Generic alias {argument} requires an argument"
                 )
         elif hasattr(argument, "__forward_arg__"):
-            argument = argument.__forward_arg__
+            argument = argument.__forward_arg__  # type: ignore
 
         self.argument = argument
 
     @util.preload_module("sqlalchemy.orm.mapper")
-    def _setup_entity(self, __argument=None):
+    def _setup_entity(self, __argument: Any = None) -> None:
         if "entity" in self.__dict__:
             return
 
@@ -1539,42 +1755,51 @@ class Relationship(
         else:
             argument = self.argument
 
+        resolved_argument: _ExternalEntityType[Any]
+
         if isinstance(argument, str):
-            argument = self._clsregistry_resolve_name(argument)()
+            # we might want to cleanup clsregistry API to make this
+            # more straightforward
+            resolved_argument = cast(
+                "_ExternalEntityType[Any]",
+                self._clsregistry_resolve_name(argument)(),
+            )
         elif callable(argument) and not isinstance(
             argument, (type, mapperlib.Mapper)
         ):
-            argument = argument()
+            resolved_argument = argument()
         else:
-            argument = argument
+            resolved_argument = argument
 
-        if isinstance(argument, type):
-            entity = class_mapper(argument, configure=False)
+        entity: _InternalEntityType[Any]
+
+        if isinstance(resolved_argument, type):
+            entity = class_mapper(resolved_argument, configure=False)
         else:
             try:
-                entity = inspect(argument)
+                entity = inspect(resolved_argument)
             except sa_exc.NoInspectionAvailable:
-                entity = None
+                entity = None  # type: ignore
 
             if not hasattr(entity, "mapper"):
                 raise sa_exc.ArgumentError(
                     "relationship '%s' expects "
                     "a class or a mapper argument (received: %s)"
-                    % (self.key, type(argument))
+                    % (self.key, type(resolved_argument))
                 )
 
         self.entity = entity  # type: ignore
         self.target = self.entity.persist_selectable
 
-    def _setup_join_conditions(self):
+    def _setup_join_conditions(self) -> None:
         self._join_condition = jc = JoinCondition(
             parent_persist_selectable=self.parent.persist_selectable,
             child_persist_selectable=self.entity.persist_selectable,
             parent_local_selectable=self.parent.local_table,
             child_local_selectable=self.entity.local_table,
-            primaryjoin=self.primaryjoin,
-            secondary=self.secondary,
-            secondaryjoin=self.secondaryjoin,
+            primaryjoin=self._init_args.primaryjoin.resolved,
+            secondary=self._init_args.secondary.resolved,
+            secondaryjoin=self._init_args.secondaryjoin.resolved,
             parent_equivalents=self.parent._equivalent_columns,
             child_equivalents=self.mapper._equivalent_columns,
             consider_as_foreign_keys=self._user_defined_foreign_keys,
@@ -1587,6 +1812,7 @@ class Relationship(
         )
         self.primaryjoin = jc.primaryjoin
         self.secondaryjoin = jc.secondaryjoin
+        self.secondary = jc.secondary
         self.direction = jc.direction
         self.local_remote_pairs = jc.local_remote_pairs
         self.remote_side = jc.remote_columns
@@ -1596,21 +1822,30 @@ class Relationship(
         self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
 
     @property
-    def _clsregistry_resolve_arg(self):
+    def _clsregistry_resolve_arg(
+        self,
+    ) -> Callable[[str, bool], _class_resolver]:
         return self._clsregistry_resolvers[1]
 
     @property
-    def _clsregistry_resolve_name(self):
+    def _clsregistry_resolve_name(
+        self,
+    ) -> Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]]:
         return self._clsregistry_resolvers[0]
 
     @util.memoized_property
     @util.preload_module("sqlalchemy.orm.clsregistry")
-    def _clsregistry_resolvers(self):
+    def _clsregistry_resolvers(
+        self,
+    ) -> Tuple[
+        Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+        Callable[[str, bool], _class_resolver],
+    ]:
         _resolver = util.preloaded.orm_clsregistry._resolver
 
         return _resolver(self.parent.class_, self)
 
-    def _check_conflicts(self):
+    def _check_conflicts(self) -> None:
         """Test that this relationship is legal, warn about
         inheritance conflicts."""
         if self.parent.non_primary and not class_mapper(
@@ -1637,10 +1872,10 @@ class Relationship(
         return self._cascade
 
     @cascade.setter
-    def cascade(self, cascade: Union[str, CascadeOptions]):
+    def cascade(self, cascade: Union[str, CascadeOptions]) -> None:
         self._set_cascade(cascade)
 
-    def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]):
+    def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]) -> None:
         cascade = CascadeOptions(cascade_arg)
 
         if self.viewonly:
@@ -1655,7 +1890,7 @@ class Relationship(
         if self._dependency_processor:
             self._dependency_processor.cascade = cascade
 
-    def _check_cascade_settings(self, cascade):
+    def _check_cascade_settings(self, cascade: CascadeOptions) -> None:
         if (
             cascade.delete_orphan
             and not self.single_parent
@@ -1699,7 +1934,7 @@ class Relationship(
                 (self.key, self.parent.class_)
             )
 
-    def _persists_for(self, mapper):
+    def _persists_for(self, mapper: Mapper[Any]) -> bool:
         """Return True if this property will persist values on behalf
         of the given mapper.
 
@@ -1710,16 +1945,15 @@ class Relationship(
             and mapper.relationships[self.key] is self
         )
 
-    def _columns_are_mapped(self, *cols):
+    def _columns_are_mapped(self, *cols: ColumnElement[Any]) -> bool:
         """Return True if all columns in the given collection are
         mapped by the tables referenced by this :class:`.Relationship`.
 
         """
+
+        secondary = self._init_args.secondary.resolved
         for c in cols:
-            if (
-                self.secondary is not None
-                and self.secondary.c.contains_column(c)
-            ):
+            if secondary is not None and secondary.c.contains_column(c):
                 continue
             if not self.parent.persist_selectable.c.contains_column(
                 c
@@ -1727,13 +1961,14 @@ class Relationship(
                 return False
         return True
 
-    def _generate_backref(self):
+    def _generate_backref(self) -> None:
         """Interpret the 'backref' instruction to create a
         :func:`_orm.relationship` complementary to this one."""
 
         if self.parent.non_primary:
             return
         if self.backref is not None and not self.back_populates:
+            kwargs: Dict[str, Any]
             if isinstance(self.backref, str):
                 backref_key, kwargs = self.backref, {}
             else:
@@ -1805,7 +2040,7 @@ class Relationship(
             self._add_reverse_property(self.back_populates)
 
     @util.preload_module("sqlalchemy.orm.dependency")
-    def _post_init(self):
+    def _post_init(self) -> None:
         dependency = util.preloaded.orm_dependency
 
         if self.uselist is None:
@@ -1816,7 +2051,7 @@ class Relationship(
             )(self)
 
     @util.memoized_property
-    def _use_get(self):
+    def _use_get(self) -> bool:
         """memoize the 'use_get' attribute of this RelationshipLoader's
         lazyloader."""
 
@@ -1824,18 +2059,25 @@ class Relationship(
         return strategy.use_get
 
     @util.memoized_property
-    def _is_self_referential(self):
+    def _is_self_referential(self) -> bool:
         return self.mapper.common_parent(self.parent)
 
     def _create_joins(
         self,
-        source_polymorphic=False,
-        source_selectable=None,
-        dest_selectable=None,
-        of_type_entity=None,
-        alias_secondary=False,
-        extra_criteria=(),
-    ):
+        source_polymorphic: bool = False,
+        source_selectable: Optional[FromClause] = None,
+        dest_selectable: Optional[FromClause] = None,
+        of_type_entity: Optional[_InternalEntityType[Any]] = None,
+        alias_secondary: bool = False,
+        extra_criteria: Tuple[ColumnElement[bool], ...] = (),
+    ) -> Tuple[
+        ColumnElement[bool],
+        Optional[ColumnElement[bool]],
+        FromClause,
+        FromClause,
+        Optional[FromClause],
+        Optional[ClauseAdapter],
+    ]:
 
         aliased = False
 
@@ -1905,38 +2147,56 @@ class Relationship(
         )
 
 
-def _annotate_columns(element, annotations):
-    def clone(elem):
+def _annotate_columns(element: _CE, annotations: _AnnotationDict) -> _CE:
+    def clone(elem: _CE) -> _CE:
         if isinstance(elem, expression.ColumnClause):
-            elem = elem._annotate(annotations.copy())
+            elem = elem._annotate(annotations.copy())  # type: ignore
         elem._copy_internals(clone=clone)
         return elem
 
     if element is not None:
         element = clone(element)
-    clone = None  # remove gc cycles
+    clone = None  # type: ignore # remove gc cycles
     return element
 
 
 class JoinCondition:
+
+    primaryjoin_initial: Optional[ColumnElement[bool]]
+    primaryjoin: ColumnElement[bool]
+    secondaryjoin: Optional[ColumnElement[bool]]
+    secondary: Optional[FromClause]
+    prop: Relationship[Any]
+
+    synchronize_pairs: _ColumnPairs
+    secondary_synchronize_pairs: _ColumnPairs
+    direction: RelationshipDirection
+
+    parent_persist_selectable: FromClause
+    child_persist_selectable: FromClause
+    parent_local_selectable: FromClause
+    child_local_selectable: FromClause
+
+    _local_remote_pairs: Optional[_ColumnPairs]
+
     def __init__(
         self,
-        parent_persist_selectable,
-        child_persist_selectable,
-        parent_local_selectable,
-        child_local_selectable,
-        primaryjoin=None,
-        secondary=None,
-        secondaryjoin=None,
-        parent_equivalents=None,
-        child_equivalents=None,
-        consider_as_foreign_keys=None,
-        local_remote_pairs=None,
-        remote_side=None,
-        self_referential=False,
-        prop=None,
-        support_sync=True,
-        can_be_synced_fn=lambda *c: True,
+        parent_persist_selectable: FromClause,
+        child_persist_selectable: FromClause,
+        parent_local_selectable: FromClause,
+        child_local_selectable: FromClause,
+        primaryjoin: Optional[ColumnElement[bool]] = None,
+        secondary: Optional[FromClause] = None,
+        secondaryjoin: Optional[ColumnElement[bool]] = None,
+        parent_equivalents: Optional[_EquivalentColumnMap] = None,
+        child_equivalents: Optional[_EquivalentColumnMap] = None,
+        consider_as_foreign_keys: Any = None,
+        local_remote_pairs: Optional[_ColumnPairs] = None,
+        remote_side: Any = None,
+        self_referential: Any = False,
+        prop: Optional[Relationship[Any]] = None,
+        support_sync: bool = True,
+        can_be_synced_fn: Callable[..., bool] = lambda *c: True,
     ):
         self.parent_persist_selectable = parent_persist_selectable
         self.parent_local_selectable = parent_local_selectable
@@ -1944,7 +2204,7 @@ class JoinCondition:
         self.child_local_selectable = child_local_selectable
         self.parent_equivalents = parent_equivalents
         self.child_equivalents = child_equivalents
-        self.primaryjoin = primaryjoin
+        self.primaryjoin_initial = primaryjoin
         self.secondaryjoin = secondaryjoin
         self.secondary = secondary
         self.consider_as_foreign_keys = consider_as_foreign_keys
@@ -1954,7 +2214,10 @@ class JoinCondition:
         self.self_referential = self_referential
         self.support_sync = support_sync
         self.can_be_synced_fn = can_be_synced_fn
+
         self._determine_joins()
+        assert self.primaryjoin is not None
+
         self._sanitize_joins()
         self._annotate_fks()
         self._annotate_remote()
@@ -1968,7 +2231,7 @@ class JoinCondition:
         self._check_remote_side()
         self._log_joins()
 
-    def _log_joins(self):
+    def _log_joins(self) -> None:
         if self.prop is None:
             return
         log = self.prop.logger
@@ -2008,7 +2271,7 @@ class JoinCondition:
         )
         log.info("%s relationship direction %s", self.prop, self.direction)
 
-    def _sanitize_joins(self):
+    def _sanitize_joins(self) -> None:
         """remove the parententity annotation from our join conditions which
         can leak in here based on some declarative patterns and maybe others.
 
@@ -2026,7 +2289,7 @@ class JoinCondition:
                 self.secondaryjoin, values=("parententity", "proxy_key")
             )
 
-    def _determine_joins(self):
+    def _determine_joins(self) -> None:
         """Determine the 'primaryjoin' and 'secondaryjoin' attributes,
         if not passed to the constructor already.
 
@@ -2056,21 +2319,25 @@ class JoinCondition:
                         a_subset=self.child_local_selectable,
                         consider_as_foreign_keys=consider_as_foreign_keys,
                     )
-                if self.primaryjoin is None:
+                if self.primaryjoin_initial is None:
                     self.primaryjoin = join_condition(
                         self.parent_persist_selectable,
                         self.secondary,
                         a_subset=self.parent_local_selectable,
                         consider_as_foreign_keys=consider_as_foreign_keys,
                     )
+                else:
+                    self.primaryjoin = self.primaryjoin_initial
             else:
-                if self.primaryjoin is None:
+                if self.primaryjoin_initial is None:
                     self.primaryjoin = join_condition(
                         self.parent_persist_selectable,
                         self.child_persist_selectable,
                         a_subset=self.parent_local_selectable,
                         consider_as_foreign_keys=consider_as_foreign_keys,
                     )
+                else:
+                    self.primaryjoin = self.primaryjoin_initial
         except sa_exc.NoForeignKeysError as nfe:
             if self.secondary is not None:
                 raise sa_exc.NoForeignKeysError(
@@ -2118,15 +2385,16 @@ class JoinCondition:
                 ) from afe
 
     @property
-    def primaryjoin_minus_local(self):
+    def primaryjoin_minus_local(self) -> ColumnElement[bool]:
         return _deep_deannotate(self.primaryjoin, values=("local", "remote"))
 
     @property
-    def secondaryjoin_minus_local(self):
+    def secondaryjoin_minus_local(self) -> ColumnElement[bool]:
+        assert self.secondaryjoin is not None
         return _deep_deannotate(self.secondaryjoin, values=("local", "remote"))
 
     @util.memoized_property
-    def primaryjoin_reverse_remote(self):
+    def primaryjoin_reverse_remote(self) -> ColumnElement[bool]:
         """Return the primaryjoin condition suitable for the
         "reverse" direction.
 
@@ -2138,7 +2406,7 @@ class JoinCondition:
         """
         if self._has_remote_annotations:
 
-            def replace(element):
+            def replace(element: _CE, **kw: Any) -> Optional[_CE]:
                 if "remote" in element._annotations:
                     v = dict(element._annotations)
                     del v["remote"]
@@ -2150,6 +2418,8 @@ class JoinCondition:
                     v["remote"] = True
                     return element._with_annotations(v)
 
+                return None
+
             return visitors.replacement_traverse(self.primaryjoin, {}, replace)
         else:
             if self._has_foreign_annotations:
@@ -2160,7 +2430,7 @@ class JoinCondition:
             else:
                 return _deep_deannotate(self.primaryjoin)
 
-    def _has_annotation(self, clause, annotation):
+    def _has_annotation(self, clause: ClauseElement, annotation: str) -> bool:
         for col in visitors.iterate(clause, {}):
             if annotation in col._annotations:
                 return True
@@ -2168,14 +2438,14 @@ class JoinCondition:
             return False
 
     @util.memoized_property
-    def _has_foreign_annotations(self):
+    def _has_foreign_annotations(self) -> bool:
         return self._has_annotation(self.primaryjoin, "foreign")
 
     @util.memoized_property
-    def _has_remote_annotations(self):
+    def _has_remote_annotations(self) -> bool:
         return self._has_annotation(self.primaryjoin, "remote")
 
-    def _annotate_fks(self):
+    def _annotate_fks(self) -> None:
         """Annotate the primaryjoin and secondaryjoin
         structures with 'foreign' annotations marking columns
         considered as foreign.
@@ -2189,10 +2459,11 @@ class JoinCondition:
         else:
             self._annotate_present_fks()
 
-    def _annotate_from_fk_list(self):
-        def check_fk(col):
-            if col in self.consider_as_foreign_keys:
-                return col._annotate({"foreign": True})
+    def _annotate_from_fk_list(self) -> None:
+        def check_fk(element: _CE, **kw: Any) -> Optional[_CE]:
+            if element in self.consider_as_foreign_keys:
+                return element._annotate({"foreign": True})
+            return None
 
         self.primaryjoin = visitors.replacement_traverse(
             self.primaryjoin, {}, check_fk
@@ -2202,13 +2473,15 @@ class JoinCondition:
                 self.secondaryjoin, {}, check_fk
             )
 
-    def _annotate_present_fks(self):
+    def _annotate_present_fks(self) -> None:
         if self.secondary is not None:
             secondarycols = util.column_set(self.secondary.c)
         else:
             secondarycols = set()
 
-        def is_foreign(a, b):
+        def is_foreign(
+            a: ColumnElement[Any], b: ColumnElement[Any]
+        ) -> Optional[ColumnElement[Any]]:
             if isinstance(a, schema.Column) and isinstance(b, schema.Column):
                 if a.references(b):
                     return a
@@ -2221,7 +2494,9 @@ class JoinCondition:
                 elif b in secondarycols and a not in secondarycols:
                     return b
 
-        def visit_binary(binary):
+            return None
+
+        def visit_binary(binary: BinaryExpression[Any]) -> None:
             if not isinstance(
                 binary.left, sql.ColumnElement
             ) or not isinstance(binary.right, sql.ColumnElement):
@@ -2248,16 +2523,17 @@ class JoinCondition:
                 self.secondaryjoin, {}, {"binary": visit_binary}
             )
 
-    def _refers_to_parent_table(self):
+    def _refers_to_parent_table(self) -> bool:
         """Return True if the join condition contains column
         comparisons where both columns are in both tables.
 
         """
         pt = self.parent_persist_selectable
         mt = self.child_persist_selectable
-        result = [False]
+        result = False
 
-        def visit_binary(binary):
+        def visit_binary(binary: BinaryExpression[Any]) -> None:
+            nonlocal result
             c, f = binary.left, binary.right
             if (
                 isinstance(c, expression.ColumnClause)
@@ -2267,19 +2543,19 @@ class JoinCondition:
                 and mt.is_derived_from(c.table)
                 and mt.is_derived_from(f.table)
             ):
-                result[0] = True
+                result = True
 
         visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary})
-        return result[0]
+        return result
 
-    def _tables_overlap(self):
+    def _tables_overlap(self) -> bool:
         """Return True if parent/child tables have some overlap."""
 
         return selectables_overlap(
             self.parent_persist_selectable, self.child_persist_selectable
         )
 
-    def _annotate_remote(self):
+    def _annotate_remote(self) -> None:
         """Annotate the primaryjoin and secondaryjoin
         structures with 'remote' annotations marking columns
         considered as part of the 'remote' side.
@@ -2301,30 +2577,38 @@ class JoinCondition:
         else:
             self._annotate_remote_distinct_selectables()
 
-    def _annotate_remote_secondary(self):
+    def _annotate_remote_secondary(self) -> None:
         """annotate 'remote' in primaryjoin, secondaryjoin
         when 'secondary' is present.
 
         """
 
-        def repl(element):
-            if self.secondary.c.contains_column(element):
+        assert self.secondary is not None
+        fixed_secondary = self.secondary
+
+        def repl(element: _CE, **kw: Any) -> Optional[_CE]:
+            if fixed_secondary.c.contains_column(element):
                 return element._annotate({"remote": True})
+            return None
 
         self.primaryjoin = visitors.replacement_traverse(
             self.primaryjoin, {}, repl
         )
+
+        assert self.secondaryjoin is not None
         self.secondaryjoin = visitors.replacement_traverse(
             self.secondaryjoin, {}, repl
         )
 
-    def _annotate_selfref(self, fn, remote_side_given):
+    def _annotate_selfref(
+        self, fn: Callable[[ColumnElement[Any]], bool], remote_side_given: bool
+    ) -> None:
         """annotate 'remote' in primaryjoin, secondaryjoin
         when the relationship is detected as self-referential.
 
         """
 
-        def visit_binary(binary):
+        def visit_binary(binary: BinaryExpression[Any]) -> None:
             equated = binary.left.compare(binary.right)
             if isinstance(binary.left, expression.ColumnClause) and isinstance(
                 binary.right, expression.ColumnClause
@@ -2341,7 +2625,7 @@ class JoinCondition:
             self.primaryjoin, {}, {"binary": visit_binary}
         )
 
-    def _annotate_remote_from_args(self):
+    def _annotate_remote_from_args(self) -> None:
         """annotate 'remote' in primaryjoin, secondaryjoin
         when the 'remote_side' or '_local_remote_pairs'
         arguments are used.
@@ -2363,17 +2647,18 @@ class JoinCondition:
             self._annotate_selfref(lambda col: col in remote_side, True)
         else:
 
-            def repl(element):
+            def repl(element: _CE, **kw: Any) -> Optional[_CE]:
                 # use set() to avoid generating ``__eq__()`` expressions
                 # against each element
                 if element in set(remote_side):
                     return element._annotate({"remote": True})
+                return None
 
             self.primaryjoin = visitors.replacement_traverse(
                 self.primaryjoin, {}, repl
             )
 
-    def _annotate_remote_with_overlap(self):
+    def _annotate_remote_with_overlap(self) -> None:
         """annotate 'remote' in primaryjoin, secondaryjoin
         when the parent/child tables have some set of
         tables in common, though is not a fully self-referential
@@ -2381,7 +2666,7 @@ class JoinCondition:
 
         """
 
-        def visit_binary(binary):
+        def visit_binary(binary: BinaryExpression[Any]) -> None:
             binary.left, binary.right = proc_left_right(
                 binary.left, binary.right
             )
@@ -2393,7 +2678,9 @@ class JoinCondition:
             self.prop is not None and self.prop.mapper is not self.prop.parent
         )
 
-        def proc_left_right(left, right):
+        def proc_left_right(
+            left: ColumnElement[Any], right: ColumnElement[Any]
+        ) -> Tuple[ColumnElement[Any], ColumnElement[Any]]:
             if isinstance(left, expression.ColumnClause) and isinstance(
                 right, expression.ColumnClause
             ):
@@ -2420,32 +2707,33 @@ class JoinCondition:
             self.primaryjoin, {}, {"binary": visit_binary}
         )
 
-    def _annotate_remote_distinct_selectables(self):
+    def _annotate_remote_distinct_selectables(self) -> None:
         """annotate 'remote' in primaryjoin, secondaryjoin
         when the parent/child tables are entirely
         separate.
 
         """
 
-        def repl(element):
+        def repl(element: _CE, **kw: Any) -> Optional[_CE]:
             if self.child_persist_selectable.c.contains_column(element) and (
                 not self.parent_local_selectable.c.contains_column(element)
                 or self.child_local_selectable.c.contains_column(element)
             ):
                 return element._annotate({"remote": True})
+            return None
 
         self.primaryjoin = visitors.replacement_traverse(
             self.primaryjoin, {}, repl
         )
 
-    def _warn_non_column_elements(self):
+    def _warn_non_column_elements(self) -> None:
         util.warn(
             "Non-simple column elements in primary "
             "join condition for property %s - consider using "
             "remote() annotations to mark the remote side." % self.prop
         )
 
-    def _annotate_local(self):
+    def _annotate_local(self) -> None:
         """Annotate the primaryjoin and secondaryjoin
         structures with 'local' annotations.
 
@@ -2466,29 +2754,31 @@ class JoinCondition:
         else:
             local_side = util.column_set(self.parent_persist_selectable.c)
 
-        def locals_(elem):
-            if "remote" not in elem._annotations and elem in local_side:
-                return elem._annotate({"local": True})
+        def locals_(element: _CE, **kw: Any) -> Optional[_CE]:
+            if "remote" not in element._annotations and element in local_side:
+                return element._annotate({"local": True})
+            return None
 
         self.primaryjoin = visitors.replacement_traverse(
             self.primaryjoin, {}, locals_
         )
 
-    def _annotate_parentmapper(self):
+    def _annotate_parentmapper(self) -> None:
         if self.prop is None:
             return
 
-        def parentmappers_(elem):
-            if "remote" in elem._annotations:
-                return elem._annotate({"parentmapper": self.prop.mapper})
-            elif "local" in elem._annotations:
-                return elem._annotate({"parentmapper": self.prop.parent})
+        def parentmappers_(element: _CE, **kw: Any) -> Optional[_CE]:
+            if "remote" in element._annotations:
+                return element._annotate({"parentmapper": self.prop.mapper})
+            elif "local" in element._annotations:
+                return element._annotate({"parentmapper": self.prop.parent})
+            return None
 
         self.primaryjoin = visitors.replacement_traverse(
             self.primaryjoin, {}, parentmappers_
         )
 
-    def _check_remote_side(self):
+    def _check_remote_side(self) -> None:
         if not self.local_remote_pairs:
             raise sa_exc.ArgumentError(
                 "Relationship %s could "
@@ -2501,7 +2791,9 @@ class JoinCondition:
                 "the relationship." % (self.prop,)
             )
 
-    def _check_foreign_cols(self, join_condition, primary):
+    def _check_foreign_cols(
+        self, join_condition: ColumnElement[bool], primary: bool
+    ) -> None:
         """Check the foreign key columns collected and emit error
         messages."""
 
@@ -2567,7 +2859,7 @@ class JoinCondition:
             )
             raise sa_exc.ArgumentError(err)
 
-    def _determine_direction(self):
+    def _determine_direction(self) -> None:
         """Determine if this relationship is one to many, many to one,
         many to many.
 
@@ -2651,7 +2943,9 @@ class JoinCondition:
                     "nor the child's mapped tables" % self.prop
                 )
 
-    def _deannotate_pairs(self, collection):
+    def _deannotate_pairs(
+        self, collection: _ColumnPairIterable
+    ) -> _MutableColumnPairs:
         """provide deannotation for the various lists of
         pairs, so that using them in hashes doesn't incur
         high-overhead __eq__() comparisons against
@@ -2660,13 +2954,22 @@ class JoinCondition:
         """
         return [(x._deannotate(), y._deannotate()) for x, y in collection]
 
-    def _setup_pairs(self):
-        sync_pairs = []
-        lrp = util.OrderedSet([])
-        secondary_sync_pairs = []
-
-        def go(joincond, collection):
-            def visit_binary(binary, left, right):
+    def _setup_pairs(self) -> None:
+        sync_pairs: _MutableColumnPairs = []
+        lrp: util.OrderedSet[
+            Tuple[ColumnElement[Any], ColumnElement[Any]]
+        ] = util.OrderedSet([])
+        secondary_sync_pairs: _MutableColumnPairs = []
+
+        def go(
+            joincond: ColumnElement[bool],
+            collection: _MutableColumnPairs,
+        ) -> None:
+            def visit_binary(
+                binary: BinaryExpression[Any],
+                left: ColumnElement[Any],
+                right: ColumnElement[Any],
+            ) -> None:
                 if (
                     "remote" in right._annotations
                     and "remote" not in left._annotations
@@ -2703,9 +3006,12 @@ class JoinCondition:
             secondary_sync_pairs
         )
 
-    _track_overlapping_sync_targets = weakref.WeakKeyDictionary()
+    _track_overlapping_sync_targets: weakref.WeakKeyDictionary[
+        ColumnElement[Any],
+        weakref.WeakKeyDictionary[Relationship[Any], ColumnElement[Any]],
+    ] = weakref.WeakKeyDictionary()
 
-    def _warn_for_conflicting_sync_targets(self):
+    def _warn_for_conflicting_sync_targets(self) -> None:
         if not self.support_sync:
             return
 
@@ -2793,18 +3099,20 @@ class JoinCondition:
                 self._track_overlapping_sync_targets[to_][self.prop] = from_
 
     @util.memoized_property
-    def remote_columns(self):
+    def remote_columns(self) -> Set[ColumnElement[Any]]:
         return self._gather_join_annotations("remote")
 
     @util.memoized_property
-    def local_columns(self):
+    def local_columns(self) -> Set[ColumnElement[Any]]:
         return self._gather_join_annotations("local")
 
     @util.memoized_property
-    def foreign_key_columns(self):
+    def foreign_key_columns(self) -> Set[ColumnElement[Any]]:
         return self._gather_join_annotations("foreign")
 
-    def _gather_join_annotations(self, annotation):
+    def _gather_join_annotations(
+        self, annotation: str
+    ) -> Set[ColumnElement[Any]]:
         s = set(
             self._gather_columns_with_annotation(self.primaryjoin, annotation)
         )
@@ -2816,24 +3124,32 @@ class JoinCondition:
             )
         return {x._deannotate() for x in s}
 
-    def _gather_columns_with_annotation(self, clause, *annotation):
-        annotation = set(annotation)
+    def _gather_columns_with_annotation(
+        self, clause: ColumnElement[Any], *annotation: Iterable[str]
+    ) -> Set[ColumnElement[Any]]:
+        annotation_set = set(annotation)
         return set(
             [
-                col
+                cast(ColumnElement[Any], col)
                 for col in visitors.iterate(clause, {})
-                if annotation.issubset(col._annotations)
+                if annotation_set.issubset(col._annotations)
             ]
         )
 
     def join_targets(
         self,
-        source_selectable,
-        dest_selectable,
-        aliased,
-        single_crit=None,
-        extra_criteria=(),
-    ):
+        source_selectable: Optional[FromClause],
+        dest_selectable: FromClause,
+        aliased: bool,
+        single_crit: Optional[ColumnElement[bool]] = None,
+        extra_criteria: Tuple[ColumnElement[bool], ...] = (),
+    ) -> Tuple[
+        ColumnElement[bool],
+        Optional[ColumnElement[bool]],
+        Optional[FromClause],
+        Optional[ClauseAdapter],
+        FromClause,
+    ]:
         """Given a source and destination selectable, create a
         join between them.
 
@@ -2923,9 +3239,15 @@ class JoinCondition:
             dest_selectable,
         )
 
-    def create_lazy_clause(self, reverse_direction=False):
-        binds = util.column_dict()
-        equated_columns = util.column_dict()
+    def create_lazy_clause(
+        self, reverse_direction: bool = False
+    ) -> Tuple[
+        ColumnElement[bool],
+        Dict[str, ColumnElement[Any]],
+        Dict[ColumnElement[Any], ColumnElement[Any]],
+    ]:
+        binds: Dict[ColumnElement[Any], BindParameter[Any]] = {}
+        equated_columns: Dict[ColumnElement[Any], ColumnElement[Any]] = {}
 
         has_secondary = self.secondaryjoin is not None
 
@@ -2941,21 +3263,23 @@ class JoinCondition:
             for l, r in self.local_remote_pairs:
                 equated_columns[l] = r
 
-        def col_to_bind(col):
+        def col_to_bind(
+            element: ColumnElement[Any], **kw: Any
+        ) -> Optional[BindParameter[Any]]:
 
             if (
-                (not reverse_direction and "local" in col._annotations)
+                (not reverse_direction and "local" in element._annotations)
                 or reverse_direction
                 and (
-                    (has_secondary and col in lookup)
-                    or (not has_secondary and "remote" in col._annotations)
+                    (has_secondary and element in lookup)
+                    or (not has_secondary and "remote" in element._annotations)
                 )
             ):
-                if col not in binds:
-                    binds[col] = sql.bindparam(
-                        None, None, type_=col.type, unique=True
+                if element not in binds:
+                    binds[element] = sql.bindparam(
+                        None, None, type_=element.type, unique=True
                     )
-                return binds[col]
+                return binds[element]
             return None
 
         lazywhere = self.primaryjoin
@@ -2982,8 +3306,8 @@ class _ColInAnnotations:
 
     __slots__ = ("name",)
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         self.name = name
 
-    def __call__(self, c):
+    def __call__(self, c: ClauseElement) -> bool:
         return self.name in c._annotations
index b5491248b81e8ffde45c82c7b0140b237f73cec1..d72e78c9e69be26a0aabde5d7a23766b59d36402 100644 (file)
@@ -118,6 +118,7 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _T7
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
+    from ..sql.base import ExecutableOption
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import TypedReturnsRows
@@ -765,7 +766,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
         self.session.dispatch.after_transaction_create(self.session, self)
 
     def _raise_for_prerequisite_state(
-        self, operation_name: str, state: SessionTransactionState
+        self, operation_name: str, state: _StateChangeState
     ) -> NoReturn:
         if state is SessionTransactionState.DEACTIVE:
             if self._rollback_exception:
@@ -3183,7 +3184,7 @@ class Session(_SessionClassMethods, EventTarget):
         primary_key_identity: _PKIdentityArgument,
         db_load_fn: Callable[..., _O],
         *,
-        options: Optional[Sequence[ORMOption]] = None,
+        options: Optional[Sequence[ExecutableOption]] = None,
         populate_existing: bool = False,
         with_for_update: Optional[ForUpdateArg] = None,
         identity_token: Optional[Any] = None,
@@ -3377,7 +3378,7 @@ class Session(_SessionClassMethods, EventTarget):
         *,
         options: Optional[Sequence[ORMOption]] = None,
         load: bool,
-        _recursive: Dict[InstanceState[Any], object],
+        _recursive: Dict[Any, object],
         _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
     ) -> _O:
         mapper: Mapper[_O] = _state_mapper(state)
index cb8b1f4aad9b66b1e74271ff4a93ef116e5f4eb4..af9f4870662931423e7dd581f2764d2f2f1c4735 100644 (file)
@@ -82,6 +82,22 @@ class _InstanceDictProto(Protocol):
         ...
 
 
+class _InstallLoaderCallableProto(Protocol[_O]):
+    """used at result loading time to install a _LoaderCallable callable
+    upon a specific InstanceState, which will be used to populate an
+    attribute when that attribute is accessed.
+
+    Concrete examples are per-instance deferred column loaders and
+    relationship lazy loaders.
+
+    """
+
+    def __call__(
+        self, state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
+    ) -> None:
+        ...
+
+
 @inspection._self_inspects
 class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
     """tracks state information at the instance level.
@@ -658,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
     @classmethod
     def _instance_level_callable_processor(
         cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any
-    ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]:
+    ) -> _InstallLoaderCallableProto[_O]:
         impl = manager[key].impl
         if is_collection_impl(impl):
             fixed_impl = impl
index b7bf96558534e5aa7ad9923e6aab76b6bdeed734..764b5dfa6bbf3ed3a6c8a96beef9d18554509de1 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+
 """State tracking utilities used by :class:`_orm.Session`.
 
 """
@@ -14,6 +15,9 @@ import contextlib
 from enum import Enum
 from typing import Any
 from typing import Callable
+from typing import cast
+from typing import Iterator
+from typing import NoReturn
 from typing import Optional
 from typing import Tuple
 from typing import TypeVar
@@ -48,9 +52,11 @@ class _StateChange:
 
     _next_state: _StateChangeState = _StateChangeStates.ANY
     _state: _StateChangeState = _StateChangeStates.NO_CHANGE
-    _current_fn: Optional[Callable] = None
+    _current_fn: Optional[Callable[..., Any]] = None
 
-    def _raise_for_prerequisite_state(self, operation_name, state):
+    def _raise_for_prerequisite_state(
+        self, operation_name: str, state: _StateChangeState
+    ) -> NoReturn:
         raise sa_exc.IllegalStateChangeError(
             f"Can't run operation '{operation_name}()' when Session "
             f"is in state {state!r}"
@@ -80,16 +86,19 @@ class _StateChange:
             prerequisite_states is not _StateChangeStates.ANY
         )
 
+        prerequisite_state_collection = cast(
+            "Tuple[_StateChangeState, ...]", prerequisite_states
+        )
         expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE
 
         @util.decorator
-        def _go(fn, self, *arg, **kw):
+        def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
 
             current_state = self._state
 
             if (
                 has_prerequisite_states
-                and current_state not in prerequisite_states
+                and current_state not in prerequisite_state_collection
             ):
                 self._raise_for_prerequisite_state(fn.__name__, current_state)
 
@@ -159,7 +168,7 @@ class _StateChange:
         return _go
 
     @contextlib.contextmanager
-    def _expect_state(self, expected: _StateChangeState):
+    def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]:
         """called within a method that changes states.
 
         method must also use the ``@declare_states()`` decorator.
index 0ba22e7a7cbf9ae9fc1eba1c2cb1cc1be84e6de1..5dc80e4f285d069f55ae508c4b32ef86f3b304d2 100644 (file)
@@ -14,6 +14,10 @@ from __future__ import annotations
 
 import collections
 import itertools
+from typing import Any
+from typing import Dict
+from typing import Tuple
+from typing import TYPE_CHECKING
 
 from . import attributes
 from . import exc as orm_exc
@@ -28,7 +32,9 @@ from . import util as orm_util
 from .base import _DEFER_FOR_STATE
 from .base import _RAISE_FOR_STATE
 from .base import _SET_DEFERRED_EXPIRED
+from .base import LoaderCallableStatus
 from .base import PASSIVE_OFF
+from .base import PassiveFlag
 from .context import _column_descriptions
 from .context import ORMCompileState
 from .context import ORMSelectCompileState
@@ -50,6 +56,10 @@ from ..sql import visitors
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..sql.selectable import Select
 
+if TYPE_CHECKING:
+    from .relationships import Relationship
+    from ..sql.elements import ColumnElement
+
 
 def _register_attribute(
     prop,
@@ -486,10 +496,10 @@ class DeferredColumnLoader(LoaderStrategy):
 
     def _load_for_state(self, state, passive):
         if not state.key:
-            return attributes.ATTR_EMPTY
+            return LoaderCallableStatus.ATTR_EMPTY
 
-        if not passive & attributes.SQL_OK:
-            return attributes.PASSIVE_NO_RESULT
+        if not passive & PassiveFlag.SQL_OK:
+            return LoaderCallableStatus.PASSIVE_NO_RESULT
 
         localparent = state.manager.mapper
 
@@ -522,7 +532,7 @@ class DeferredColumnLoader(LoaderStrategy):
             state.mapper, state, set(group), PASSIVE_OFF
         )
 
-        return attributes.ATTR_WAS_SET
+        return LoaderCallableStatus.ATTR_WAS_SET
 
     def _invoke_raise_load(self, state, passive, lazy):
         raise sa_exc.InvalidRequestError(
@@ -626,7 +636,9 @@ class NoLoader(AbstractRelationshipLoader):
 @relationships.Relationship.strategy_for(lazy="raise")
 @relationships.Relationship.strategy_for(lazy="raise_on_sql")
 @relationships.Relationship.strategy_for(lazy="baked_select")
-class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
+class LazyLoader(
+    AbstractRelationshipLoader, util.MemoizedSlots, log.Identified
+):
     """Provide loading behavior for a :class:`.Relationship`
     with "lazy=True", that is loads when first accessed.
 
@@ -648,7 +660,16 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         "_raise_on_sql",
     )
 
-    def __init__(self, parent, strategy_key):
+    _lazywhere: ColumnElement[bool]
+    _bind_to_col: Dict[str, ColumnElement[Any]]
+    _rev_lazywhere: ColumnElement[bool]
+    _rev_bind_to_col: Dict[str, ColumnElement[Any]]
+
+    parent_property: Relationship[Any]
+
+    def __init__(
+        self, parent: Relationship[Any], strategy_key: Tuple[Any, ...]
+    ):
         super(LazyLoader, self).__init__(parent, strategy_key)
         self._raise_always = self.strategy_opts["lazy"] == "raise"
         self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql"
@@ -786,13 +807,13 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         o = state.obj()  # strong ref
         dict_ = attributes.instance_dict(o)
 
-        if passive & attributes.INIT_OK:
-            passive ^= attributes.INIT_OK
+        if passive & PassiveFlag.INIT_OK:
+            passive ^= PassiveFlag.INIT_OK
 
         params = {}
         for key, ident, value in param_keys:
             if ident is not None:
-                if passive and passive & attributes.LOAD_AGAINST_COMMITTED:
+                if passive and passive & PassiveFlag.LOAD_AGAINST_COMMITTED:
                     value = mapper._get_committed_state_attr_by_column(
                         state, dict_, ident, passive
                     )
@@ -818,23 +839,23 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
             )
             or not state.session_id
         ):
-            return attributes.ATTR_EMPTY
+            return LoaderCallableStatus.ATTR_EMPTY
 
         pending = not state.key
         primary_key_identity = None
 
         use_get = self.use_get and (not loadopt or not loadopt._extra_criteria)
 
-        if (not passive & attributes.SQL_OK and not use_get) or (
+        if (not passive & PassiveFlag.SQL_OK and not use_get) or (
             not passive & attributes.NON_PERSISTENT_OK and pending
         ):
-            return attributes.PASSIVE_NO_RESULT
+            return LoaderCallableStatus.PASSIVE_NO_RESULT
 
         if (
             # we were given lazy="raise"
             self._raise_always
             # the no_raise history-related flag was not passed
-            and not passive & attributes.NO_RAISE
+            and not passive & PassiveFlag.NO_RAISE
             and (
                 # if we are use_get and related_object_ok is disabled,
                 # which means we are at most looking in the identity map
@@ -842,7 +863,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
                 # PASSIVE_NO_RESULT, don't raise.  This is also a
                 # history-related flag
                 not use_get
-                or passive & attributes.RELATED_OBJECT_OK
+                or passive & PassiveFlag.RELATED_OBJECT_OK
             )
         ):
 
@@ -850,8 +871,8 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
 
         session = _state_session(state)
         if not session:
-            if passive & attributes.NO_RAISE:
-                return attributes.PASSIVE_NO_RESULT
+            if passive & PassiveFlag.NO_RAISE:
+                return LoaderCallableStatus.PASSIVE_NO_RESULT
 
             raise orm_exc.DetachedInstanceError(
                 "Parent instance %s is not bound to a Session; "
@@ -865,19 +886,19 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
             primary_key_identity = self._get_ident_for_use_get(
                 session, state, passive
             )
-            if attributes.PASSIVE_NO_RESULT in primary_key_identity:
-                return attributes.PASSIVE_NO_RESULT
-            elif attributes.NEVER_SET in primary_key_identity:
-                return attributes.NEVER_SET
+            if LoaderCallableStatus.PASSIVE_NO_RESULT in primary_key_identity:
+                return LoaderCallableStatus.PASSIVE_NO_RESULT
+            elif LoaderCallableStatus.NEVER_SET in primary_key_identity:
+                return LoaderCallableStatus.NEVER_SET
 
             if _none_set.issuperset(primary_key_identity):
                 return None
 
             if (
                 self.key in state.dict
-                and not passive & attributes.DEFERRED_HISTORY_LOAD
+                and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD
             ):
-                return attributes.ATTR_WAS_SET
+                return LoaderCallableStatus.ATTR_WAS_SET
 
             # look for this identity in the identity map.  Delegate to the
             # Query class in use, as it may have special rules for how it
@@ -892,15 +913,15 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
             )
 
             if instance is not None:
-                if instance is attributes.PASSIVE_CLASS_MISMATCH:
+                if instance is LoaderCallableStatus.PASSIVE_CLASS_MISMATCH:
                     return None
                 else:
                     return instance
             elif (
-                not passive & attributes.SQL_OK
-                or not passive & attributes.RELATED_OBJECT_OK
+                not passive & PassiveFlag.SQL_OK
+                or not passive & PassiveFlag.RELATED_OBJECT_OK
             ):
-                return attributes.PASSIVE_NO_RESULT
+                return LoaderCallableStatus.PASSIVE_NO_RESULT
 
         return self._emit_lazyload(
             session,
@@ -914,7 +935,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
     def _get_ident_for_use_get(self, session, state, passive):
         instance_mapper = state.manager.mapper
 
-        if passive & attributes.LOAD_AGAINST_COMMITTED:
+        if passive & PassiveFlag.LOAD_AGAINST_COMMITTED:
             get_attr = instance_mapper._get_committed_state_attr_by_column
         else:
             get_attr = instance_mapper._get_state_attr_by_column
@@ -985,7 +1006,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         stmt._compile_options += {"_current_path": effective_path}
 
         if use_get:
-            if self._raise_on_sql and not passive & attributes.NO_RAISE:
+            if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE:
                 self._invoke_raise_load(state, passive, "raise_on_sql")
 
             return loading.load_on_pk_identity(
@@ -1022,9 +1043,9 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
 
         if (
             self.key in state.dict
-            and not passive & attributes.DEFERRED_HISTORY_LOAD
+            and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD
         ):
-            return attributes.ATTR_WAS_SET
+            return LoaderCallableStatus.ATTR_WAS_SET
 
         if pending:
             if util.has_intersection(orm_util._none_set, params.values()):
@@ -1033,7 +1054,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         elif util.has_intersection(orm_util._never_set, params.values()):
             return None
 
-        if self._raise_on_sql and not passive & attributes.NO_RAISE:
+        if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE:
             self._invoke_raise_load(state, passive, "raise_on_sql")
 
         stmt._where_criteria = (lazy_clause,)
@@ -1246,9 +1267,9 @@ class ImmediateLoader(PostLoader):
             # "use get" load.   the "_RELATED" part means it may return
             # instance even if its expired, since this is a mutually-recursive
             # load operation.
-            flags = attributes.PASSIVE_NO_FETCH_RELATED | attributes.NO_RAISE
+            flags = attributes.PASSIVE_NO_FETCH_RELATED | PassiveFlag.NO_RAISE
         else:
-            flags = attributes.PASSIVE_OFF | attributes.NO_RAISE
+            flags = attributes.PASSIVE_OFF | PassiveFlag.NO_RAISE
 
         populators["delayed"].append((self.key, load_immediate))
 
@@ -2840,7 +2861,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
                 # if the loaded parent objects do not have the foreign key
                 # to the related item loaded, then degrade into the joined
                 # version of selectinload
-                if attributes.PASSIVE_NO_RESULT in related_ident:
+                if LoaderCallableStatus.PASSIVE_NO_RESULT in related_ident:
                     query_info = self._fallback_query_info
                     break
 
index 63679dd27528e891f447f4957fc9f6b9604c77cf..7aed6dd7bbfd6d587caf98dd2317ee812da270f6 100644 (file)
@@ -3,6 +3,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: allow-untyped-defs, allow-untyped-calls
 
 """
 
@@ -12,18 +13,30 @@ from __future__ import annotations
 
 import typing
 from typing import Any
+from typing import Callable
 from typing import cast
-from typing import Mapping
-from typing import NoReturn
+from typing import Dict
+from typing import Iterable
 from typing import Optional
+from typing import overload
+from typing import Sequence
 from typing import Tuple
+from typing import Type
+from typing import TypeVar
 from typing import Union
 
 from . import util as orm_util
+from ._typing import insp_is_aliased_class
+from ._typing import insp_is_attribute
+from ._typing import insp_is_mapper
+from ._typing import insp_is_mapper_property
+from .attributes import QueryableAttribute
 from .base import InspectionAttr
 from .interfaces import LoaderOption
 from .path_registry import _DEFAULT_TOKEN
 from .path_registry import _WILDCARD_TOKEN
+from .path_registry import AbstractEntityRegistry
+from .path_registry import path_is_property
 from .path_registry import PathRegistry
 from .path_registry import TokenRegistry
 from .util import _orm_full_deannotate
@@ -38,14 +51,37 @@ from ..sql import roles
 from ..sql import traversals
 from ..sql import visitors
 from ..sql.base import _generative
+from ..util.typing import Final
+from ..util.typing import Literal
 
-_RELATIONSHIP_TOKEN = "relationship"
-_COLUMN_TOKEN = "column"
+_RELATIONSHIP_TOKEN: Final[Literal["relationship"]] = "relationship"
+_COLUMN_TOKEN: Final[Literal["column"]] = "column"
+
+_FN = TypeVar("_FN", bound="Callable[..., Any]")
 
 if typing.TYPE_CHECKING:
+    from ._typing import _EntityType
+    from ._typing import _InternalEntityType
+    from .context import _MapperEntity
+    from .context import ORMCompileState
+    from .context import QueryContext
+    from .interfaces import _StrategyKey
+    from .interfaces import MapperProperty
     from .mapper import Mapper
+    from .path_registry import _PathRepresentation
+    from ..sql._typing import _ColumnExpressionArgument
+    from ..sql._typing import _FromClauseArgument
+    from ..sql.cache_key import _CacheKeyTraversalType
+    from ..sql.cache_key import CacheKey
+
+Self_AbstractLoad = TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
+
+_AttrType = Union[str, "QueryableAttribute[Any]"]
 
-Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
+_WildcardKeyType = Literal["relationship", "column"]
+_StrategySpec = Dict[str, Any]
+_OptsType = Dict[str, Any]
+_AttrGroupType = Tuple[_AttrType, ...]
 
 
 class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
@@ -54,7 +90,12 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
     _is_strategy_option = True
     propagate_to_loaders: bool
 
-    def contains_eager(self, attr, alias=None, _is_chain=False):
+    def contains_eager(
+        self: Self_AbstractLoad,
+        attr: _AttrType,
+        alias: Optional[_FromClauseArgument] = None,
+        _is_chain: bool = False,
+    ) -> Self_AbstractLoad:
         r"""Indicate that the given attribute should be eagerly loaded from
         columns stated manually in the query.
 
@@ -94,9 +135,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         """
         if alias is not None:
             if not isinstance(alias, str):
-                info = inspect(alias)
-                alias = info.selectable
-
+                coerced_alias = coercions.expect(roles.FromClauseRole, alias)
             else:
                 util.warn_deprecated(
                     "Passing a string name for the 'alias' argument to "
@@ -105,21 +144,28 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
                     "sqlalchemy.orm.aliased() construct.",
                     version="1.4",
                 )
+                coerced_alias = alias
 
         elif getattr(attr, "_of_type", None):
-            ot = inspect(attr._of_type)
-            alias = ot.selectable
+            assert isinstance(attr, QueryableAttribute)
+            ot: Optional[_InternalEntityType[Any]] = inspect(attr._of_type)
+            assert ot is not None
+            coerced_alias = ot.selectable
+        else:
+            coerced_alias = None
 
         cloned = self._set_relationship_strategy(
             attr,
             {"lazy": "joined"},
             propagate_to_loaders=False,
-            opts={"eager_from_alias": alias},
+            opts={"eager_from_alias": coerced_alias},
             _reconcile_to_other=True if _is_chain else None,
         )
         return cloned
 
-    def load_only(self, *attrs):
+    def load_only(
+        self: Self_AbstractLoad, *attrs: _AttrType
+    ) -> Self_AbstractLoad:
         """Indicate that for a particular entity, only the given list
         of column-based attribute names should be loaded; all others will be
         deferred.
@@ -159,11 +205,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
             {"deferred": False, "instrument": True},
         )
         cloned = cloned._set_column_strategy(
-            "*", {"deferred": True, "instrument": True}, {"undefer_pks": True}
+            ("*",),
+            {"deferred": True, "instrument": True},
+            {"undefer_pks": True},
         )
         return cloned
 
-    def joinedload(self, attr, innerjoin=None):
+    def joinedload(
+        self: Self_AbstractLoad,
+        attr: _AttrType,
+        innerjoin: Optional[bool] = None,
+    ) -> Self_AbstractLoad:
         """Indicate that the given attribute should be loaded using joined
         eager loading.
 
@@ -258,7 +310,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         )
         return loader
 
-    def subqueryload(self, attr):
+    def subqueryload(
+        self: Self_AbstractLoad, attr: _AttrType
+    ) -> Self_AbstractLoad:
         """Indicate that the given attribute should be loaded using
         subquery eager loading.
 
@@ -289,7 +343,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         """
         return self._set_relationship_strategy(attr, {"lazy": "subquery"})
 
-    def selectinload(self, attr):
+    def selectinload(
+        self: Self_AbstractLoad, attr: _AttrType
+    ) -> Self_AbstractLoad:
         """Indicate that the given attribute should be loaded using
         SELECT IN eager loading.
 
@@ -321,7 +377,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         """
         return self._set_relationship_strategy(attr, {"lazy": "selectin"})
 
-    def lazyload(self, attr):
+    def lazyload(
+        self: Self_AbstractLoad, attr: _AttrType
+    ) -> Self_AbstractLoad:
         """Indicate that the given attribute should be loaded using "lazy"
         loading.
 
@@ -337,7 +395,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         """
         return self._set_relationship_strategy(attr, {"lazy": "select"})
 
-    def immediateload(self, attr):
+    def immediateload(
+        self: Self_AbstractLoad, attr: _AttrType
+    ) -> Self_AbstractLoad:
         """Indicate that the given attribute should be loaded using
         an immediate load with a per-attribute SELECT statement.
 
@@ -361,7 +421,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         loader = self._set_relationship_strategy(attr, {"lazy": "immediate"})
         return loader
 
-    def noload(self, attr):
+    def noload(self: Self_AbstractLoad, attr: _AttrType) -> Self_AbstractLoad:
         """Indicate that the given relationship attribute should remain
         unloaded.
 
@@ -387,7 +447,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
         return self._set_relationship_strategy(attr, {"lazy": "noload"})
 
-    def raiseload(self, attr, sql_only=False):
+    def raiseload(
+        self: Self_AbstractLoad, attr: _AttrType, sql_only: bool = False
+    ) -> Self_AbstractLoad:
         """Indicate that the given attribute should raise an error if accessed.
 
         A relationship attribute configured with :func:`_orm.raiseload` will
@@ -428,7 +490,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
             attr, {"lazy": "raise_on_sql" if sql_only else "raise"}
         )
 
-    def defaultload(self, attr):
+    def defaultload(
+        self: Self_AbstractLoad, attr: _AttrType
+    ) -> Self_AbstractLoad:
         """Indicate an attribute should load using its default loader style.
 
         This method is used to link to other loader options further into
@@ -463,7 +527,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         """
         return self._set_relationship_strategy(attr, None)
 
-    def defer(self, key, raiseload=False):
+    def defer(
+        self: Self_AbstractLoad, key: _AttrType, raiseload: bool = False
+    ) -> Self_AbstractLoad:
         r"""Indicate that the given column-oriented attribute should be
         deferred, e.g. not loaded until accessed.
 
@@ -524,7 +590,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
             strategy["raiseload"] = True
         return self._set_column_strategy((key,), strategy)
 
-    def undefer(self, key):
+    def undefer(self: Self_AbstractLoad, key: _AttrType) -> Self_AbstractLoad:
         r"""Indicate that the given column-oriented attribute should be
         undeferred, e.g. specified within the SELECT statement of the entity
         as a whole.
@@ -538,7 +604,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         Examples::
 
             # undefer two columns
-            session.query(MyClass).options(undefer("col1"), undefer("col2"))
+            session.query(MyClass).options(
+                undefer(MyClass.col1), undefer(MyClass.col2)
+            )
 
             # undefer all columns specific to a single class using Load + *
             session.query(MyClass, MyOtherClass).options(
@@ -546,7 +614,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
             # undefer a column on a related object
             session.query(MyClass).options(
-                defaultload(MyClass.items).undefer('text'))
+                defaultload(MyClass.items).undefer(MyClass.text))
 
         :param key: Attribute to be undeferred.
 
@@ -563,7 +631,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
             (key,), {"deferred": False, "instrument": True}
         )
 
-    def undefer_group(self, name):
+    def undefer_group(self: Self_AbstractLoad, name: str) -> Self_AbstractLoad:
         """Indicate that columns within the given deferred group name should be
         undeferred.
 
@@ -591,10 +659,14 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
         """
         return self._set_column_strategy(
-            _WILDCARD_TOKEN, None, {f"undefer_group_{name}": True}
+            (_WILDCARD_TOKEN,), None, {f"undefer_group_{name}": True}
         )
 
-    def with_expression(self, key, expression):
+    def with_expression(
+        self: Self_AbstractLoad,
+        key: _AttrType,
+        expression: _ColumnExpressionArgument[Any],
+    ) -> Self_AbstractLoad:
         r"""Apply an ad-hoc SQL expression to a "deferred expression"
         attribute.
 
@@ -626,15 +698,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
         """
 
-        expression = coercions.expect(
-            roles.LabeledColumnExprRole, _orm_full_deannotate(expression)
+        expression = _orm_full_deannotate(
+            coercions.expect(roles.LabeledColumnExprRole, expression)
         )
 
         return self._set_column_strategy(
             (key,), {"query_expression": True}, opts={"expression": expression}
         )
 
-    def selectin_polymorphic(self, classes):
+    def selectin_polymorphic(
+        self: Self_AbstractLoad, classes: Iterable[Type[Any]]
+    ) -> Self_AbstractLoad:
         """Indicate an eager load should take place for all attributes
         specific to a subclass.
 
@@ -659,25 +733,37 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         )
         return self
 
-    def _coerce_strat(self, strategy):
+    @overload
+    def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey:
+        ...
+
+    @overload
+    def _coerce_strat(self, strategy: Literal[None]) -> None:
+        ...
+
+    def _coerce_strat(
+        self, strategy: Optional[_StrategySpec]
+    ) -> Optional[_StrategyKey]:
         if strategy is not None:
-            strategy = tuple(sorted(strategy.items()))
-        return strategy
+            strategy_key = tuple(sorted(strategy.items()))
+        else:
+            strategy_key = None
+        return strategy_key
 
     @_generative
     def _set_relationship_strategy(
         self: Self_AbstractLoad,
-        attr,
-        strategy,
-        propagate_to_loaders=True,
-        opts=None,
-        _reconcile_to_other=None,
+        attr: _AttrType,
+        strategy: Optional[_StrategySpec],
+        propagate_to_loaders: bool = True,
+        opts: Optional[_OptsType] = None,
+        _reconcile_to_other: Optional[bool] = None,
     ) -> Self_AbstractLoad:
-        strategy = self._coerce_strat(strategy)
+        strategy_key = self._coerce_strat(strategy)
 
         self._clone_for_bind_strategy(
             (attr,),
-            strategy,
+            strategy_key,
             _RELATIONSHIP_TOKEN,
             opts=opts,
             propagate_to_loaders=propagate_to_loaders,
@@ -687,13 +773,16 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
     @_generative
     def _set_column_strategy(
-        self: Self_AbstractLoad, attrs, strategy, opts=None
+        self: Self_AbstractLoad,
+        attrs: Tuple[_AttrType, ...],
+        strategy: Optional[_StrategySpec],
+        opts: Optional[_OptsType] = None,
     ) -> Self_AbstractLoad:
-        strategy = self._coerce_strat(strategy)
+        strategy_key = self._coerce_strat(strategy)
 
         self._clone_for_bind_strategy(
             attrs,
-            strategy,
+            strategy_key,
             _COLUMN_TOKEN,
             opts=opts,
             attr_group=attrs,
@@ -702,12 +791,15 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
     @_generative
     def _set_generic_strategy(
-        self: Self_AbstractLoad, attrs, strategy, _reconcile_to_other=None
+        self: Self_AbstractLoad,
+        attrs: Tuple[_AttrType, ...],
+        strategy: _StrategySpec,
+        _reconcile_to_other: Optional[bool] = None,
     ) -> Self_AbstractLoad:
-        strategy = self._coerce_strat(strategy)
+        strategy_key = self._coerce_strat(strategy)
         self._clone_for_bind_strategy(
             attrs,
-            strategy,
+            strategy_key,
             None,
             propagate_to_loaders=True,
             reconcile_to_other=_reconcile_to_other,
@@ -716,14 +808,14 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
     @_generative
     def _set_class_strategy(
-        self: Self_AbstractLoad, strategy, opts
+        self: Self_AbstractLoad, strategy: _StrategySpec, opts: _OptsType
     ) -> Self_AbstractLoad:
-        strategy = self._coerce_strat(strategy)
+        strategy_key = self._coerce_strat(strategy)
 
-        self._clone_for_bind_strategy(None, strategy, None, opts=opts)
+        self._clone_for_bind_strategy(None, strategy_key, None, opts=opts)
         return self
 
-    def _apply_to_parent(self, parent):
+    def _apply_to_parent(self, parent: Load) -> None:
         """apply this :class:`_orm._AbstractLoad` object as a sub-option o
         a :class:`_orm.Load` object.
 
@@ -732,7 +824,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         """
         raise NotImplementedError()
 
-    def options(self: Self_AbstractLoad, *opts) -> NoReturn:
+    def options(
+        self: Self_AbstractLoad, *opts: _AbstractLoad
+    ) -> Self_AbstractLoad:
         r"""Apply a series of options as sub-options to this
         :class:`_orm._AbstractLoad` object.
 
@@ -742,20 +836,22 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         raise NotImplementedError()
 
     def _clone_for_bind_strategy(
-        self,
-        attrs,
-        strategy,
-        wildcard_key,
-        opts=None,
-        attr_group=None,
-        propagate_to_loaders=True,
-        reconcile_to_other=None,
-    ):
+        self: Self_AbstractLoad,
+        attrs: Optional[Tuple[_AttrType, ...]],
+        strategy: Optional[_StrategyKey],
+        wildcard_key: Optional[_WildcardKeyType],
+        opts: Optional[_OptsType] = None,
+        attr_group: Optional[_AttrGroupType] = None,
+        propagate_to_loaders: bool = True,
+        reconcile_to_other: Optional[bool] = None,
+    ) -> Self_AbstractLoad:
         raise NotImplementedError()
 
     def process_compile_state_replaced_entities(
-        self, compile_state, mapper_entities
-    ):
+        self,
+        compile_state: ORMCompileState,
+        mapper_entities: Sequence[_MapperEntity],
+    ) -> None:
         if not compile_state.compile_options._enable_eagerloads:
             return
 
@@ -768,7 +864,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
             not bool(compile_state.current_path),
         )
 
-    def process_compile_state(self, compile_state):
+    def process_compile_state(self, compile_state: ORMCompileState) -> None:
         if not compile_state.compile_options._enable_eagerloads:
             return
 
@@ -779,12 +875,22 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
             and not compile_state.compile_options._for_refresh_state,
         )
 
-    def _process(self, compile_state, mapper_entities, raiseerr):
+    def _process(
+        self,
+        compile_state: ORMCompileState,
+        mapper_entities: Sequence[_MapperEntity],
+        raiseerr: bool,
+    ) -> None:
         """implemented by subclasses"""
         raise NotImplementedError()
 
     @classmethod
-    def _chop_path(cls, to_chop, path, debug=False):
+    def _chop_path(
+        cls,
+        to_chop: _PathRepresentation,
+        path: PathRegistry,
+        debug: bool = False,
+    ) -> Optional[_PathRepresentation]:
         i = -1
 
         for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)):
@@ -793,7 +899,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
                     return to_chop
                 elif (
                     c_token != f"{_RELATIONSHIP_TOKEN}:{_WILDCARD_TOKEN}"
-                    and c_token != p_token.key
+                    and c_token != p_token.key  # type: ignore
                 ):
                     return None
 
@@ -801,9 +907,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
                 continue
             elif (
                 isinstance(c_token, InspectionAttr)
-                and c_token.is_mapper
+                and insp_is_mapper(c_token)
                 and (
-                    (p_token.is_mapper and c_token.isa(p_token))
+                    (insp_is_mapper(p_token) and c_token.isa(p_token))
                     or (
                         # a too-liberal check here to allow a path like
                         # A->A.bs->B->B.cs->C->C.ds, natural path, to chop
@@ -827,10 +933,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
                         # test_of_type.py->test_all_subq_query
                         #
                         i >= 2
-                        and p_token.is_aliased_class
+                        and insp_is_aliased_class(p_token)
                         and p_token._is_with_polymorphic
                         and c_token in p_token.with_polymorphic_mappers
-                        # and (breakpoint() or True)
                     )
                 )
             ):
@@ -841,7 +946,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         return to_chop[i + 1 :]
 
 
-SelfLoad = typing.TypeVar("SelfLoad", bound="Load")
+SelfLoad = TypeVar("SelfLoad", bound="Load")
 
 
 class Load(_AbstractLoad):
@@ -903,28 +1008,28 @@ class Load(_AbstractLoad):
     _cache_key_traversal = None
 
     path: PathRegistry
-    context: Tuple["_LoadElement", ...]
+    context: Tuple[_LoadElement, ...]
 
-    def __init__(self, entity):
-        insp = cast(Union["Mapper", AliasedInsp], inspect(entity))
+    def __init__(self, entity: _EntityType[Any]):
+        insp = cast("Union[Mapper[Any], AliasedInsp[Any]]", inspect(entity))
         insp._post_inspect
 
         self.path = insp._path_registry
         self.context = ()
         self.propagate_to_loaders = False
 
-    def __str__(self):
+    def __str__(self) -> str:
         return f"Load({self.path[0]})"
 
     @classmethod
-    def _construct_for_existing_path(cls, path):
+    def _construct_for_existing_path(cls, path: PathRegistry) -> Load:
         load = cls.__new__(cls)
         load.path = path
         load.context = ()
         load.propagate_to_loaders = False
         return load
 
-    def _adjust_for_extra_criteria(self, context):
+    def _adjust_for_extra_criteria(self, context: QueryContext) -> Load:
         """Apply the current bound parameters in a QueryContext to all
         occurrences "extra_criteria" stored within this ``Load`` object,
         returning a new instance of this ``Load`` object.
@@ -932,10 +1037,10 @@ class Load(_AbstractLoad):
         """
         orig_query = context.compile_state.select_statement
 
-        orig_cache_key = None
-        replacement_cache_key = None
+        orig_cache_key: Optional[CacheKey] = None
+        replacement_cache_key: Optional[CacheKey] = None
 
-        def process(opt):
+        def process(opt: _LoadElement) -> _LoadElement:
             if not opt._extra_criteria:
                 return opt
 
@@ -948,6 +1053,9 @@ class Load(_AbstractLoad):
                 orig_cache_key = orig_query._generate_cache_key()
                 replacement_cache_key = context.query._generate_cache_key()
 
+                assert orig_cache_key is not None
+                assert replacement_cache_key is not None
+
             opt._extra_criteria = tuple(
                 replacement_cache_key._apply_params_to_element(
                     orig_cache_key, crit
@@ -975,12 +1083,22 @@ class Load(_AbstractLoad):
         ezero = None
         for ent in mapper_entities:
             ezero = ent.entity_zero
-            if ezero and orm_util._entity_corresponds_to(ezero, path[0]):
+            if ezero and orm_util._entity_corresponds_to(
+                # technically this can be a token also, but this is
+                # safe to pass to _entity_corresponds_to()
+                ezero,
+                cast("_InternalEntityType[Any]", path[0]),
+            ):
                 return ezero
 
         return None
 
-    def _process(self, compile_state, mapper_entities, raiseerr):
+    def _process(
+        self,
+        compile_state: ORMCompileState,
+        mapper_entities: Sequence[_MapperEntity],
+        raiseerr: bool,
+    ) -> None:
 
         reconciled_lead_entity = self._reconcile_query_entities_with_us(
             mapper_entities, raiseerr
@@ -995,7 +1113,7 @@ class Load(_AbstractLoad):
                 raiseerr,
             )
 
-    def _apply_to_parent(self, parent):
+    def _apply_to_parent(self, parent: Load) -> None:
         """apply this :class:`_orm.Load` object as a sub-option of another
         :class:`_orm.Load` object.
 
@@ -1007,7 +1125,8 @@ class Load(_AbstractLoad):
         assert cloned.propagate_to_loaders == self.propagate_to_loaders
 
         if not orm_util._entity_corresponds_to_use_path_impl(
-            parent.path[-1], cloned.path[0]
+            cast("_InternalEntityType[Any]", parent.path[-1]),
+            cast("_InternalEntityType[Any]", cloned.path[0]),
         ):
             raise sa_exc.ArgumentError(
                 f'Attribute "{cloned.path[1]}" does not link '
@@ -1025,7 +1144,7 @@ class Load(_AbstractLoad):
             parent.context += cloned.context
 
     @_generative
-    def options(self: SelfLoad, *opts) -> SelfLoad:
+    def options(self: SelfLoad, *opts: _AbstractLoad) -> SelfLoad:
         r"""Apply a series of options as sub-options to this
         :class:`_orm.Load`
         object.
@@ -1062,38 +1181,36 @@ class Load(_AbstractLoad):
         return self
 
     def _clone_for_bind_strategy(
-        self,
-        attrs,
-        strategy,
-        wildcard_key,
-        opts=None,
-        attr_group=None,
-        propagate_to_loaders=True,
-        reconcile_to_other=None,
-    ) -> None:
+        self: SelfLoad,
+        attrs: Optional[Tuple[_AttrType, ...]],
+        strategy: Optional[_StrategyKey],
+        wildcard_key: Optional[_WildcardKeyType],
+        opts: Optional[_OptsType] = None,
+        attr_group: Optional[_AttrGroupType] = None,
+        propagate_to_loaders: bool = True,
+        reconcile_to_other: Optional[bool] = None,
+    ) -> SelfLoad:
         # for individual strategy that needs to propagate, set the whole
         # Load container to also propagate, so that it shows up in
         # InstanceState.load_options
         if propagate_to_loaders:
             self.propagate_to_loaders = True
 
-        if not self.path.has_entity:
-            if self.path.is_token:
+        if self.path.is_token:
+            raise sa_exc.ArgumentError(
+                "Wildcard token cannot be followed by another entity"
+            )
+
+        elif path_is_property(self.path):
+            # re-use the lookup which will raise a nicely formatted
+            # LoaderStrategyException
+            if strategy:
+                self.path.prop._strategy_lookup(self.path.prop, strategy[0])
+            else:
                 raise sa_exc.ArgumentError(
-                    "Wildcard token cannot be followed by another entity"
+                    f"Mapped attribute '{self.path.prop}' does not "
+                    "refer to a mapped entity"
                 )
-            else:
-                # re-use the lookup which will raise a nicely formatted
-                # LoaderStrategyException
-                if strategy:
-                    self.path.prop._strategy_lookup(
-                        self.path.prop, strategy[0]
-                    )
-                else:
-                    raise sa_exc.ArgumentError(
-                        f"Mapped attribute '{self.path.prop}' does not "
-                        "refer to a mapped entity"
-                    )
 
         if attrs is None:
             load_element = _ClassStrategyLoad.create(
@@ -1140,6 +1257,7 @@ class Load(_AbstractLoad):
                     if wildcard_key is _RELATIONSHIP_TOKEN:
                         self.path = load_element.path
                     self.context += (load_element,)
+        return self
 
     def __getstate__(self):
         d = self._shallow_to_dict()
@@ -1151,7 +1269,7 @@ class Load(_AbstractLoad):
         self._shallow_from_dict(state)
 
 
-SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
+SelfWildcardLoad = TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
 
 
 class _WildcardLoad(_AbstractLoad):
@@ -1167,14 +1285,14 @@ class _WildcardLoad(_AbstractLoad):
             visitors.ExtendedInternalTraversal.dp_string_multi_dict,
         ),
     ]
-    cache_key_traversal = None
+    cache_key_traversal: _CacheKeyTraversalType = None
 
     strategy: Optional[Tuple[Any, ...]]
-    local_opts: Mapping[str, Any]
+    local_opts: _OptsType
     path: Tuple[str, ...]
     propagate_to_loaders = False
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.path = ()
         self.strategy = None
         self.local_opts = util.EMPTY_DICT
@@ -1189,6 +1307,7 @@ class _WildcardLoad(_AbstractLoad):
         propagate_to_loaders=True,
         reconcile_to_other=None,
     ):
+        assert attrs is not None
         attr = attrs[0]
         assert (
             wildcard_key
@@ -1203,10 +1322,12 @@ class _WildcardLoad(_AbstractLoad):
         if opts:
             self.local_opts = util.immutabledict(opts)
 
-    def options(self: SelfWildcardLoad, *opts) -> SelfWildcardLoad:
+    def options(
+        self: SelfWildcardLoad, *opts: _AbstractLoad
+    ) -> SelfWildcardLoad:
         raise NotImplementedError("Star option does not support sub-options")
 
-    def _apply_to_parent(self, parent):
+    def _apply_to_parent(self, parent: Load) -> None:
         """apply this :class:`_orm._WildcardLoad` object as a sub-option of
         a :class:`_orm.Load` object.
 
@@ -1215,12 +1336,11 @@ class _WildcardLoad(_AbstractLoad):
         it may be used as the sub-option of a :class:`_orm.Load` object.
 
         """
-
         attr = self.path[0]
         if attr.endswith(_DEFAULT_TOKEN):
             attr = f"{attr.split(':')[0]}:{_WILDCARD_TOKEN}"
 
-        effective_path = parent.path.token(attr)
+        effective_path = cast(AbstractEntityRegistry, parent.path).token(attr)
 
         assert effective_path.is_token
 
@@ -1244,20 +1364,21 @@ class _WildcardLoad(_AbstractLoad):
         entities = [ent.entity_zero for ent in mapper_entities]
         current_path = compile_state.current_path
 
-        start_path = self.path
+        start_path: _PathRepresentation = self.path
 
         # TODO: chop_path already occurs in loader.process_compile_state()
         # so we will seek to simplify this
         if current_path:
-            start_path = self._chop_path(start_path, current_path)
-            if not start_path:
+            new_path = self._chop_path(start_path, current_path)
+            if not new_path:
                 return
+            start_path = new_path
 
         # start_path is a single-token tuple
         assert start_path and len(start_path) == 1
 
         token = start_path[0]
-
+        assert isinstance(token, str)
         entity = self._find_entity_basestring(entities, token, raiseerr)
 
         if not entity:
@@ -1270,6 +1391,7 @@ class _WildcardLoad(_AbstractLoad):
         # we just located, then go through the rest of our path
         # tokens and populate into the Load().
 
+        assert isinstance(token, str)
         loader = _TokenStrategyLoad.create(
             path_element._path_registry,
             token,
@@ -1291,7 +1413,12 @@ class _WildcardLoad(_AbstractLoad):
 
         return loader
 
-    def _find_entity_basestring(self, entities, token, raiseerr):
+    def _find_entity_basestring(
+        self,
+        entities: Iterable[_InternalEntityType[Any]],
+        token: str,
+        raiseerr: bool,
+    ) -> Optional[_InternalEntityType[Any]]:
         if token.endswith(f":{_WILDCARD_TOKEN}"):
             if len(list(entities)) != 1:
                 if raiseerr:
@@ -1324,11 +1451,11 @@ class _WildcardLoad(_AbstractLoad):
             else:
                 return None
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         d = self._shallow_to_dict()
         return d
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         self._shallow_from_dict(state)
 
 
@@ -1372,38 +1499,38 @@ class _LoadElement(
     _extra_criteria: Tuple[Any, ...]
 
     _reconcile_to_other: Optional[bool]
-    strategy: Tuple[Any, ...]
+    strategy: Optional[_StrategyKey]
     path: PathRegistry
     propagate_to_loaders: bool
 
-    local_opts: Mapping[str, Any]
+    local_opts: util.immutabledict[str, Any]
 
     is_token_strategy: bool
     is_class_strategy: bool
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return id(self)
 
     def __eq__(self, other):
         return traversals.compare(self, other)
 
     @property
-    def is_opts_only(self):
+    def is_opts_only(self) -> bool:
         return bool(self.local_opts and self.strategy is None)
 
-    def _clone(self):
+    def _clone(self, **kw: Any) -> _LoadElement:
         cls = self.__class__
         s = cls.__new__(cls)
 
         self._shallow_copy_to(s)
         return s
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         d = self._shallow_to_dict()
         d["path"] = self.path.serialize()
         return d
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         state["path"] = PathRegistry.deserialize(state["path"])
         self._shallow_from_dict(state)
 
@@ -1437,8 +1564,8 @@ class _LoadElement(
             )
 
     def _adjust_effective_path_for_current_path(
-        self, effective_path, current_path
-    ):
+        self, effective_path: PathRegistry, current_path: PathRegistry
+    ) -> Optional[PathRegistry]:
         """receives the 'current_path' entry from an :class:`.ORMCompileState`
         instance, which is set during lazy loads and secondary loader strategy
         loads, and adjusts the given path to be relative to the
@@ -1456,7 +1583,7 @@ class _LoadElement(
 
 
         """
-        chopped_start_path = Load._chop_path(effective_path, current_path)
+        chopped_start_path = Load._chop_path(effective_path.path, current_path)
         if not chopped_start_path:
             return None
 
@@ -1523,16 +1650,16 @@ class _LoadElement(
     @classmethod
     def create(
         cls,
-        path,
-        attr,
-        strategy,
-        wildcard_key,
-        local_opts,
-        propagate_to_loaders,
-        raiseerr=True,
-        attr_group=None,
-        reconcile_to_other=None,
-    ):
+        path: PathRegistry,
+        attr: Optional[_AttrType],
+        strategy: Optional[_StrategyKey],
+        wildcard_key: Optional[_WildcardKeyType],
+        local_opts: Optional[_OptsType],
+        propagate_to_loaders: bool,
+        raiseerr: bool = True,
+        attr_group: Optional[_AttrGroupType] = None,
+        reconcile_to_other: Optional[bool] = None,
+    ) -> _LoadElement:
         """Create a new :class:`._LoadElement` object."""
 
         opt = cls.__new__(cls)
@@ -1554,14 +1681,14 @@ class _LoadElement(
         path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr)
 
         if not path:
-            return None
+            return None  # type: ignore
 
         assert opt.is_token_strategy == path.is_token
 
         opt.path = path
         return opt
 
-    def __init__(self, path, strategy, local_opts, propagate_to_loaders):
+    def __init__(self) -> None:
         raise NotImplementedError()
 
     def _prepend_path_from(self, parent):
@@ -1580,7 +1707,8 @@ class _LoadElement(
         assert cloned.is_class_strategy == self.is_class_strategy
 
         if not orm_util._entity_corresponds_to_use_path_impl(
-            parent.path[-1], cloned.path[0]
+            cast("_InternalEntityType[Any]", parent.path[-1]),
+            cast("_InternalEntityType[Any]", cloned.path[0]),
         ):
             raise sa_exc.ArgumentError(
                 f'Attribute "{cloned.path[1]}" does not link '
@@ -1592,7 +1720,9 @@ class _LoadElement(
         return cloned
 
     @staticmethod
-    def _reconcile(replacement, existing):
+    def _reconcile(
+        replacement: _LoadElement, existing: _LoadElement
+    ) -> _LoadElement:
         """define behavior for when two Load objects are to be put into
         the context.attributes under the same key.
 
@@ -1670,7 +1800,7 @@ class _AttributeStrategyLoad(_LoadElement):
         ),
     ]
 
-    _of_type: Union["Mapper", AliasedInsp, None]
+    _of_type: Union["Mapper[Any]", "AliasedInsp[Any]", None]
     _path_with_polymorphic_path: Optional[PathRegistry]
 
     is_class_strategy = False
@@ -1812,7 +1942,7 @@ class _AttributeStrategyLoad(_LoadElement):
             pwpi = inspect(
                 orm_util.AliasedInsp._with_polymorphic_factory(
                     pwpi.mapper.base_mapper,
-                    pwpi.mapper,
+                    (pwpi.mapper,),
                     aliased=True,
                     _use_mapper_path=True,
                 )
@@ -1820,11 +1950,12 @@ class _AttributeStrategyLoad(_LoadElement):
         start_path = self._path_with_polymorphic_path
         if current_path:
 
-            start_path = self._adjust_effective_path_for_current_path(
+            new_path = self._adjust_effective_path_for_current_path(
                 start_path, current_path
             )
-            if start_path is None:
+            if new_path is None:
                 return
+            start_path = new_path
 
         key = ("path_with_polymorphic", start_path.natural_path)
         if key in context:
@@ -1872,6 +2003,7 @@ class _AttributeStrategyLoad(_LoadElement):
             effective_path = self.path
 
         if current_path:
+            assert effective_path is not None
             effective_path = self._adjust_effective_path_for_current_path(
                 effective_path, current_path
             )
@@ -1985,11 +2117,12 @@ class _TokenStrategyLoad(_LoadElement):
             )
 
         if current_path:
-            effective_path = self._adjust_effective_path_for_current_path(
+            new_effective_path = self._adjust_effective_path_for_current_path(
                 effective_path, current_path
             )
-            if effective_path is None:
+            if new_effective_path is None:
                 return []
+            effective_path = new_effective_path
 
         # for a wildcard token, expand out the path we set
         # to encompass everything from the query entity on
@@ -2048,19 +2181,25 @@ class _ClassStrategyLoad(_LoadElement):
         effective_path = self.path
 
         if current_path:
-            effective_path = self._adjust_effective_path_for_current_path(
+            new_effective_path = self._adjust_effective_path_for_current_path(
                 effective_path, current_path
             )
-            if effective_path is None:
+            if new_effective_path is None:
                 return []
+            effective_path = new_effective_path
 
-        return [("loader", cast(PathRegistry, effective_path).natural_path)]
+        return [("loader", effective_path.natural_path)]
 
 
-def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad:
-
-    lead_element = None
+def _generate_from_keys(
+    meth: Callable[..., _AbstractLoad],
+    keys: Tuple[_AttrType, ...],
+    chained: bool,
+    kw: Any,
+) -> _AbstractLoad:
+    lead_element: Optional[_AbstractLoad] = None
 
+    attr: Any
     for is_default, _keys in (True, keys[0:-1]), (False, keys[-1:]):
         for attr in _keys:
             if isinstance(attr, str):
@@ -2116,7 +2255,9 @@ def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad:
     return lead_element
 
 
-def _parse_attr_argument(attr):
+def _parse_attr_argument(
+    attr: _AttrType,
+) -> Tuple[InspectionAttr, _InternalEntityType[Any], MapperProperty[Any]]:
     """parse an attribute or wildcard argument to produce an
     :class:`._AbstractLoad` instance.
 
@@ -2126,16 +2267,21 @@ def _parse_attr_argument(attr):
 
     """
     try:
-        insp = inspect(attr)
+        # TODO: need to figure out this None thing being returned by
+        # inspect(), it should not have None as an option in most cases
+        # if at all
+        insp: InspectionAttr = inspect(attr)  # type: ignore
     except sa_exc.NoInspectionAvailable as err:
         raise sa_exc.ArgumentError(
             "expected ORM mapped attribute for loader strategy argument"
         ) from err
 
-    if insp.is_property:
+    lead_entity: _InternalEntityType[Any]
+
+    if insp_is_mapper_property(insp):
         lead_entity = insp.parent
         prop = insp
-    elif insp.is_attribute:
+    elif insp_is_attribute(insp):
         lead_entity = insp.parent
         prop = insp.prop
     else:
@@ -2146,7 +2292,7 @@ def _parse_attr_argument(attr):
     return insp, lead_entity, prop
 
 
-def loader_unbound_fn(fn):
+def loader_unbound_fn(fn: _FN) -> _FN:
     """decorator that applies docstrings between standalone loader functions
     and the loader methods on :class:`._AbstractLoad`.
 
@@ -2169,12 +2315,12 @@ See :func:`_orm.{fn.__name__}` for usage examples.
 
 
 @loader_unbound_fn
-def contains_eager(*keys, **kw) -> _AbstractLoad:
+def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
     return _generate_from_keys(Load.contains_eager, keys, True, kw)
 
 
 @loader_unbound_fn
-def load_only(*attrs) -> _AbstractLoad:
+def load_only(*attrs: _AttrType) -> _AbstractLoad:
     # TODO: attrs against different classes.  we likely have to
     # add some extra state to Load of some kind
     _, lead_element, _ = _parse_attr_argument(attrs[0])
@@ -2182,47 +2328,47 @@ def load_only(*attrs) -> _AbstractLoad:
 
 
 @loader_unbound_fn
-def joinedload(*keys, **kw) -> _AbstractLoad:
+def joinedload(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
     return _generate_from_keys(Load.joinedload, keys, False, kw)
 
 
 @loader_unbound_fn
-def subqueryload(*keys) -> _AbstractLoad:
+def subqueryload(*keys: _AttrType) -> _AbstractLoad:
     return _generate_from_keys(Load.subqueryload, keys, False, {})
 
 
 @loader_unbound_fn
-def selectinload(*keys) -> _AbstractLoad:
+def selectinload(*keys: _AttrType) -> _AbstractLoad:
     return _generate_from_keys(Load.selectinload, keys, False, {})
 
 
 @loader_unbound_fn
-def lazyload(*keys) -> _AbstractLoad:
+def lazyload(*keys: _AttrType) -> _AbstractLoad:
     return _generate_from_keys(Load.lazyload, keys, False, {})
 
 
 @loader_unbound_fn
-def immediateload(*keys) -> _AbstractLoad:
+def immediateload(*keys: _AttrType) -> _AbstractLoad:
     return _generate_from_keys(Load.immediateload, keys, False, {})
 
 
 @loader_unbound_fn
-def noload(*keys) -> _AbstractLoad:
+def noload(*keys: _AttrType) -> _AbstractLoad:
     return _generate_from_keys(Load.noload, keys, False, {})
 
 
 @loader_unbound_fn
-def raiseload(*keys, **kw) -> _AbstractLoad:
+def raiseload(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
     return _generate_from_keys(Load.raiseload, keys, False, kw)
 
 
 @loader_unbound_fn
-def defaultload(*keys) -> _AbstractLoad:
+def defaultload(*keys: _AttrType) -> _AbstractLoad:
     return _generate_from_keys(Load.defaultload, keys, False, {})
 
 
 @loader_unbound_fn
-def defer(key, *addl_attrs, **kw) -> _AbstractLoad:
+def defer(key: _AttrType, *addl_attrs: _AttrType, **kw: Any) -> _AbstractLoad:
     if addl_attrs:
         util.warn_deprecated(
             "The *addl_attrs on orm.defer is deprecated.  Please use "
@@ -2234,7 +2380,7 @@ def defer(key, *addl_attrs, **kw) -> _AbstractLoad:
 
 
 @loader_unbound_fn
-def undefer(key, *addl_attrs) -> _AbstractLoad:
+def undefer(key: _AttrType, *addl_attrs: _AttrType) -> _AbstractLoad:
     if addl_attrs:
         util.warn_deprecated(
             "The *addl_attrs on orm.undefer is deprecated.  Please use "
@@ -2246,19 +2392,23 @@ def undefer(key, *addl_attrs) -> _AbstractLoad:
 
 
 @loader_unbound_fn
-def undefer_group(name) -> _AbstractLoad:
+def undefer_group(name: str) -> _AbstractLoad:
     element = _WildcardLoad()
     return element.undefer_group(name)
 
 
 @loader_unbound_fn
-def with_expression(key, expression) -> _AbstractLoad:
+def with_expression(
+    key: _AttrType, expression: _ColumnExpressionArgument[Any]
+) -> _AbstractLoad:
     return _generate_from_keys(
         Load.with_expression, (key,), False, {"expression": expression}
     )
 
 
 @loader_unbound_fn
-def selectin_polymorphic(base_cls, classes) -> _AbstractLoad:
+def selectin_polymorphic(
+    base_cls: _EntityType[Any], classes: Iterable[Type[Any]]
+) -> _AbstractLoad:
     ul = Load(base_cls)
     return ul.selectin_polymorphic(classes)
index 4f63e241baefd4b681456546e0ff209b0c43bdcb..4f1eeb39b6c264d390a0aef78777743198eb2d3f 100644 (file)
@@ -4,7 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
 
 
 """private module containing functions used for copying data
@@ -14,9 +14,9 @@ between instances based on join conditions.
 
 from __future__ import annotations
 
-from . import attributes
 from . import exc
 from . import util as orm_util
+from .base import PassiveFlag
 
 
 def populate(
@@ -36,7 +36,7 @@ def populate(
             # inline of source_mapper._get_state_attr_by_column
             prop = source_mapper._columntoproperty[l]
             value = source.manager[prop.key].impl.get(
-                source, source_dict, attributes.PASSIVE_OFF
+                source, source_dict, PassiveFlag.PASSIVE_OFF
             )
         except exc.UnmappedColumnError as err:
             _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err)
@@ -74,8 +74,8 @@ def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs):
         try:
             prop = source_mapper._columntoproperty[r]
             source_dict[prop.key] = value
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(True, source_mapper, l, source_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(True, source_mapper, l, source_mapper, r, err)
 
 
 def clear(dest, dest_mapper, synchronize_pairs):
@@ -103,7 +103,7 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
                 source.obj(), l
             )
             value = source_mapper._get_state_attr_by_column(
-                source, source.dict, l, passive=attributes.PASSIVE_OFF
+                source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF
             )
         except exc.UnmappedColumnError as err:
             _raise_col_to_prop(False, source_mapper, l, None, r, err)
@@ -115,7 +115,7 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs):
     for l, r in synchronize_pairs:
         try:
             value = source_mapper._get_state_attr_by_column(
-                source, source.dict, l, passive=attributes.PASSIVE_OFF
+                source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF
             )
         except exc.UnmappedColumnError as err:
             _raise_col_to_prop(False, source_mapper, l, None, r, err)
@@ -134,7 +134,7 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
         except exc.UnmappedColumnError as err:
             _raise_col_to_prop(False, source_mapper, l, None, r, err)
         history = uowcommit.get_attribute_history(
-            source, prop.key, attributes.PASSIVE_NO_INITIALIZE
+            source, prop.key, PassiveFlag.PASSIVE_NO_INITIALIZE
         )
         if bool(history.deleted):
             return True
index 4da0b777376b65997ed355ee92c0f2e22a548b64..c50cc5bac84f6bb995b49c0eede3c8accb9872db 100644 (file)
@@ -12,6 +12,7 @@ import re
 import types
 import typing
 from typing import Any
+from typing import Callable
 from typing import cast
 from typing import Dict
 from typing import FrozenSet
@@ -82,24 +83,29 @@ if typing.TYPE_CHECKING:
     from ._typing import _EntityType
     from ._typing import _IdentityKeyType
     from ._typing import _InternalEntityType
-    from ._typing import _ORMColumnExprArgument
+    from ._typing import _ORMCOLEXPR
     from .context import _MapperEntity
     from .context import ORMCompileState
     from .mapper import Mapper
+    from .query import Query
     from .relationships import Relationship
     from ..engine import Row
     from ..engine import RowMapping
+    from ..sql._typing import _CE
     from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _EquivalentColumnMap
     from ..sql._typing import _FromClauseArgument
     from ..sql._typing import _OnClauseArgument
     from ..sql._typing import _PropagateAttrsType
+    from ..sql.annotation import _SA
     from ..sql.base import ReadOnlyColumnCollection
     from ..sql.elements import BindParameter
     from ..sql.selectable import _ColumnsClauseElement
     from ..sql.selectable import Alias
+    from ..sql.selectable import Select
     from ..sql.selectable import Subquery
     from ..sql.visitors import anon_map
+    from ..util.typing import _AnnotationScanType
 
 _T = TypeVar("_T", bound=Any)
 
@@ -144,9 +150,11 @@ class CascadeOptions(FrozenSet[str]):
     expunge: bool
     delete_orphan: bool
 
-    def __new__(cls, value_list):
+    def __new__(
+        cls, value_list: Optional[Union[Iterable[str], str]]
+    ) -> CascadeOptions:
         if isinstance(value_list, str) or value_list is None:
-            return cls.from_string(value_list)
+            return cls.from_string(value_list)  # type: ignore
         values = set(value_list)
         if values.difference(cls._allowed_cascades):
             raise sa_exc.ArgumentError(
@@ -864,7 +872,7 @@ class AliasedInsp(
     def _with_polymorphic_factory(
         cls,
         base: Union[_O, Mapper[_O]],
-        classes: Iterable[Type[Any]],
+        classes: Iterable[_EntityType[Any]],
         selectable: Union[Literal[False, None], FromClause] = False,
         flat: bool = False,
         polymorphic_on: Optional[ColumnElement[Any]] = None,
@@ -1011,23 +1019,40 @@ class AliasedInsp(
         )._aliased_insp
 
     def _adapt_element(
-        self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None
-    ) -> _ORMColumnExprArgument[_T]:
-        assert isinstance(elem, ColumnElement)
+        self, expr: _ORMCOLEXPR, key: Optional[str] = None
+    ) -> _ORMCOLEXPR:
+        assert isinstance(expr, ColumnElement)
         d: Dict[str, Any] = {
             "parententity": self,
             "parentmapper": self.mapper,
         }
         if key:
             d["proxy_key"] = key
+
+        # IMO mypy should see this one also as returning the same type
+        # we put into it, but it's not
         return (
-            self._adapter.traverse(elem)
+            self._adapter.traverse(expr)  # type: ignore
             ._annotate(d)
             ._set_propagate_attrs(
                 {"compile_state_plugin": "orm", "plugin_subject": self}
             )
         )
 
+    if TYPE_CHECKING:
+        # establish compatibility with the _ORMAdapterProto protocol,
+        # which in turn is compatible with _CoreAdapterProto.
+
+        def _orm_adapt_element(
+            self,
+            obj: _CE,
+            key: Optional[str] = None,
+        ) -> _CE:
+            ...
+
+    else:
+        _orm_adapt_element = _adapt_element
+
     def _entity_for_mapper(self, mapper):
         self_poly = self.with_polymorphic_mappers
         if mapper in self_poly:
@@ -1469,7 +1494,12 @@ class Bundle(
         cloned.name = name
         return cloned
 
-    def create_row_processor(self, query, procs, labels):
+    def create_row_processor(
+        self,
+        query: Select[Any],
+        procs: Sequence[Callable[[Row[Any]], Any]],
+        labels: Sequence[str],
+    ) -> Callable[[Row[Any]], Any]:
         """Produce the "row processing" function for this :class:`.Bundle`.
 
         May be overridden by subclasses.
@@ -1481,13 +1511,13 @@ class Bundle(
         """
         keyed_tuple = result_tuple(labels, [() for l in labels])
 
-        def proc(row):
+        def proc(row: Row[Any]) -> Any:
             return keyed_tuple([proc(row) for proc in procs])
 
         return proc
 
 
-def _orm_annotate(element, exclude=None):
+def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA:
     """Deep copy the given ClauseElement, annotating each element with the
     "_orm_adapt" flag.
 
@@ -1497,7 +1527,7 @@ def _orm_annotate(element, exclude=None):
     return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
 
 
-def _orm_deannotate(element):
+def _orm_deannotate(element: _SA) -> _SA:
     """Remove annotations that link a column to a particular mapping.
 
     Note this doesn't affect "remote" and "foreign" annotations
@@ -1511,7 +1541,7 @@ def _orm_deannotate(element):
     )
 
 
-def _orm_full_deannotate(element):
+def _orm_full_deannotate(element: _SA) -> _SA:
     return sql_util._deep_deannotate(element)
 
 
@@ -1560,13 +1590,15 @@ class _ORMJoin(expression.Join):
             on_selectable = prop.parent.selectable
         else:
             prop = None
+            on_selectable = None
 
         if prop:
             left_selectable = left_info.selectable
-
+            adapt_from: Optional[FromClause]
             if sql_util.clause_is_present(on_selectable, left_selectable):
                 adapt_from = on_selectable
             else:
+                assert isinstance(left_selectable, FromClause)
                 adapt_from = left_selectable
 
             (
@@ -1855,7 +1887,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool:
         return given.isa(mapper)
 
 
-def _getitem(iterable_query, item):
+def _getitem(iterable_query: Query[Any], item: Any) -> Any:
     """calculate __getitem__ in terms of an iterable query object
     that also has a slice() method.
 
@@ -1881,17 +1913,15 @@ def _getitem(iterable_query, item):
             isinstance(stop, int) and stop < 0
         ):
             _no_negative_indexes()
-            return list(iterable_query)[item]
 
         res = iterable_query.slice(start, stop)
         if step is not None:
-            return list(res)[None : None : item.step]
+            return list(res)[None : None : item.step]  # type: ignore
         else:
-            return list(res)
+            return list(res)  # type: ignore
     else:
         if item == -1:
             _no_negative_indexes()
-            return list(iterable_query)[-1]
         else:
             return list(iterable_query[item : item + 1])[0]
 
@@ -1933,7 +1963,7 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str:
 
 
 def _extract_mapped_subtype(
-    raw_annotation: Union[type, str],
+    raw_annotation: Optional[_AnnotationScanType],
     cls: type,
     key: str,
     attr_cls: Type[Any],
index f49a6d3ec5b37bce8260d8296d1e208f02ac90c4..ed1bd283224d36b913ec672a27321bd92d5608de 100644 (file)
@@ -61,6 +61,9 @@ if TYPE_CHECKING:
 _T = TypeVar("_T", bound=Any)
 
 
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
+
+
 class _HasClauseElement(Protocol):
     """indicates a class that has a __clause_element__() method"""
 
@@ -68,6 +71,13 @@ class _HasClauseElement(Protocol):
         ...
 
 
+class _CoreAdapterProto(Protocol):
+    """protocol for the ClauseAdapter/ColumnAdapter.traverse() method."""
+
+    def __call__(self, obj: _CE) -> _CE:
+        ...
+
+
 # match column types that are not ORM entities
 _NOT_ENTITY = TypeVar(
     "_NOT_ENTITY",
index fa36c09fcffd5a66df1d3f694be06d6ad5705a2c..56d88bc2fb8881fba9c127b8818ff813bfd05a7b 100644 (file)
@@ -454,9 +454,23 @@ def _deep_annotate(
     return element
 
 
+@overload
+def _deep_deannotate(
+    element: Literal[None], values: Optional[Sequence[str]] = None
+) -> Literal[None]:
+    ...
+
+
+@overload
 def _deep_deannotate(
     element: _SA, values: Optional[Sequence[str]] = None
 ) -> _SA:
+    ...
+
+
+def _deep_deannotate(
+    element: Optional[_SA], values: Optional[Sequence[str]] = None
+) -> Optional[_SA]:
     """Deep copy the given element, removing annotations."""
 
     cloned: Dict[Any, SupportsAnnotations] = {}
@@ -482,9 +496,7 @@ def _deep_deannotate(
     return element
 
 
-def _shallow_annotate(
-    element: SupportsAnnotations, annotations: _AnnotationDict
-) -> SupportsAnnotations:
+def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA:
     """Annotate the given ClauseElement and copy its internals so that
     internal objects refer to the new annotated object.
 
index 248b48a250b33195c5b8558e6e35cccd7a516235..f5a9c10c0cdc40482be70a5e2b9a490610bbec00 100644 (file)
@@ -750,6 +750,17 @@ class _MetaOptions(type):
         o1.__dict__.update(other)
         return o1
 
+    if TYPE_CHECKING:
+
+        def __getattr__(self, key: str) -> Any:
+            ...
+
+        def __setattr__(self, key: str, value: Any) -> None:
+            ...
+
+        def __delattr__(self, key: str) -> None:
+            ...
+
 
 class Options(metaclass=_MetaOptions):
     """A cacheable option dictionary with defaults."""
@@ -904,6 +915,17 @@ class Options(metaclass=_MetaOptions):
         else:
             return existing_options, exec_options
 
+    if TYPE_CHECKING:
+
+        def __getattr__(self, key: str) -> Any:
+            ...
+
+        def __setattr__(self, key: str, value: Any) -> None:
+            ...
+
+        def __delattr__(self, key: str) -> None:
+            ...
+
 
 class CacheableOptions(Options, HasCacheKey):
     __slots__ = ()
index eef5cf211ea1d214fe8d2bd3366d32c5ab04b9b2..501188b1277bcfcc47ec3f74c776cbe0972b520e 100644 (file)
@@ -56,6 +56,7 @@ if typing.TYPE_CHECKING:
     from .elements import ColumnClause
     from .elements import ColumnElement
     from .elements import DQLDMLClauseElement
+    from .elements import NamedColumn
     from .elements import SQLCoreOperations
     from .schema import Column
     from .selectable import _ColumnsClauseElement
@@ -197,6 +198,15 @@ def expect(
     ...
 
 
+@overload
+def expect(
+    role: Type[roles.LabeledColumnExprRole[Any]],
+    element: _ColumnExpressionArgument[_T],
+    **kw: Any,
+) -> NamedColumn[_T]:
+    ...
+
+
 @overload
 def expect(
     role: Union[
@@ -217,6 +227,7 @@ def expect(
         Type[roles.LimitOffsetRole],
         Type[roles.WhereHavingRole],
         Type[roles.OnClauseRole],
+        Type[roles.ColumnArgumentRole],
     ],
     element: Any,
     **kw: Any,
index 41b7f6392ef4c45d3dab9118c4f39649ad678bc5..61c5379d8acafea5cad81fd3dd4a2beb73821a7e 100644 (file)
@@ -503,7 +503,7 @@ class ClauseElement(
 
     def params(
         self: SelfClauseElement,
-        __optionaldict: Optional[Dict[str, Any]] = None,
+        __optionaldict: Optional[Mapping[str, Any]] = None,
         **kwargs: Any,
     ) -> SelfClauseElement:
         """Return a copy with :func:`_expression.bindparam` elements
@@ -525,7 +525,7 @@ class ClauseElement(
     def _replace_params(
         self: SelfClauseElement,
         unique: bool,
-        optionaldict: Optional[Dict[str, Any]],
+        optionaldict: Optional[Mapping[str, Any]],
         kwargs: Dict[str, Any],
     ) -> SelfClauseElement:
 
@@ -545,7 +545,7 @@ class ClauseElement(
             {"bindparam": visit_bindparam},
         )
 
-    def compare(self, other, **kw):
+    def compare(self, other: ClauseElement, **kw: Any) -> bool:
         r"""Compare this :class:`_expression.ClauseElement` to
         the given :class:`_expression.ClauseElement`.
 
@@ -2516,7 +2516,9 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]):
         return False_._singleton
 
     @classmethod
-    def _ifnone(cls, other):
+    def _ifnone(
+        cls, other: Optional[ColumnElement[Any]]
+    ) -> ColumnElement[Any]:
         if other is None:
             return cls._instance()
         else:
@@ -4226,7 +4228,13 @@ class NamedColumn(KeyedColumnElement[_T]):
     ) -> Optional[str]:
         return name
 
-    def _bind_param(self, operator, obj, type_=None, expanding=False):
+    def _bind_param(
+        self,
+        operator: OperatorType,
+        obj: Any,
+        type_: Optional[TypeEngine[_T]] = None,
+        expanding: bool = False,
+    ) -> BindParameter[_T]:
         return BindParameter(
             self.key,
             obj,
index d0b0f147615435f959e6a602da52cb9bf2084e8c..fd98f17e32fdda7dcb3af2166bec9d1711fe85e7 100644 (file)
@@ -64,6 +64,7 @@ from .base import _EntityNamespace
 from .base import _expand_cloned
 from .base import _from_objects
 from .base import _generative
+from .base import _NoArg
 from .base import _select_iterables
 from .base import CacheableOptions
 from .base import ColumnCollection
@@ -131,6 +132,7 @@ if TYPE_CHECKING:
     from .dml import Insert
     from .dml import Update
     from .elements import KeyedColumnElement
+    from .elements import Label
     from .elements import NamedColumn
     from .elements import TextClause
     from .functions import Function
@@ -212,7 +214,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
         """
         raise NotImplementedError()
 
-    def is_derived_from(self, fromclause: FromClause) -> bool:
+    def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         """Return ``True`` if this :class:`.ReturnsRows` is
         'derived' from the given :class:`.FromClause`.
 
@@ -778,7 +780,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         """
         return TableSample._construct(self, sampling, name, seed)
 
-    def is_derived_from(self, fromclause: FromClause) -> bool:
+    def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         """Return ``True`` if this :class:`_expression.FromClause` is
         'derived' from the given ``FromClause``.
 
@@ -1128,11 +1130,14 @@ class SelectLabelStyle(Enum):
 
     """
 
+    LABEL_STYLE_LEGACY_ORM = 3
+
 
 (
     LABEL_STYLE_NONE,
     LABEL_STYLE_TABLENAME_PLUS_COL,
     LABEL_STYLE_DISAMBIGUATE_ONLY,
+    _,
 ) = list(SelectLabelStyle)
 
 LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY
@@ -1231,7 +1236,7 @@ class Join(roles.DMLTableRole, FromClause):
             id(self.right),
         )
 
-    def is_derived_from(self, fromclause: FromClause) -> bool:
+    def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         return (
             # use hash() to ensure direct comparison to annotated works
             # as well
@@ -1635,7 +1640,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
         """Legacy for dialects that are referring to Alias.original."""
         return self.element
 
-    def is_derived_from(self, fromclause: FromClause) -> bool:
+    def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         if fromclause in self._cloned_set:
             return True
         return self.element.is_derived_from(fromclause)
@@ -2840,7 +2845,7 @@ class FromGrouping(GroupedElement, FromClause):
     def foreign_keys(self):
         return self.element.foreign_keys
 
-    def is_derived_from(self, fromclause: FromClause) -> bool:
+    def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         return self.element.is_derived_from(fromclause)
 
     def alias(
@@ -3080,11 +3085,17 @@ class ForUpdateArg(ClauseElement):
 
     def __init__(
         self,
-        nowait=False,
-        read=False,
-        of=None,
-        skip_locked=False,
-        key_share=False,
+        *,
+        nowait: bool = False,
+        read: bool = False,
+        of: Optional[
+            Union[
+                _ColumnExpressionArgument[Any],
+                Sequence[_ColumnExpressionArgument[Any]],
+            ]
+        ] = None,
+        skip_locked: bool = False,
+        key_share: bool = False,
     ):
         """Represents arguments specified to
         :meth:`_expression.Select.for_update`.
@@ -3455,7 +3466,7 @@ class SelectBase(
 
         return ScalarSelect(self)
 
-    def label(self, name):
+    def label(self, name: Optional[str]) -> Label[Any]:
         """Return a 'scalar' representation of this selectable, embedded as a
         subquery with a label.
 
@@ -3667,6 +3678,7 @@ class GenerativeSelect(SelectBase, Generative):
     @_generative
     def with_for_update(
         self: SelfGenerativeSelect,
+        *,
         nowait: bool = False,
         read: bool = False,
         of: Optional[
@@ -4064,7 +4076,11 @@ class GenerativeSelect(SelectBase, Generative):
 
     @_generative
     def order_by(
-        self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+        self: SelfGenerativeSelect,
+        __first: Union[
+            Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+        ] = _NoArg.NO_ARG,
+        *clauses: _ColumnExpressionArgument[Any],
     ) -> SelfGenerativeSelect:
         r"""Return a new selectable with the given list of ORDER BY
         criteria applied.
@@ -4092,18 +4108,22 @@ class GenerativeSelect(SelectBase, Generative):
 
         """
 
-        if len(clauses) == 1 and clauses[0] is None:
+        if not clauses and __first is None:
             self._order_by_clauses = ()
-        else:
+        elif __first is not _NoArg.NO_ARG:
             self._order_by_clauses += tuple(
                 coercions.expect(roles.OrderByRole, clause)
-                for clause in clauses
+                for clause in (__first,) + clauses
             )
         return self
 
     @_generative
     def group_by(
-        self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+        self: SelfGenerativeSelect,
+        __first: Union[
+            Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+        ] = _NoArg.NO_ARG,
+        *clauses: _ColumnExpressionArgument[Any],
     ) -> SelfGenerativeSelect:
         r"""Return a new selectable with the given list of GROUP BY
         criterion applied.
@@ -4128,12 +4148,12 @@ class GenerativeSelect(SelectBase, Generative):
 
         """
 
-        if len(clauses) == 1 and clauses[0] is None:
+        if not clauses and __first is None:
             self._group_by_clauses = ()
-        else:
+        elif __first is not _NoArg.NO_ARG:
             self._group_by_clauses += tuple(
                 coercions.expect(roles.GroupByRole, clause)
-                for clause in clauses
+                for clause in (__first,) + clauses
             )
         return self
 
@@ -4257,7 +4277,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
     ) -> GroupedElement:
         return SelectStatementGrouping(self)
 
-    def is_derived_from(self, fromclause: FromClause) -> bool:
+    def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         for s in self.selects:
             if s.is_derived_from(fromclause):
                 return True
@@ -4959,7 +4979,7 @@ class Select(
 
     _raw_columns: List[_ColumnsClauseElement]
 
-    _distinct = False
+    _distinct: bool = False
     _distinct_on: Tuple[ColumnElement[Any], ...] = ()
     _correlate: Tuple[FromClause, ...] = ()
     _correlate_except: Optional[Tuple[FromClause, ...]] = None
@@ -5478,8 +5498,8 @@ class Select(
 
         return iter(self._all_selected_columns)
 
-    def is_derived_from(self, fromclause: FromClause) -> bool:
-        if self in fromclause._cloned_set:
+    def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
+        if fromclause is not None and self in fromclause._cloned_set:
             return True
 
         for f in self._iterate_from_elements():
index aceed99a5d9369d429c1290dfef98ca94cabbb65..94e635740e7f391ed04405abf5a4ef54c891df29 100644 (file)
@@ -19,6 +19,7 @@ from typing import Callable
 from typing import Deque
 from typing import Dict
 from typing import Iterable
+from typing import Optional
 from typing import Set
 from typing import Tuple
 from typing import Type
@@ -39,7 +40,7 @@ COMPARE_FAILED = False
 COMPARE_SUCCEEDED = True
 
 
-def compare(obj1, obj2, **kw):
+def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
     strategy: TraversalComparatorStrategy
     if kw.get("use_proxies", False):
         strategy = ColIdentityComparatorStrategy()
@@ -49,7 +50,7 @@ def compare(obj1, obj2, **kw):
     return strategy.compare(obj1, obj2, **kw)
 
 
-def _preconfigure_traversals(target_hierarchy):
+def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
     for cls in util.walk_subclasses(target_hierarchy):
         if hasattr(cls, "_generate_cache_attrs") and hasattr(
             cls, "_traverse_internals"
@@ -482,14 +483,22 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
 
     def __init__(self):
         self.stack: Deque[
-            Tuple[ExternallyTraversible, ExternallyTraversible]
+            Tuple[
+                Optional[ExternallyTraversible],
+                Optional[ExternallyTraversible],
+            ]
         ] = deque()
         self.cache = set()
 
     def _memoized_attr_anon_map(self):
         return (anon_map(), anon_map())
 
-    def compare(self, obj1, obj2, **kw):
+    def compare(
+        self,
+        obj1: ExternallyTraversible,
+        obj2: ExternallyTraversible,
+        **kw: Any,
+    ) -> bool:
         stack = self.stack
         cache = self.cache
 
@@ -551,6 +560,10 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
                 elif left_attrname in attributes_compared:
                     continue
 
+                assert left_visit_sym is not None
+                assert left_attrname is not None
+                assert right_attrname is not None
+
                 dispatch = self.dispatch(left_visit_sym)
                 assert dispatch, (
                     f"{self.__class__} has no dispatch for "
@@ -595,6 +608,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
         self, attrname, left_parent, left, right_parent, right, **kw
     ):
         for l, r in zip_longest(left, right, fillvalue=None):
+            if l is None:
+                if r is not None:
+                    return COMPARE_FAILED
+                else:
+                    continue
+            elif r is None:
+                return COMPARE_FAILED
+
             if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
                 self.anon_map[1], []
             ):
@@ -604,6 +625,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
         self, attrname, left_parent, left, right_parent, right, **kw
     ):
         for l, r in zip_longest(left, right, fillvalue=None):
+            if l is None:
+                if r is not None:
+                    return COMPARE_FAILED
+                else:
+                    continue
+            elif r is None:
+                return COMPARE_FAILED
+
             if (
                 l._gen_cache_key(self.anon_map[0], [])
                 if l._is_has_cache_key
index 262689128d7719ac6e5f358361f1089a52ba4c97..390e23952fab6731a93d8b8844c4e09d9ea468bf 100644 (file)
@@ -73,6 +73,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
     from ._typing import _EquivalentColumnMap
     from ._typing import _TypeEngineArgument
+    from .elements import BinaryExpression
     from .elements import TextClause
     from .selectable import _JoinTargetElement
     from .selectable import _SelectIterable
@@ -86,8 +87,15 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.row import Row
 
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
 
-def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
+
+def join_condition(
+    a: FromClause,
+    b: FromClause,
+    a_subset: Optional[FromClause] = None,
+    consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
+) -> ColumnElement[bool]:
     """Create a join condition between two tables or selectables.
 
     e.g.::
@@ -118,7 +126,9 @@ def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
     )
 
 
-def find_join_source(clauses, join_to):
+def find_join_source(
+    clauses: List[FromClause], join_to: FromClause
+) -> List[int]:
     """Given a list of FROM clauses and a selectable,
     return the first index and element from the list of
     clauses which can be joined against the selectable.  returns
@@ -144,7 +154,9 @@ def find_join_source(clauses, join_to):
     return idx
 
 
-def find_left_clause_that_matches_given(clauses, join_from):
+def find_left_clause_that_matches_given(
+    clauses: Sequence[FromClause], join_from: FromClause
+) -> List[int]:
     """Given a list of FROM clauses and a selectable,
     return the indexes from the list of
     clauses which is derived from the selectable.
@@ -243,7 +255,12 @@ def find_left_clause_to_join_from(
         return idx
 
 
-def visit_binary_product(fn, expr):
+def visit_binary_product(
+    fn: Callable[
+        [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
+    ],
+    expr: ColumnElement[Any],
+) -> None:
     """Produce a traversal of the given expression, delivering
     column comparisons to the given function.
 
@@ -278,19 +295,19 @@ def visit_binary_product(fn, expr):
     a binary comparison is passed as pairs.
 
     """
-    stack: List[ClauseElement] = []
+    stack: List[BinaryExpression[Any]] = []
 
-    def visit(element):
+    def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
         if isinstance(element, ScalarSelect):
             # we don't want to dig into correlated subqueries,
             # those are just column elements by themselves
             yield element
         elif element.__visit_name__ == "binary" and operators.is_comparison(
-            element.operator
+            element.operator  # type: ignore
         ):
-            stack.insert(0, element)
-            for l in visit(element.left):
-                for r in visit(element.right):
+            stack.insert(0, element)  # type: ignore
+            for l in visit(element.left):  # type: ignore
+                for r in visit(element.right):  # type: ignore
                     fn(stack[0], l, r)
             stack.pop(0)
             for elem in element.get_children():
@@ -502,7 +519,7 @@ def extract_first_column_annotation(column, annotation_name):
     return None
 
 
-def selectables_overlap(left, right):
+def selectables_overlap(left: FromClause, right: FromClause) -> bool:
     """Return True if left/right have some overlapping selectable"""
 
     return bool(
@@ -701,7 +718,7 @@ class _repr_params(_repr_base):
             return "[%s]" % (", ".join(trunc(value) for value in params))
 
 
-def adapt_criterion_to_null(crit, nulls):
+def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
     """given criterion containing bind params, convert selected elements
     to IS NULL.
 
@@ -922,9 +939,6 @@ def criterion_as_pairs(
     return pairs
 
 
-_CE = TypeVar("_CE", bound="ClauseElement")
-
-
 class ClauseAdapter(visitors.ReplacingExternalTraversal):
     """Clones and modifies clauses based on column correspondence.
 
index 217e2d2ab41212c211ccec221fb39ea337062512..b550f8f28674c546ee6e2400fa685d6f3aa9aa92 100644 (file)
@@ -21,7 +21,6 @@ from typing import Any
 from typing import Callable
 from typing import cast
 from typing import ClassVar
-from typing import Collection
 from typing import Dict
 from typing import Iterable
 from typing import Iterator
@@ -31,6 +30,7 @@ from typing import Optional
 from typing import overload
 from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -42,6 +42,10 @@ from ..util.typing import Literal
 from ..util.typing import Protocol
 from ..util.typing import Self
 
+if TYPE_CHECKING:
+    from .annotation import _AnnotationDict
+    from .elements import ColumnElement
+
 if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
     from ._py_util import prefix_anon_map as prefix_anon_map
     from ._py_util import cache_anon_map as anon_map
@@ -590,13 +594,23 @@ _dispatch_lookup = HasTraversalDispatch._dispatch_lookup
 _generate_traversal_dispatch()
 
 
+SelfExternallyTraversible = TypeVar(
+    "SelfExternallyTraversible", bound="ExternallyTraversible"
+)
+
+
 class ExternallyTraversible(HasTraverseInternals, Visitable):
     __slots__ = ()
 
-    _annotations: Collection[Any] = ()
+    _annotations: Mapping[Any, Any] = util.EMPTY_DICT
 
     if typing.TYPE_CHECKING:
 
+        def _annotate(
+            self: SelfExternallyTraversible, values: _AnnotationDict
+        ) -> SelfExternallyTraversible:
+            ...
+
         def get_children(
             self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
         ) -> Iterable[ExternallyTraversible]:
@@ -624,6 +638,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
 
 _ET = TypeVar("_ET", bound=ExternallyTraversible)
 
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
 
 _TraverseCallableType = Callable[[_ET], None]
 
@@ -633,10 +648,8 @@ class _CloneCallableType(Protocol):
         ...
 
 
-class _TraverseTransformCallableType(Protocol):
-    def __call__(
-        self, element: ExternallyTraversible, **kw: Any
-    ) -> Optional[ExternallyTraversible]:
+class _TraverseTransformCallableType(Protocol[_ET]):
+    def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]:
         ...
 
 
@@ -1074,16 +1087,25 @@ def cloned_traverse(
 def replacement_traverse(
     obj: Literal[None],
     opts: Mapping[str, Any],
-    replace: _TraverseTransformCallableType,
+    replace: _TraverseTransformCallableType[Any],
 ) -> None:
     ...
 
 
+@overload
+def replacement_traverse(
+    obj: _CE,
+    opts: Mapping[str, Any],
+    replace: _TraverseTransformCallableType[Any],
+) -> _CE:
+    ...
+
+
 @overload
 def replacement_traverse(
     obj: ExternallyTraversible,
     opts: Mapping[str, Any],
-    replace: _TraverseTransformCallableType,
+    replace: _TraverseTransformCallableType[Any],
 ) -> ExternallyTraversible:
     ...
 
@@ -1091,7 +1113,7 @@ def replacement_traverse(
 def replacement_traverse(
     obj: Optional[ExternallyTraversible],
     opts: Mapping[str, Any],
-    replace: _TraverseTransformCallableType,
+    replace: _TraverseTransformCallableType[Any],
 ) -> Optional[ExternallyTraversible]:
     """Clone the given expression structure, allowing element
     replacement by a given replacement function.
@@ -1134,7 +1156,7 @@ def replacement_traverse(
             newelem = replace(elem)
             if newelem is not None:
                 stop_on.add(id(newelem))
-                return newelem
+                return newelem  # type: ignore
             else:
                 # base "already seen" on id(), not hash, so that we don't
                 # replace an Annotated element with its non-annotated one, and
@@ -1145,11 +1167,11 @@ def replacement_traverse(
                         newelem = kw["replace"](elem)
                         if newelem is not None:
                             cloned[id_elem] = newelem
-                            return newelem
+                            return newelem  # type: ignore
 
                     cloned[id_elem] = newelem = elem._clone(**kw)
                     newelem._copy_internals(clone=clone, **kw)
-                return cloned[id_elem]
+                return cloned[id_elem]  # type: ignore
 
     if obj is not None:
         obj = clone(
index 7150dedcf89ac61c6207563f78d3042a22d7522c..54be2e4e5b59fc05b4cd35b51e9740a1c9634632 100644 (file)
@@ -71,7 +71,7 @@ _T_co = TypeVar("_T_co", covariant=True)
 EMPTY_SET: FrozenSet[Any] = frozenset()
 
 
-def merge_lists_w_ordering(a, b):
+def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
     """merge two lists, maintaining ordering as much as possible.
 
     this is to reconcile vars(cls) with cls.__annotations__.
@@ -450,7 +450,7 @@ def to_set(x):
         return x
 
 
-def to_column_set(x):
+def to_column_set(x: Any) -> Set[Any]:
     if x is None:
         return column_set()
     if not isinstance(x, column_set):
index 24fa0f3e38e30cfe894376a09dbb20ddea6cf9b7..adbbf143f9d7e9cbdba4696501a59d4994428d73 100644 (file)
@@ -20,11 +20,14 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Iterable
 from typing import List
 from typing import Mapping
 from typing import Optional
 from typing import Sequence
+from typing import Set
 from typing import Tuple
+from typing import Type
 
 
 py311 = sys.version_info >= (3, 11)
@@ -225,7 +228,7 @@ def inspect_formatargspec(
     return result
 
 
-def dataclass_fields(cls):
+def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
     """Return a sequence of all dataclasses.Field objects associated
     with a class."""
 
@@ -235,12 +238,12 @@ def dataclass_fields(cls):
         return []
 
 
-def local_dataclass_fields(cls):
+def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
     """Return a sequence of all dataclasses.Field objects associated with
     a class, excluding those that originate from a superclass."""
 
     if dataclasses.is_dataclass(cls):
-        super_fields = set()
+        super_fields: Set[dataclasses.Field[Any]] = set()
         for sup in cls.__bases__:
             super_fields.update(dataclass_fields(sup))
         return [f for f in dataclasses.fields(cls) if f not in super_fields]
index 24c66bfa4e4373d53cf284c009b2f25d37b0114e..e54f3347582769da5fc445729e12cc5a65d52e96 100644 (file)
@@ -266,13 +266,31 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
         metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
         metadata.update(format_argspec_plus(spec, grouped=False))
         metadata["name"] = fn.__name__
-        code = (
-            """\
+
+        # look for __ positional arguments.  This is a convention in
+        # SQLAlchemy that arguments should be passed positionally
+        # rather than as keyword
+        # arguments.  note that apply_pos doesn't currently work in all cases
+        # such as when a kw-only indicator "*" is present, which is why
+        # we limit the use of this to just that case we can detect.  As we add
+        # more kinds of methods that use @decorator, things may have to
+        # be further improved in this area
+        if "__" in repr(spec[0]):
+            code = (
+                """\
+def %(name)s%(grouped_args)s:
+    return %(target)s(%(fn)s, %(apply_pos)s)
+"""
+                % metadata
+            )
+        else:
+            code = (
+                """\
 def %(name)s%(grouped_args)s:
     return %(target)s(%(fn)s, %(apply_kw)s)
 """
-            % metadata
-        )
+                % metadata
+            )
         env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
 
         decorated = cast(
@@ -1235,10 +1253,10 @@ class HasMemoized:
             return result
 
     @classmethod
-    def memoized_instancemethod(cls, fn: Any) -> Any:
+    def memoized_instancemethod(cls, fn: _F) -> _F:
         """Decorate a method memoize its return value."""
 
-        def oneshot(self, *args, **kw):
+        def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
             result = fn(self, *args, **kw)
 
             def memo(*a, **kw):
@@ -1250,7 +1268,7 @@ class HasMemoized:
             self._memoized_keys |= {fn.__name__}
             return result
 
-        return update_wrapper(oneshot, fn)
+        return update_wrapper(oneshot, fn)  # type: ignore
 
 
 if TYPE_CHECKING:
index fce3cd3b0bf561c447b52ce40213217bd3473dd1..67394c9a3a2b4d391fad44eaf1e2992a19f52275 100644 (file)
@@ -25,8 +25,12 @@ _FN = TypeVar("_FN", bound=Callable[..., Any])
 
 if TYPE_CHECKING:
     from sqlalchemy.engine import default as engine_default  # noqa
+    from sqlalchemy.orm import clsregistry as orm_clsregistry  # noqa
+    from sqlalchemy.orm import decl_api as orm_decl_api  # noqa
+    from sqlalchemy.orm import properties as orm_properties  # noqa
     from sqlalchemy.orm import relationships as orm_relationships  # noqa
     from sqlalchemy.orm import session as orm_session  # noqa
+    from sqlalchemy.orm import state as orm_state  # noqa
     from sqlalchemy.orm import util as orm_util  # noqa
     from sqlalchemy.sql import dml as sql_dml  # noqa
     from sqlalchemy.sql import functions as sql_functions  # noqa
index 37297103efd0d2ccea43e1e3d71048a7e6e617fe..24e478b573f1a7f6c42c4a9597c4c5c9c3243b7c 100644 (file)
@@ -4,21 +4,33 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
 
 """Topological sorting algorithms."""
 
 from __future__ import annotations
 
+from typing import Any
+from typing import DefaultDict
+from typing import Iterable
+from typing import Iterator
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+
 from .. import util
 from ..exc import CircularDependencyError
 
+_T = TypeVar("_T", bound=Any)
+
 __all__ = ["sort", "sort_as_subsets", "find_cycles"]
 
 
-def sort_as_subsets(tuples, allitems):
+def sort_as_subsets(
+    tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
+) -> Iterator[Sequence[_T]]:
 
-    edges = util.defaultdict(set)
+    edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
     for parent, child in tuples:
         edges[child].add(parent)
 
@@ -43,7 +55,11 @@ def sort_as_subsets(tuples, allitems):
         yield output
 
 
-def sort(tuples, allitems, deterministic_order=True):
+def sort(
+    tuples: Iterable[Tuple[_T, _T]],
+    allitems: Iterable[_T],
+    deterministic_order: bool = True,
+) -> Iterator[_T]:
     """sort the given list of items by dependency.
 
     'tuples' is a list of tuples representing a partial ordering.
@@ -59,11 +75,14 @@ def sort(tuples, allitems, deterministic_order=True):
             yield s
 
 
-def find_cycles(tuples, allitems):
+def find_cycles(
+    tuples: Iterable[Tuple[_T, _T]],
+    allitems: Iterable[_T],
+) -> Set[_T]:
     # adapted from:
     # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
 
-    edges = util.defaultdict(set)
+    edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
     for parent, child in tuples:
         edges[parent].add(child)
     nodes_to_test = set(edges)
@@ -99,5 +118,5 @@ def find_cycles(tuples, allitems):
     return output
 
 
-def _gen_edges(edges):
-    return set([(right, left) for left in edges for right in edges[left]])
+def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]:
+    return {(right, left) for left in edges for right in edges[left]}
index ebcae28a7a37b88b4b25e1524057777d33fc0c36..44e26f60940cdcd3c9bc33af24545f692ad9ad20 100644 (file)
@@ -11,7 +11,9 @@ from typing import Dict
 from typing import ForwardRef
 from typing import Generic
 from typing import Iterable
+from typing import NoReturn
 from typing import Optional
+from typing import overload
 from typing import Tuple
 from typing import Type
 from typing import TypeVar
@@ -33,7 +35,7 @@ Self = TypeVar("Self", bound=Any)
 if compat.py310:
     # why they took until py310 to put this in stdlib is beyond me,
     # I've been wanting it since py27
-    from types import NoneType
+    from types import NoneType as NoneType
 else:
     NoneType = type(None)  # type: ignore
 
@@ -68,6 +70,8 @@ else:
 # copied from TypeShed, required in order to implement
 # MutableMapping.update()
 
+_AnnotationScanType = Union[Type[Any], str]
+
 
 class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
     def keys(self) -> Iterable[_KT]:
@@ -90,9 +94,9 @@ else:
 
 def de_stringify_annotation(
     cls: Type[Any],
-    annotation: Union[str, Type[Any]],
+    annotation: _AnnotationScanType,
     str_cleanup_fn: Optional[Callable[[str], str]] = None,
-) -> Union[str, Type[Any]]:
+) -> Type[Any]:
     """Resolve annotations that may be string based into real objects.
 
     This is particularly important if a module defines "from __future__ import
@@ -125,20 +129,32 @@ def de_stringify_annotation(
             annotation = eval(annotation, base_globals, None)
         except NameError:
             pass
-    return annotation
+    return annotation  # type: ignore
 
 
-def is_fwd_ref(type_):
+def is_fwd_ref(type_: _AnnotationScanType) -> bool:
     return isinstance(type_, ForwardRef)
 
 
-def de_optionalize_union_types(type_):
+@overload
+def de_optionalize_union_types(type_: str) -> str:
+    ...
+
+
+@overload
+def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]:
+    ...
+
+
+def de_optionalize_union_types(
+    type_: _AnnotationScanType,
+) -> _AnnotationScanType:
     """Given a type, filter out ``Union`` types that include ``NoneType``
     to not include the ``NoneType``.
 
     """
     if is_optional(type_):
-        typ = set(type_.__args__)
+        typ = set(type_.__args__)  # type: ignore
 
         typ.discard(NoneType)
 
@@ -148,14 +164,14 @@ def de_optionalize_union_types(type_):
         return type_
 
 
-def make_union_type(*types):
+def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
     """Make a Union type.
 
     This is needed by :func:`.de_optionalize_union_types` which removes
     ``NoneType`` from a ``Union``.
 
     """
-    return cast(Any, Union).__getitem__(types)
+    return cast(Any, Union).__getitem__(types)  # type: ignore
 
 
 def expand_unions(
@@ -251,4 +267,47 @@ class DescriptorReference(Generic[_DESC]):
         ...
 
 
+_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
+
+
+class RODescriptorReference(Generic[_DESC_co]):
+    """a descriptor that refers to a descriptor.
+
+    same as :class:`.DescriptorReference` but is read-only, so that subclasses
+    can define a subtype as the generically contained element
+
+    """
+
+    def __get__(self, instance: object, owner: Any) -> _DESC_co:
+        ...
+
+    def __set__(self, instance: Any, value: Any) -> NoReturn:
+        ...
+
+    def __delete__(self, instance: Any) -> NoReturn:
+        ...
+
+
+_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
+
+
+class CallableReference(Generic[_FN]):
+    """a descriptor that refers to a callable.
+
+    works around mypy's limitation of not allowing callables assigned
+    as instance variables
+
+
+    """
+
+    def __get__(self, instance: object, owner: Any) -> _FN:
+        ...
+
+    def __set__(self, instance: Any, value: _FN) -> None:
+        ...
+
+    def __delete__(self, instance: Any) -> None:
+        ...
+
+
 # $def ro_descriptor_reference(fn: Callable[])
index f8498fde94e7f8d072179a782908da31582dcd21..29d59ea698bd38560e50006cfc5d598a04d6842a 100644 (file)
@@ -68,21 +68,6 @@ strict = true
 # pass
 module = [
 
-    # TODO for ORM, non-strict
-    "sqlalchemy.orm.base",
-    "sqlalchemy.orm.decl_base",
-    "sqlalchemy.orm.descriptor_props",
-    "sqlalchemy.orm.identity",
-    "sqlalchemy.orm.mapped_collection",
-    "sqlalchemy.orm.properties",
-    "sqlalchemy.orm.relationships",
-    "sqlalchemy.orm.strategy_options",
-    "sqlalchemy.orm.state_changes",
-
-    # would ideally be strict
-    "sqlalchemy.orm.decl_api",
-    "sqlalchemy.orm.events",
-    "sqlalchemy.orm.query",
     "sqlalchemy.engine.reflection",
 
 ]
diff --git a/test/ext/mypy/plain_files/composite.py b/test/ext/mypy/plain_files/composite.py
new file mode 100644 (file)
index 0000000..c699633
--- /dev/null
@@ -0,0 +1,67 @@
+from typing import Any
+from typing import Tuple
+
+from sqlalchemy import select
+from sqlalchemy.orm import composite
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Point:
+    def __init__(self, x: int, y: int):
+        self.x = x
+        self.y = y
+
+    def __composite_values__(self) -> Tuple[int, int]:
+        return self.x, self.y
+
+    def __repr__(self) -> str:
+        return "Point(x=%r, y=%r)" % (self.x, self.y)
+
+    def __eq__(self, other: Any) -> bool:
+        return (
+            isinstance(other, Point)
+            and other.x == self.x
+            and other.y == self.y
+        )
+
+    def __ne__(self, other: Any) -> bool:
+        return not self.__eq__(other)
+
+
+class Vertex(Base):
+    __tablename__ = "vertices"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    x1: Mapped[int]
+    y1: Mapped[int]
+    x2: Mapped[int]
+    y2: Mapped[int]
+
+    # inferred from right hand side
+    start = composite(Point, "x1", "y1")
+
+    # taken from left hand side
+    end: Mapped[Point] = composite(Point, "x2", "y2")
+
+
+v1 = Vertex(start=Point(3, 4), end=Point(5, 6))
+
+stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)]))
+
+# EXPECTED_TYPE: Select[Tuple[Vertex]]
+reveal_type(stmt)
+
+# EXPECTED_TYPE: composite.Point
+reveal_type(v1.start)
+
+# EXPECTED_TYPE: composite.Point
+reveal_type(v1.end)
+
+# EXPECTED_TYPE: int
+reveal_type(v1.end.y)
index 1086f187af451ad8843c5b3d6287879967e2ae97..37f99502dbabe9e0c594bf0124164da87bdff5da 100644 (file)
@@ -258,7 +258,7 @@ class MypyPluginTest(fixtures.TestBase):
                             )
 
                             expected_msg = re.sub(
-                                r"(int|str|float|bool)",
+                                r"\b(int|str|float|bool)\b",
                                 lambda m: rf"builtins.{m.group(0)}\*?",
                                 expected_msg,
                             )
index 9f9f8e601e23ccf081d78b384b0bcc261bbee2b0..4990056c3e708eb1653ee6f30d40700c8c2a2ac4 100644 (file)
@@ -133,7 +133,9 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
 
         reg = registry(metadata=metadata)
 
-        reg.map_declaratively(User)
+        mp = reg.map_declaratively(User)
+        assert mp is inspect(User)
+        assert mp is User.__mapper__
 
     def test_undefer_column_name(self):
         # TODO: not sure if there was an explicit
@@ -186,6 +188,53 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
             class_mapper(User).get_property("props").secondary is user_to_prop
         )
 
+    def test_string_dependency_resolution_schemas_no_base(self):
+        """
+
+        found_during_type_annotation
+
+        """
+
+        reg = registry()
+
+        @reg.mapped
+        class User:
+
+            __tablename__ = "users"
+            __table_args__ = {"schema": "fooschema"}
+
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+            props = relationship(
+                "Prop",
+                secondary="fooschema.user_to_prop",
+                primaryjoin="User.id==fooschema.user_to_prop.c.user_id",
+                secondaryjoin="fooschema.user_to_prop.c.prop_id==Prop.id",
+                backref="users",
+            )
+
+        @reg.mapped
+        class Prop:
+
+            __tablename__ = "props"
+            __table_args__ = {"schema": "fooschema"}
+
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+
+        user_to_prop = Table(
+            "user_to_prop",
+            reg.metadata,
+            Column("user_id", Integer, ForeignKey("fooschema.users.id")),
+            Column("prop_id", Integer, ForeignKey("fooschema.props.id")),
+            schema="fooschema",
+        )
+        configure_mappers()
+
+        assert (
+            class_mapper(User).get_property("props").secondary is user_to_prop
+        )
+
     def test_string_dependency_resolution_annotations(self):
         Base = declarative_base()
 
@@ -290,6 +339,51 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
         reg = registry(metadata=metadata)
         reg.mapped(User)
         reg.mapped(Address)
+
+        reg.metadata.create_all(testing.db)
+        u1 = User(
+            name="u1", addresses=[Address(email="one"), Address(email="two")]
+        )
+        with Session(testing.db) as sess:
+            sess.add(u1)
+            sess.commit()
+        with Session(testing.db) as sess:
+            eq_(
+                sess.query(User).all(),
+                [
+                    User(
+                        name="u1",
+                        addresses=[Address(email="one"), Address(email="two")],
+                    )
+                ],
+            )
+
+    def test_map_declaratively(self, metadata):
+        class User(fixtures.ComparableEntity):
+
+            __tablename__ = "users"
+            id = Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            )
+            name = Column("name", String(50))
+            addresses = relationship("Address", backref="user")
+
+        class Address(fixtures.ComparableEntity):
+
+            __tablename__ = "addresses"
+            id = Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            )
+            email = Column("email", String(50))
+            user_id = Column("user_id", Integer, ForeignKey("users.id"))
+
+        reg = registry(metadata=metadata)
+        um = reg.map_declaratively(User)
+        am = reg.map_declaratively(Address)
+
+        is_(User.__mapper__, um)
+        is_(Address.__mapper__, am)
+
         reg.metadata.create_all(testing.db)
         u1 = User(
             name="u1", addresses=[Address(email="one"), Address(email="two")]
index ca3a6e6089fcc9182232a3b1f9c1983e3645bfed..fb27c910ffc6cd8b55e79c6b201d656f487649ca 100644 (file)
@@ -1,6 +1,7 @@
 import sqlalchemy as sa
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import class_mapper
@@ -14,6 +15,7 @@ from sqlalchemy.orm.decl_api import registry
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
@@ -653,6 +655,25 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
             Engineer(name="vlad", primary_language="cobol"),
         )
 
+    def test_single_cols_on_sub_base_of_subquery(self):
+        """
+        found_during_type_annotation
+
+        """
+        t = Table("t", Base.metadata, Column("id", Integer, primary_key=True))
+
+        class Person(Base):
+            __table__ = select(t).subquery()
+
+        with expect_raises_message(
+            sa.exc.ArgumentError,
+            r"Can't declare columns on single-table-inherited subclass "
+            r".*Contractor.*; superclass .*Person.* is not mapped to a Table",
+        ):
+
+            class Contractor(Person):
+                contractor_field = Column(String)
+
     def test_single_cols_on_sub_base_of_joined(self):
         """test [ticket:3895]"""
 
index 2c9cc3b21bc2b63f7c1c8c9a953a7c256281593d..e72c181100f5ac5c17625ae8283b4529220a0829 100644 (file)
@@ -965,6 +965,45 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             '"user".state, "user".zip FROM "user"',
         )
 
+    def test_name_cols_by_str(self, decl_base):
+        @dataclasses.dataclass
+        class Address:
+            street: str
+            state: str
+            zip_: str
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            name: Mapped[str]
+            street: Mapped[str]
+            state: Mapped[str]
+
+            # TODO: this needs to be improved, we should be able to say:
+            # zip_: Mapped[str] = mapped_column("zip")
+            # and it should assign to "zip_" for the attribute. not working
+
+            zip_: Mapped[str] = mapped_column(name="zip", key="zip_")
+
+            address: Mapped["Address"] = composite(
+                Address, "street", "state", "zip_"
+            )
+
+        eq_(
+            User.__mapper__.attrs["address"].props,
+            [
+                User.__mapper__.attrs["street"],
+                User.__mapper__.attrs["state"],
+                User.__mapper__.attrs["zip_"],
+            ],
+        )
+        self.assert_compile(
+            select(User),
+            'SELECT "user".id, "user".name, "user".street, '
+            '"user".state, "user".zip FROM "user"',
+        )
+
     def test_cls_annotated_setup(self, decl_base):
         @dataclasses.dataclass
         class Address:
index ab6d79383c8b0a348ea5c5bac7d300319c08bfdc..78b503873f8599cd4ebf4f2e02cd8de423b26b14 100644 (file)
@@ -24,6 +24,7 @@ from sqlalchemy.orm import Session
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
@@ -35,7 +36,9 @@ from sqlalchemy.testing.schema import Table
 from test.orm.test_events import _RemoveListeners
 
 
-class ConcreteTest(fixtures.MappedTest):
+class ConcreteTest(AssertsCompiledSQL, fixtures.MappedTest):
+    __dialect__ = "default"
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -265,6 +268,10 @@ class ConcreteTest(fixtures.MappedTest):
             "sometype",
         )
 
+        # found_during_type_annotation
+        # test the comparator returned by ConcreteInheritedProperty
+        self.assert_compile(Manager.type == "x", "pjoin.type = :type_1")
+
         jenn = Engineer("Jenn", "knows how to program")
         hacker = Hacker("Karina", "Badass", "knows how to hack")
 
index e3732bef50b85166126ee7107a61927cd97d942c..d9013b2c4a186767ead1a925e6148f4fa7834bc4 100644 (file)
@@ -2259,6 +2259,22 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL):
         )
         assert a1.c.users_id is not None
 
+    def test_no_subquery_for_from_statement(self):
+        """
+        found_during_typing
+
+        """
+        User = self.classes.User
+
+        session = fixture_session()
+        q = session.query(User.id).from_statement(text("select * from user"))
+
+        with expect_raises_message(
+            sa.exc.InvalidRequestError,
+            r"Can't call this method on a Query that uses from_statement\(\)",
+        ):
+            q.subquery()
+
     def test_reduced_subquery(self):
         User = self.classes.User
         ua = aliased(User)
@@ -6183,12 +6199,6 @@ class TextTest(QueryTest, AssertsCompiledSQL):
             "FROM users GROUP BY name",
         )
 
-    def test_orm_columns_accepts_text(self):
-        from sqlalchemy.orm.base import _orm_columns
-
-        t = text("x")
-        eq_(_orm_columns(t), [t])
-
     def test_order_by_w_eager_one(self):
         User = self.classes.User
         s = fixture_session()
index fd192ab51907f1e562cc5eb1ba7253909a99d78b..22bc549085ecd454ec0b6963e90ad715f8c9dbc6 100644 (file)
@@ -4324,7 +4324,8 @@ class InvalidRemoteSideTest(fixtures.MappedTest):
         assert_raises_message(
             sa.exc.ArgumentError,
             "T1.t1s and back-reference T1.parent are "
-            r"both of the same direction symbol\('ONETOMANY'\).  Did you "
+            r"both of the same "
+            r"direction .*RelationshipDirection.ONETOMANY.*.  Did you "
             "mean to set remote_side on the many-to-one side ?",
             configure_mappers,
         )
@@ -4347,7 +4348,8 @@ class InvalidRemoteSideTest(fixtures.MappedTest):
         assert_raises_message(
             sa.exc.ArgumentError,
             "T1.t1s and back-reference T1.parent are "
-            r"both of the same direction symbol\('MANYTOONE'\).  Did you "
+            r"both of the same direction .*RelationshipDirection.MANYTOONE.*."
+            "Did you "
             "mean to set remote_side on the many-to-one side ?",
             configure_mappers,
         )
@@ -4367,7 +4369,8 @@ class InvalidRemoteSideTest(fixtures.MappedTest):
         # can't be sure of ordering here
         assert_raises_message(
             sa.exc.ArgumentError,
-            r"both of the same direction symbol\('ONETOMANY'\).  Did you "
+            r"both of the same direction "
+            r".*RelationshipDirection.ONETOMANY.*.  Did you "
             "mean to set remote_side on the many-to-one side ?",
             configure_mappers,
         )
@@ -4391,7 +4394,8 @@ class InvalidRemoteSideTest(fixtures.MappedTest):
         # can't be sure of ordering here
         assert_raises_message(
             sa.exc.ArgumentError,
-            r"both of the same direction symbol\('MANYTOONE'\).  Did you "
+            r"both of the same direction "
+            r".*RelationshipDirection.MANYTOONE.*.  Did you "
             "mean to set remote_side on the many-to-one side ?",
             configure_mappers,
         )