]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep-484: ORM public API, constructors
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Apr 2022 15:05:36 +0000 (11:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Apr 2022 19:14:09 +0000 (15:14 -0400)
for the moment, abandoning using @overload with
relationship() and mapped_column().  The overloads
are very difficult to get working at all, and
the overloads that were there all wouldn't pass on
mypy.  various techniques of getting them to
"work", meaning having right hand side dictate
what's legal on the left, have mixed success
and wont give consistent results; additionally,
it's legal to have Optional / non-optional
independent of nullable in any case for columns.
relationship cases are less ambiguous but mypy
was not going along with things.

we have a comprehensive system of allowing
left side annotations to drive the right side,
in the absense of explicit settings on the right.
so type-centric SQLAlchemy will be left-side
driven just like dataclasses, and the various flags
and switches on the right side will just not be
needed very much.

in other matters, one surprise, forgot to remove string support
from orm.join(A, B, "somename") or do deprecations
for it in 1.4.   This is a really not-directly-used
structure barely
mentioned in the docs for many years, the example
shows a relationship being used, not a string, so
we will just change it to raise the usual error here.

Change-Id: Iefbbb8d34548b538023890ab8b7c9a5d9496ec6e

68 files changed:
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/util.py
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/ext/instrumentation.py
lib/sqlalchemy/inspection.py
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/_typing.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/exc.py
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/_elements_constructors.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/lambdas.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/_py_collections.py
lib/sqlalchemy/util/deprecations.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/preloaded.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/ext/mypy/plain_files/association_proxy_one.py
test/ext/mypy/plain_files/experimental_relationship.py
test/ext/mypy/plain_files/hybrid_one.py
test/ext/mypy/plain_files/hybrid_two.py
test/ext/mypy/plain_files/mapped_column.py
test/ext/mypy/plain_files/sql_operations.py
test/ext/mypy/plain_files/trad_relationship_uselist.py
test/ext/mypy/plain_files/traditional_relationship.py
test/ext/mypy/plugin_files/relationship_6255_one.py
test/ext/mypy/plugin_files/typing_err3.py
test/ext/test_extendedattr.py
test/orm/inheritance/test_basic.py
test/orm/test_cascade.py
test/orm/test_instrumentation.py
test/orm/test_joins.py
test/orm/test_mapper.py
test/orm/test_options.py
test/orm/test_query.py
test/orm/test_utils.py
test/sql/test_selectable.py

index 72102ac264fe77e3acffcfb8acb229a4a1b493a5..ccf5736756d7c289e065852386fd5cf701d9d60d 100644 (file)
@@ -1638,7 +1638,6 @@ class CursorResult(Result):
             :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial`
 
         """  # noqa: E501
-
         try:
             return self.context.rowcount
         except BaseException as e:
index 529b2ca73b8ea9431ceb9b003adbc19fbd338e0c..45f6bf20bf3fcd51c3da7beb11903a0e186f743f 100644 (file)
@@ -48,7 +48,7 @@ def connection_memoize(key: str) -> Callable[[_C], _C]:
             connection.info[key] = val = fn(self, connection)
             return val
 
-    return decorated  # type: ignore[return-value]
+    return decorated  # type: ignore
 
 
 class _TConsSubject(Protocol):
index 33cf3f745a445e01812406d669402190ae2b0696..c7a6e2ca0157866c5afcd2c3789695acd9139a2e 100644 (file)
@@ -76,7 +76,6 @@ if TYPE_CHECKING:
         "expunge",
         "expunge_all",
         "flush",
-        "get",
         "get_bind",
         "is_modified",
         "invalidate",
@@ -204,6 +203,49 @@ class async_scoped_session:
             await self.registry().close()
         self.registry.clear()
 
+    async def get(
+        self,
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: Optional[ForUpdateArg] = None,
+        identity_token: Optional[Any] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+    ) -> Optional[_O]:
+        r"""Return an instance based on the given primary key identifier,
+        or ``None`` if not found.
+
+        .. container:: class_bases
+
+            Proxied for the :class:`_asyncio.AsyncSession` class on
+            behalf of the :class:`_asyncio.scoping.async_scoped_session` class.
+
+        .. seealso::
+
+            :meth:`_orm.Session.get` - main documentation for get
+
+
+
+        """  # noqa: E501
+
+        # this was proxied but Mypy is requiring the return type to be
+        # clarified
+
+        # work around:
+        # https://github.com/python/typing/discussions/1143
+        return_value = await self._proxied.get(
+            entity,
+            ident,
+            options=options,
+            populate_existing=populate_existing,
+            with_for_update=with_for_update,
+            identity_token=identity_token,
+            execution_options=execution_options,
+        )
+        return return_value
+
     # START PROXY METHODS async_scoped_session
 
     # code within this block is **programmatically,
@@ -632,43 +674,6 @@ class async_scoped_session:
 
         return await self._proxied.flush(objects=objects)
 
-    async def get(
-        self,
-        entity: _EntityBindKey[_O],
-        ident: _PKIdentityArgument,
-        *,
-        options: Optional[Sequence[ORMOption]] = None,
-        populate_existing: bool = False,
-        with_for_update: Optional[ForUpdateArg] = None,
-        identity_token: Optional[Any] = None,
-        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
-    ) -> Optional[_O]:
-        r"""Return an instance based on the given primary key identifier,
-        or ``None`` if not found.
-
-        .. container:: class_bases
-
-            Proxied for the :class:`_asyncio.AsyncSession` class on
-            behalf of the :class:`_asyncio.scoping.async_scoped_session` class.
-
-        .. seealso::
-
-            :meth:`_orm.Session.get` - main documentation for get
-
-
-
-        """  # noqa: E501
-
-        return await self._proxied.get(
-            entity,
-            ident,
-            options=options,
-            populate_existing=populate_existing,
-            with_for_update=with_for_update,
-            identity_token=identity_token,
-            execution_options=execution_options,
-        )
-
     def get_bind(
         self,
         mapper: Optional[_EntityBindKey[_O]] = None,
index be872804e22132fd817a6e3a9a42f74af476cdc9..7200414a183116fd001e1b6edb1e40e8ca507b7c 100644 (file)
@@ -832,6 +832,8 @@ from ..util.typing import Protocol
 
 
 if TYPE_CHECKING:
+    from ..orm._typing import _ORMColumnExprArgument
+    from ..orm.interfaces import MapperProperty
     from ..orm.util import AliasedInsp
     from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _DMLColumnArgument
@@ -840,7 +842,6 @@ if TYPE_CHECKING:
     from ..sql.operators import OperatorType
     from ..sql.roles import ColumnsClauseRole
 
-
 _T = TypeVar("_T", bound=Any)
 _T_co = TypeVar("_T_co", bound=Any, covariant=True)
 _T_con = TypeVar("_T_con", bound=Any, contravariant=True)
@@ -1289,7 +1290,7 @@ class Comparator(interfaces.PropComparator[_T]):
     ):
         self.expression = expression
 
-    def __clause_element__(self) -> ColumnsClauseRole:
+    def __clause_element__(self) -> _ORMColumnExprArgument[_T]:
         expr = self.expression
         if is_has_clause_element(expr):
             ret_expr = expr.__clause_element__()
@@ -1298,10 +1299,15 @@ class Comparator(interfaces.PropComparator[_T]):
                 assert isinstance(expr, ColumnElement)
             ret_expr = expr
 
+        if TYPE_CHECKING:
+            # see test_hybrid->test_expression_isnt_clause_element
+            # that exercises the usual place this is caught if not
+            # true
+            assert isinstance(ret_expr, ColumnElement)
         return ret_expr
 
-    @util.non_memoized_property
-    def property(self) -> Any:
+    @util.ro_non_memoized_property
+    def property(self) -> Optional[interfaces.MapperProperty[_T]]:
         return None
 
     def adapt_to_entity(
@@ -1325,7 +1331,7 @@ class ExprComparator(Comparator[_T]):
     def __getattr__(self, key: str) -> Any:
         return getattr(self.expression, key)
 
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def info(self) -> _InfoType:
         return self.hybrid.info
 
@@ -1339,8 +1345,8 @@ class ExprComparator(Comparator[_T]):
         else:
             return [(self.expression, value)]
 
-    @util.non_memoized_property
-    def property(self) -> Any:
+    @util.ro_non_memoized_property
+    def property(self) -> Optional[MapperProperty[_T]]:
         return self.expression.property  # type: ignore
 
     def operate(
index 72448fbdc8da3950f335b2fd820c45f1771b185b..b1138a4ad8ec8930b454dea63bc12f42a388c9f4 100644 (file)
@@ -25,6 +25,7 @@ from ..orm import exc as orm_exc
 from ..orm import instrumentation as orm_instrumentation
 from ..orm.instrumentation import _default_dict_getter
 from ..orm.instrumentation import _default_manager_getter
+from ..orm.instrumentation import _default_opt_manager_getter
 from ..orm.instrumentation import _default_state_getter
 from ..orm.instrumentation import ClassManager
 from ..orm.instrumentation import InstrumentationFactory
@@ -140,7 +141,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
         hierarchy = util.class_hierarchy(cls)
         factories = set()
         for member in hierarchy:
-            manager = self.manager_of_class(member)
+            manager = self.opt_manager_of_class(member)
             if manager is not None:
                 factories.add(manager.factory)
             else:
@@ -161,17 +162,34 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
             del self._state_finders[class_]
             del self._dict_finders[class_]
 
-    def manager_of_class(self, cls):
-        if cls is None:
-            return None
+    def opt_manager_of_class(self, cls):
         try:
-            finder = self._manager_finders.get(cls, _default_manager_getter)
+            finder = self._manager_finders.get(
+                cls, _default_opt_manager_getter
+            )
         except TypeError:
             # due to weakref lookup on invalid object
             return None
         else:
             return finder(cls)
 
+    def manager_of_class(self, cls):
+        try:
+            finder = self._manager_finders.get(cls, _default_manager_getter)
+        except TypeError:
+            # due to weakref lookup on invalid object
+            raise orm_exc.UnmappedClassError(
+                cls, f"Can't locate an instrumentation manager for class {cls}"
+            )
+        else:
+            manager = finder(cls)
+            if manager is None:
+                raise orm_exc.UnmappedClassError(
+                    cls,
+                    f"Can't locate an instrumentation manager for class {cls}",
+                )
+            return manager
+
     def state_of(self, instance):
         if instance is None:
             raise AttributeError("None has no persistent state.")
@@ -384,6 +402,7 @@ def _install_instrumented_lookups():
             instance_state=_instrumentation_factory.state_of,
             instance_dict=_instrumentation_factory.dict_of,
             manager_of_class=_instrumentation_factory.manager_of_class,
+            opt_manager_of_class=_instrumentation_factory.opt_manager_of_class,
         )
     )
 
@@ -395,16 +414,19 @@ def _reinstall_default_lookups():
             instance_state=_default_state_getter,
             instance_dict=_default_dict_getter,
             manager_of_class=_default_manager_getter,
+            opt_manager_of_class=_default_opt_manager_getter,
         )
     )
     _instrumentation_factory._extended = False
 
 
 def _install_lookups(lookups):
-    global instance_state, instance_dict, manager_of_class
+    global instance_state, instance_dict
+    global manager_of_class, opt_manager_of_class
     instance_state = lookups["instance_state"]
     instance_dict = lookups["instance_dict"]
     manager_of_class = lookups["manager_of_class"]
+    opt_manager_of_class = lookups["opt_manager_of_class"]
     orm_base.instance_state = (
         attributes.instance_state
     ) = orm_instrumentation.instance_state = instance_state
@@ -414,3 +436,6 @@ def _install_lookups(lookups):
     orm_base.manager_of_class = (
         attributes.manager_of_class
     ) = orm_instrumentation.manager_of_class = manager_of_class
+    orm_base.opt_manager_of_class = (
+        attributes.opt_manager_of_class
+    ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class
index 6b06c0d6b64b2b153efe973cdeb599e11c3a9891..01c740fe414cfa8d1052581888a41a5f4948c7da 100644 (file)
@@ -34,6 +34,7 @@ from typing import Any
 from typing import Callable
 from typing import Dict
 from typing import Generic
+from typing import Optional
 from typing import overload
 from typing import Type
 from typing import TypeVar
@@ -43,6 +44,9 @@ from . import exc
 from .util.typing import Literal
 
 _T = TypeVar("_T", bound=Any)
+_F = TypeVar("_F", bound=Callable[..., Any])
+
+_IN = TypeVar("_IN", bound="Inspectable[Any]")
 
 _registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {}
 
@@ -53,11 +57,22 @@ class Inspectable(Generic[_T]):
     This allows typing to set up a linkage between an object that
     can be inspected and the type of inspection it returns.
 
+    Unfortunately we cannot at the moment get all classes that are
+    returned by inspection to suit this interface as we get into
+    MRO issues.
+
     """
 
+    __slots__ = ()
+
 
 @overload
-def inspect(subject: Inspectable[_T], raiseerr: bool = True) -> _T:
+def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN:
+    ...
+
+
+@overload
+def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]:
     ...
 
 
@@ -108,9 +123,9 @@ def inspect(subject: Any, raiseerr: bool = True) -> Any:
 
 
 def _inspects(
-    *types: type,
-) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]:
-    def decorate(fn_or_cls: Callable[[Any], Any]) -> Callable[[Any], Any]:
+    *types: Type[Any],
+) -> Callable[[_F], _F]:
+    def decorate(fn_or_cls: _F) -> _F:
         for type_ in types:
             if type_ in _registrars:
                 raise AssertionError(
@@ -122,7 +137,10 @@ def _inspects(
     return decorate
 
 
-def _self_inspects(cls: Type[_T]) -> Type[_T]:
+_TT = TypeVar("_TT", bound="Type[Any]")
+
+
+def _self_inspects(cls: _TT) -> _TT:
     if cls in _registrars:
         raise AssertionError("Type %s is already " "registered" % cls)
     _registrars[cls] = True
index 7690c05dec80593f4ef0b72931b6975938a9e3e6..457ad5c5a6b1354cc1136d8e974980b3c98ecdcd 100644 (file)
@@ -9,20 +9,19 @@ from __future__ import annotations
 
 import typing
 from typing import Any
+from typing import Callable
 from typing import Collection
-from typing import Dict
-from typing import List
 from typing import Optional
 from typing import overload
-from typing import Set
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import Union
 
-from . import mapper as mapperlib
+from . import mapperlib as mapperlib
+from ._typing import _O
 from .base import Mapped
 from .descriptor_props import Composite
 from .descriptor_props import Synonym
-from .mapper import Mapper
 from .properties import ColumnProperty
 from .properties import MappedColumn
 from .query import AliasOption
@@ -37,11 +36,29 @@ from .. import sql
 from .. import util
 from ..exc import InvalidRequestError
 from ..sql.base import SchemaEventTarget
-from ..sql.selectable import Alias
+from ..sql.schema import SchemaConst
 from ..sql.selectable import FromClause
-from ..sql.type_api import TypeEngine
 from ..util.typing import Literal
 
+if TYPE_CHECKING:
+    from ._typing import _EntityType
+    from ._typing import _ORMColumnExprArgument
+    from .descriptor_props import _CompositeAttrType
+    from .interfaces import PropComparator
+    from .query import Query
+    from .relationships import _LazyLoadArgumentType
+    from .relationships import _ORMBackrefArgument
+    from .relationships import _ORMColCollectionArgument
+    from .relationships import _ORMOrderByArgument
+    from .relationships import _RelationshipJoinConditionArgument
+    from ..sql._typing import _ColumnExpressionArgument
+    from ..sql._typing import _InfoType
+    from ..sql._typing import _TypeEngineArgument
+    from ..sql.schema import _ServerDefaultType
+    from ..sql.schema import FetchedValue
+    from ..sql.selectable import Alias
+    from ..sql.selectable import Subquery
+
 _T = typing.TypeVar("_T")
 
 
@@ -61,7 +78,7 @@ SynonymProperty = Synonym
     "for entities to be matched up to a query that is established "
     "via :meth:`.Query.from_statement` and now does nothing.",
 )
-def contains_alias(alias) -> AliasOption:
+def contains_alias(alias: Union[Alias, Subquery]) -> AliasOption:
     r"""Return a :class:`.MapperOption` that will indicate to the
     :class:`_query.Query`
     that the main table has been aliased.
@@ -70,134 +87,36 @@ def contains_alias(alias) -> AliasOption:
     return AliasOption(alias)
 
 
-# see test/ext/mypy/plain_files/mapped_column.py for mapped column
-# typing tests
-
-
-@overload
-def mapped_column(
-    __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
-    *args: SchemaEventTarget,
-    nullable: Literal[None] = ...,
-    primary_key: Literal[None] = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[Any]":
-    ...
-
-
-@overload
-def mapped_column(
-    __name: str,
-    __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
-    *args: SchemaEventTarget,
-    nullable: Literal[None] = ...,
-    primary_key: Literal[None] = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[Any]":
-    ...
-
-
-@overload
-def mapped_column(
-    __name: str,
-    __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
-    *args: SchemaEventTarget,
-    nullable: Literal[True] = ...,
-    primary_key: Literal[None] = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[Optional[_T]]":
-    ...
-
-
-@overload
-def mapped_column(
-    __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
-    *args: SchemaEventTarget,
-    nullable: Literal[True] = ...,
-    primary_key: Literal[None] = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[Optional[_T]]":
-    ...
-
-
-@overload
-def mapped_column(
-    __name: str,
-    __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
-    *args: SchemaEventTarget,
-    nullable: Literal[False] = ...,
-    primary_key: Literal[None] = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[_T]":
-    ...
-
-
-@overload
-def mapped_column(
-    __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
-    *args: SchemaEventTarget,
-    nullable: Literal[False] = ...,
-    primary_key: Literal[None] = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[_T]":
-    ...
-
-
-@overload
-def mapped_column(
-    __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
-    *args: SchemaEventTarget,
-    nullable: bool = ...,
-    primary_key: Literal[True] = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[_T]":
-    ...
-
-
-@overload
-def mapped_column(
-    __name: str,
-    __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
-    *args: SchemaEventTarget,
-    nullable: bool = ...,
-    primary_key: Literal[True] = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[_T]":
-    ...
-
-
-@overload
-def mapped_column(
-    __name: str,
-    *args: SchemaEventTarget,
-    nullable: bool = ...,
-    primary_key: bool = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[Any]":
-    ...
-
-
-@overload
 def mapped_column(
+    __name_pos: Optional[
+        Union[str, _TypeEngineArgument[Any], SchemaEventTarget]
+    ] = None,
+    __type_pos: Optional[
+        Union[_TypeEngineArgument[Any], SchemaEventTarget]
+    ] = None,
     *args: SchemaEventTarget,
-    nullable: bool = ...,
-    primary_key: bool = ...,
-    deferred: bool = ...,
-    **kw: Any,
-) -> "MappedColumn[Any]":
-    ...
-
-
-def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]":
+    nullable: Optional[
+        Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
+    ] = SchemaConst.NULL_UNSPECIFIED,
+    primary_key: Optional[bool] = False,
+    deferred: bool = False,
+    name: Optional[str] = None,
+    type_: Optional[_TypeEngineArgument[Any]] = None,
+    autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
+    default: Optional[Any] = None,
+    doc: Optional[str] = None,
+    key: Optional[str] = None,
+    index: Optional[bool] = None,
+    unique: Optional[bool] = None,
+    info: Optional[_InfoType] = None,
+    onupdate: Optional[Any] = None,
+    server_default: Optional[_ServerDefaultType] = None,
+    server_onupdate: Optional[FetchedValue] = None,
+    quote: Optional[bool] = None,
+    system: bool = False,
+    comment: Optional[str] = None,
+    **dialect_kwargs: Any,
+) -> MappedColumn[Any]:
     r"""construct a new ORM-mapped :class:`_schema.Column` construct.
 
     The :func:`_orm.mapped_column` function provides an ORM-aware and
@@ -363,12 +282,45 @@ def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]":
 
     """
 
-    return MappedColumn(*args, **kw)
+    return MappedColumn(
+        __name_pos,
+        __type_pos,
+        *args,
+        name=name,
+        type_=type_,
+        autoincrement=autoincrement,
+        default=default,
+        doc=doc,
+        key=key,
+        index=index,
+        unique=unique,
+        info=info,
+        nullable=nullable,
+        onupdate=onupdate,
+        primary_key=primary_key,
+        server_default=server_default,
+        server_onupdate=server_onupdate,
+        quote=quote,
+        comment=comment,
+        system=system,
+        deferred=deferred,
+        **dialect_kwargs,
+    )
 
 
 def column_property(
-    column: sql.ColumnElement[_T], *additional_columns, **kwargs
-) -> "ColumnProperty[_T]":
+    column: _ORMColumnExprArgument[_T],
+    *additional_columns: _ORMColumnExprArgument[Any],
+    group: Optional[str] = None,
+    deferred: bool = False,
+    raiseload: bool = False,
+    comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+    descriptor: Optional[Any] = None,
+    active_history: bool = False,
+    expire_on_flush: bool = True,
+    info: Optional[_InfoType] = None,
+    doc: Optional[str] = None,
+) -> ColumnProperty[_T]:
     r"""Provide a column-level property for use with a mapping.
 
     Column-based properties can normally be applied to the mapper's
@@ -452,13 +404,25 @@ def column_property(
         expressions
 
     """
-    return ColumnProperty(column, *additional_columns, **kwargs)
+    return ColumnProperty(
+        column,
+        *additional_columns,
+        group=group,
+        deferred=deferred,
+        raiseload=raiseload,
+        comparator_factory=comparator_factory,
+        descriptor=descriptor,
+        active_history=active_history,
+        expire_on_flush=expire_on_flush,
+        info=info,
+        doc=doc,
+    )
 
 
 @overload
 def composite(
     class_: Type[_T],
-    *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
+    *attrs: _CompositeAttrType[Any],
     **kwargs: Any,
 ) -> Composite[_T]:
     ...
@@ -466,7 +430,7 @@ def composite(
 
 @overload
 def composite(
-    *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
+    *attrs: _CompositeAttrType[Any],
     **kwargs: Any,
 ) -> Composite[Any]:
     ...
@@ -474,7 +438,7 @@ def composite(
 
 def composite(
     class_: Any = None,
-    *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
+    *attrs: _CompositeAttrType[Any],
     **kwargs: Any,
 ) -> Composite[Any]:
     r"""Return a composite column-based property for use with a Mapper.
@@ -529,13 +493,13 @@ def composite(
 
 
 def with_loader_criteria(
-    entity_or_base,
-    where_criteria,
-    loader_only=False,
-    include_aliases=False,
-    propagate_to_loaders=True,
-    track_closure_variables=True,
-) -> "LoaderCriteriaOption":
+    entity_or_base: _EntityType[Any],
+    where_criteria: _ColumnExpressionArgument[bool],
+    loader_only: bool = False,
+    include_aliases: bool = False,
+    propagate_to_loaders: bool = True,
+    track_closure_variables: bool = True,
+) -> LoaderCriteriaOption:
     """Add additional WHERE criteria to the load for all occurrences of
     a particular entity.
 
@@ -711,180 +675,40 @@ def with_loader_criteria(
     )
 
 
-@overload
-def relationship(
-    argument: str,
-    secondary=...,
-    *,
-    uselist: bool = ...,
-    collection_class: Literal[None] = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[Any]:
-    ...
-
-
-@overload
-def relationship(
-    argument: str,
-    secondary=...,
-    *,
-    uselist: bool = ...,
-    collection_class: Type[Set] = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[Set[Any]]:
-    ...
-
-
-@overload
-def relationship(
-    argument: str,
-    secondary=...,
-    *,
-    uselist: bool = ...,
-    collection_class: Type[List] = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[List[Any]]:
-    ...
-
-
-@overload
-def relationship(
-    argument: Optional[_RelationshipArgumentType[_T]],
-    secondary=...,
-    *,
-    uselist: Literal[False] = ...,
-    collection_class: Literal[None] = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[_T]:
-    ...
-
-
-@overload
-def relationship(
-    argument: Optional[_RelationshipArgumentType[_T]],
-    secondary=...,
-    *,
-    uselist: Literal[True] = ...,
-    collection_class: Literal[None] = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[List[_T]]:
-    ...
-
-
-@overload
-def relationship(
-    argument: Optional[_RelationshipArgumentType[_T]],
-    secondary=...,
-    *,
-    uselist: Union[Literal[None], Literal[True]] = ...,
-    collection_class: Type[List] = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[List[_T]]:
-    ...
-
-
-@overload
-def relationship(
-    argument: Optional[_RelationshipArgumentType[_T]],
-    secondary=...,
-    *,
-    uselist: Union[Literal[None], Literal[True]] = ...,
-    collection_class: Type[Set] = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[Set[_T]]:
-    ...
-
-
-@overload
-def relationship(
-    argument: Optional[_RelationshipArgumentType[_T]],
-    secondary=...,
-    *,
-    uselist: Union[Literal[None], Literal[True]] = ...,
-    collection_class: Type[Dict[Any, Any]] = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[Dict[Any, _T]]:
-    ...
-
-
-@overload
-def relationship(
-    argument: _RelationshipArgumentType[_T],
-    secondary=...,
-    *,
-    uselist: Literal[None] = ...,
-    collection_class: Literal[None] = ...,
-    primaryjoin=...,
-    secondaryjoin=None,
-    back_populates=None,
-    **kw: Any,
-) -> Relationship[Any]:
-    ...
-
-
-@overload
-def relationship(
-    argument: Optional[_RelationshipArgumentType[_T]] = ...,
-    secondary=...,
-    *,
-    uselist: Literal[True] = ...,
-    collection_class: Any = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[Any]:
-    ...
-
-
-@overload
 def relationship(
-    argument: Literal[None] = ...,
-    secondary=...,
-    *,
-    uselist: Optional[bool] = ...,
-    collection_class: Any = ...,
-    primaryjoin=...,
-    secondaryjoin=...,
-    back_populates=...,
-    **kw: Any,
-) -> Relationship[Any]:
-    ...
-
-
-def relationship(
-    argument: Optional[_RelationshipArgumentType[_T]] = None,
-    secondary=None,
+    argument: Optional[_RelationshipArgumentType[Any]] = None,
+    secondary: Optional[FromClause] = None,
     *,
     uselist: Optional[bool] = None,
-    collection_class: Optional[Type[Collection]] = None,
-    primaryjoin=None,
-    secondaryjoin=None,
-    back_populates=None,
+    collection_class: Optional[
+        Union[Type[Collection[Any]], Callable[[], Collection[Any]]]
+    ] = None,
+    primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+    secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+    back_populates: Optional[str] = None,
+    order_by: _ORMOrderByArgument = False,
+    backref: Optional[_ORMBackrefArgument] = None,
+    overlaps: Optional[str] = None,
+    post_update: bool = False,
+    cascade: str = "save-update, merge",
+    viewonly: bool = False,
+    lazy: _LazyLoadArgumentType = "select",
+    passive_deletes: bool = False,
+    passive_updates: bool = True,
+    active_history: bool = False,
+    enable_typechecks: bool = True,
+    foreign_keys: Optional[_ORMColCollectionArgument] = None,
+    remote_side: Optional[_ORMColCollectionArgument] = None,
+    join_depth: Optional[int] = None,
+    comparator_factory: Optional[Type[PropComparator[Any]]] = None,
+    single_parent: bool = False,
+    innerjoin: bool = False,
+    distinct_target_key: Optional[bool] = None,
+    load_on_pending: bool = False,
+    query_class: Optional[Type[Query[Any]]] = None,
+    info: Optional[_InfoType] = None,
+    omit_join: Literal[None, False] = None,
+    sync_backref: Optional[bool] = None,
     **kw: Any,
 ) -> Relationship[Any]:
     """Provide a relationship between two mapped classes.
@@ -1098,13 +922,6 @@ def relationship(
 
             :ref:`error_qzyx` - usage example
 
-    :param bake_queries=True:
-        Legacy parameter, not used.
-
-          .. versionchanged:: 1.4.23 the "lambda caching" system is no longer
-             used by loader strategies and the ``bake_queries`` parameter
-             has no effect.
-
     :param cascade:
       A comma-separated list of cascade rules which determines how
       Session operations should be "cascaded" from parent to child.
@@ -1701,18 +1518,42 @@ def relationship(
         primaryjoin=primaryjoin,
         secondaryjoin=secondaryjoin,
         back_populates=back_populates,
+        order_by=order_by,
+        backref=backref,
+        overlaps=overlaps,
+        post_update=post_update,
+        cascade=cascade,
+        viewonly=viewonly,
+        lazy=lazy,
+        passive_deletes=passive_deletes,
+        passive_updates=passive_updates,
+        active_history=active_history,
+        enable_typechecks=enable_typechecks,
+        foreign_keys=foreign_keys,
+        remote_side=remote_side,
+        join_depth=join_depth,
+        comparator_factory=comparator_factory,
+        single_parent=single_parent,
+        innerjoin=innerjoin,
+        distinct_target_key=distinct_target_key,
+        load_on_pending=load_on_pending,
+        query_class=query_class,
+        info=info,
+        omit_join=omit_join,
+        sync_backref=sync_backref,
         **kw,
     )
 
 
 def synonym(
-    name,
-    map_column=None,
-    descriptor=None,
-    comparator_factory=None,
-    doc=None,
-    info=None,
-) -> "Synonym[Any]":
+    name: str,
+    *,
+    map_column: Optional[bool] = None,
+    descriptor: Optional[Any] = None,
+    comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+    info: Optional[_InfoType] = None,
+    doc: Optional[str] = None,
+) -> Synonym[Any]:
     """Denote an attribute name as a synonym to a mapped property,
     in that the attribute will mirror the value and expression behavior
     of another attribute.
@@ -1951,8 +1792,8 @@ def deferred(*columns, **kw):
 
 
 def query_expression(
-    default_expr: sql.ColumnElement[_T] = sql.null(),
-) -> "Mapped[_T]":
+    default_expr: _ORMColumnExprArgument[_T] = sql.null(),
+) -> Mapped[_T]:
     """Indicate an attribute that populates from a query-time SQL expression.
 
     :param default_expr: Optional SQL expression object that will be used in
@@ -2010,33 +1851,33 @@ def clear_mappers():
 
 @overload
 def aliased(
-    element: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"],
-    alias=None,
-    name=None,
-    flat=False,
-    adapt_on_names=False,
-) -> "AliasedClass[_T]":
+    element: _EntityType[_O],
+    alias: Optional[Union[Alias, Subquery]] = None,
+    name: Optional[str] = None,
+    flat: bool = False,
+    adapt_on_names: bool = False,
+) -> AliasedClass[_O]:
     ...
 
 
 @overload
 def aliased(
-    element: "FromClause",
-    alias=None,
-    name=None,
-    flat=False,
-    adapt_on_names=False,
-) -> "Alias":
+    element: FromClause,
+    alias: Optional[Union[Alias, Subquery]] = None,
+    name: Optional[str] = None,
+    flat: bool = False,
+    adapt_on_names: bool = False,
+) -> FromClause:
     ...
 
 
 def aliased(
-    element: Union[Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"],
-    alias=None,
-    name=None,
-    flat=False,
-    adapt_on_names=False,
-) -> Union["AliasedClass[_T]", "Alias"]:
+    element: Union[_EntityType[_O], FromClause],
+    alias: Optional[Union[Alias, Subquery]] = None,
+    name: Optional[str] = None,
+    flat: bool = False,
+    adapt_on_names: bool = False,
+) -> Union[AliasedClass[_O], FromClause]:
     """Produce an alias of the given element, usually an :class:`.AliasedClass`
     instance.
 
@@ -2233,9 +2074,7 @@ def with_polymorphic(
     )
 
 
-def join(
-    left, right, onclause=None, isouter=False, full=False, join_to_left=None
-):
+def join(left, right, onclause=None, isouter=False, full=False):
     r"""Produce an inner join between left and right clauses.
 
     :func:`_orm.join` is an extension to the core join interface
@@ -2270,16 +2109,11 @@ def join(
     See :ref:`orm_queryguide_joins` for information on modern usage
     of ORM level joins.
 
-    .. deprecated:: 0.8
-
-        the ``join_to_left`` parameter is deprecated, and will be removed
-        in a future release.  The parameter has no effect.
-
     """
     return _ORMJoin(left, right, onclause, isouter, full)
 
 
-def outerjoin(left, right, onclause=None, full=False, join_to_left=None):
+def outerjoin(left, right, onclause=None, full=False):
     """Produce a left outer join between left and right clauses.
 
     This is the "outer join" version of the :func:`_orm.join` function,
index 4250cdbe1f6c532ba74648313af73f874237ec38..339844f14756da63f55c3529c4864e3dad7838d8 100644 (file)
@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import operator
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import Optional
 from typing import Tuple
@@ -10,7 +11,9 @@ from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
-from sqlalchemy.orm.interfaces import UserDefinedOption
+from ..sql import roles
+from ..sql._typing import _HasClauseElement
+from ..sql.elements import ColumnElement
 from ..util.typing import Protocol
 from ..util.typing import TypeGuard
 
@@ -18,8 +21,12 @@ if TYPE_CHECKING:
     from .attributes import AttributeImpl
     from .attributes import CollectionAttributeImpl
     from .base import PassiveFlag
+    from .decl_api import registry as _registry_type
     from .descriptor_props import _CompositeClassProto
+    from .interfaces import MapperProperty
+    from .interfaces import UserDefinedOption
     from .mapper import Mapper
+    from .relationships import Relationship
     from .state import InstanceState
     from .util import AliasedClass
     from .util import AliasedInsp
@@ -27,21 +34,39 @@ if TYPE_CHECKING:
 
 _T = TypeVar("_T", bound=Any)
 
+
+# I would have preferred this were bound=object however it seems
+# to not travel in all situations when defined in that way.
 _O = TypeVar("_O", bound=Any)
 """The 'ORM mapped object' type.
-I would have preferred this were bound=object however it seems
-to not travel in all situations when defined in that way.
+
 """
 
+if TYPE_CHECKING:
+    _RegistryType = _registry_type
+
 _InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"]
 
-_EntityType = Union[_T, "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"]
+_EntityType = Union[
+    Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"
+]
 
 
 _InstanceDict = Dict[str, Any]
 
 _IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]]
 
+_ORMColumnExprArgument = Union[
+    ColumnElement[_T],
+    _HasClauseElement,
+    roles.ExpressionElementRole[_T],
+]
+
+# somehow Protocol didn't want to work for this one
+_ORMAdapterProto = Callable[
+    [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T]
+]
+
 
 class _LoaderCallable(Protocol):
     def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any:
@@ -60,10 +85,28 @@ def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]:
 
 if TYPE_CHECKING:
 
+    def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]:
+        ...
+
+    def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]:
+        ...
+
+    def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]:
+        ...
+
+    def prop_is_relationship(
+        prop: MapperProperty[Any],
+    ) -> TypeGuard[Relationship[Any]]:
+        ...
+
     def is_collection_impl(
         impl: AttributeImpl,
     ) -> TypeGuard[CollectionAttributeImpl]:
         ...
 
 else:
+    insp_is_mapper_property = operator.attrgetter("is_property")
+    insp_is_mapper = operator.attrgetter("is_mapper")
+    insp_is_aliased_class = operator.attrgetter("is_aliased_class")
     is_collection_impl = operator.attrgetter("collection")
+    prop_is_relationship = operator.attrgetter("_is_relationship")
index 33ce96a192b1bf803b2e739ea8457ce66e7581df..41d944c57d3bdd0ee3d3c628a0846ed267124661 100644 (file)
@@ -44,7 +44,7 @@ from .base import instance_dict as instance_dict
 from .base import instance_state as instance_state
 from .base import instance_str
 from .base import LOAD_AGAINST_COMMITTED
-from .base import manager_of_class
+from .base import manager_of_class as manager_of_class
 from .base import Mapped as Mapped  # noqa
 from .base import NEVER_SET  # noqa
 from .base import NO_AUTOFLUSH
@@ -52,6 +52,7 @@ from .base import NO_CHANGE  # noqa
 from .base import NO_RAISE
 from .base import NO_VALUE
 from .base import NON_PERSISTENT_OK  # noqa
+from .base import opt_manager_of_class as opt_manager_of_class
 from .base import PASSIVE_CLASS_MISMATCH  # noqa
 from .base import PASSIVE_NO_FETCH
 from .base import PASSIVE_NO_FETCH_RELATED  # noqa
@@ -74,6 +75,7 @@ from ..sql import traversals
 from ..sql import visitors
 
 if TYPE_CHECKING:
+    from .interfaces import MapperProperty
     from .state import InstanceState
     from ..sql.dml import _DMLColumnElement
     from ..sql.elements import ColumnElement
@@ -146,7 +148,7 @@ class QueryableAttribute(
         self._of_type = of_type
         self._extra_criteria = extra_criteria
 
-        manager = manager_of_class(class_)
+        manager = opt_manager_of_class(class_)
         # manager is None in the case of AliasedClass
         if manager:
             # propagate existing event listeners from
@@ -370,7 +372,7 @@ class QueryableAttribute(
         return "%s.%s" % (self.class_.__name__, self.key)
 
     @util.memoized_property
-    def property(self):
+    def property(self) -> MapperProperty[_T]:
         """Return the :class:`.MapperProperty` associated with this
         :class:`.QueryableAttribute`.
 
index 3fa855a4bd38f03187265363a05c3c6bb8e42d9f..054d52d83bd8f5a5f59c1db6c968c81792da8568 100644 (file)
@@ -26,24 +26,25 @@ from typing import TypeVar
 from typing import Union
 
 from . import exc
+from ._typing import insp_is_mapper
 from .. import exc as sa_exc
 from .. import inspection
 from .. import util
 from ..sql.elements import SQLCoreOperations
 from ..util import FastIntFlag
 from ..util.langhelpers import TypingOnly
-from ..util.typing import Concatenate
 from ..util.typing import Literal
-from ..util.typing import ParamSpec
 from ..util.typing import Self
 
 if typing.TYPE_CHECKING:
     from ._typing import _InternalEntityType
     from .attributes import InstrumentedAttribute
+    from .instrumentation import ClassManager
     from .mapper import Mapper
     from .state import InstanceState
     from ..sql._typing import _InfoType
 
+
 _T = TypeVar("_T", bound=Any)
 
 _O = TypeVar("_O", bound=object)
@@ -246,21 +247,15 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
 _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
 
 
-_Fn = TypeVar("_Fn", bound=Callable)
-_Args = ParamSpec("_Args")
+_F = TypeVar("_F", bound=Callable)
 _Self = TypeVar("_Self")
 
 
 def _assertions(
     *assertions: Any,
-) -> Callable[
-    [Callable[Concatenate[_Self, _Fn, _Args], _Self]],
-    Callable[Concatenate[_Self, _Fn, _Args], _Self],
-]:
+) -> Callable[[_F], _F]:
     @util.decorator
-    def generate(
-        fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
-    ) -> _Self:
+    def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self:
         for assertion in assertions:
             assertion(self, fn.__name__)
         fn(self, *args, **kw)
@@ -269,13 +264,13 @@ def _assertions(
     return generate
 
 
-# these can be replaced by sqlalchemy.ext.instrumentation
-# if augmented class instrumentation is enabled.
-def manager_of_class(cls):
-    return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
+if TYPE_CHECKING:
 
+    def manager_of_class(cls: Type[Any]) -> ClassManager:
+        ...
 
-if TYPE_CHECKING:
+    def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]:
+        ...
 
     def instance_state(instance: _O) -> InstanceState[_O]:
         ...
@@ -284,6 +279,20 @@ if TYPE_CHECKING:
         ...
 
 else:
+    # these can be replaced by sqlalchemy.ext.instrumentation
+    # if augmented class instrumentation is enabled.
+
+    def manager_of_class(cls):
+        try:
+            return cls.__dict__[DEFAULT_MANAGER_ATTR]
+        except KeyError as ke:
+            raise exc.UnmappedClassError(
+                cls, f"Can't locate an instrumentation manager for class {cls}"
+            ) from ke
+
+    def opt_manager_of_class(cls):
+        return cls.__dict__.get(DEFAULT_MANAGER_ATTR)
+
     instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
 
     instance_dict = operator.attrgetter("__dict__")
@@ -458,11 +467,12 @@ else:
     _state_mapper = util.dottedgetter("manager.mapper")
 
 
-@inspection._inspects(type)
-def _inspect_mapped_class(class_, configure=False):
+def _inspect_mapped_class(
+    class_: Type[_O], configure: bool = False
+) -> Optional[Mapper[_O]]:
     try:
-        class_manager = manager_of_class(class_)
-        if not class_manager.is_mapped:
+        class_manager = opt_manager_of_class(class_)
+        if class_manager is None or not class_manager.is_mapped:
             return None
         mapper = class_manager.mapper
     except exc.NO_STATE:
@@ -473,7 +483,28 @@ def _inspect_mapped_class(class_, configure=False):
         return mapper
 
 
-def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]:
+@inspection._inspects(type)
+def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]:
+    try:
+        class_manager = opt_manager_of_class(class_)
+        if class_manager is None or not class_manager.is_mapped:
+            return None
+        mapper = class_manager.mapper
+    except exc.NO_STATE:
+        return None
+    else:
+        return mapper
+
+
+def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]:
+    insp = inspection.inspect(arg, raiseerr=False)
+    if insp_is_mapper(insp):
+        return insp
+
+    raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}")
+
+
+def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]:
     """Given a class, return the primary :class:`_orm.Mapper` associated
     with the key.
 
@@ -502,8 +533,8 @@ def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]:
 
 
 class InspectionAttr:
-    """A base class applied to all ORM objects that can be returned
-    by the :func:`_sa.inspect` function.
+    """A base class applied to all ORM objects and attributes that are
+    related to things that can be returned by the :func:`_sa.inspect` function.
 
     The attributes defined here allow the usage of simple boolean
     checks to test basic facts about the object returned.
index 419da65f7fc2829369a46ee47bb0f801c38d1604..4fee2d383d6558bffbaef58425950935948a7202 100644 (file)
@@ -63,11 +63,14 @@ from ..sql.visitors import InternalTraversal
 
 if TYPE_CHECKING:
     from ._typing import _InternalEntityType
+    from .mapper import Mapper
+    from .query import Query
     from ..sql.compiler import _CompilerStackEntry
     from ..sql.dml import _DMLTableElement
     from ..sql.elements import ColumnElement
     from ..sql.selectable import _LabelConventionCallable
     from ..sql.selectable import SelectBase
+    from ..sql.type_api import TypeEngine
 
 _path_registry = PathRegistry.root
 
@@ -211,6 +214,9 @@ class ORMCompileState(CompileState):
         _for_refresh_state = False
         _render_for_subquery = False
 
+    attributes: Dict[Any, Any]
+    global_attributes: Dict[Any, Any]
+
     statement: Union[Select, FromStatement]
     select_statement: Union[Select, FromStatement]
     _entities: List[_QueryEntity]
@@ -1930,7 +1936,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             assert right_mapper
 
             adapter = ORMAdapter(
-                right, equivalents=right_mapper._equivalent_columns
+                inspect(right), equivalents=right_mapper._equivalent_columns
             )
 
             # if an alias() on the right side was generated,
@@ -2075,14 +2081,16 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
 
 def _column_descriptions(
-    query_or_select_stmt, compile_state=None, legacy=False
+    query_or_select_stmt: Union[Query, Select, FromStatement],
+    compile_state: Optional[ORMSelectCompileState] = None,
+    legacy: bool = False,
 ) -> List[ORMColumnDescription]:
     if compile_state is None:
         compile_state = ORMSelectCompileState._create_entities_collection(
             query_or_select_stmt, legacy=legacy
         )
     ctx = compile_state
-    return [
+    d = [
         {
             "name": ent._label_name,
             "type": ent.type,
@@ -2093,17 +2101,10 @@ def _column_descriptions(
             else None,
         }
         for ent, insp_ent in [
-            (
-                _ent,
-                (
-                    inspect(_ent.entity_zero)
-                    if _ent.entity_zero is not None
-                    else None
-                ),
-            )
-            for _ent in ctx._entities
+            (_ent, _ent.entity_zero) for _ent in ctx._entities
         ]
     ]
+    return d
 
 
 def _legacy_filter_by_entity_zero(query_or_augmented_select):
@@ -2157,6 +2158,11 @@ class _QueryEntity:
     _null_column_type = False
     use_id_for_hash = False
 
+    _label_name: Optional[str]
+    type: Union[Type[Any], TypeEngine[Any]]
+    expr: Union[_InternalEntityType, ColumnElement[Any]]
+    entity_zero: Optional[_InternalEntityType]
+
     def setup_compile_state(self, compile_state: ORMCompileState) -> None:
         raise NotImplementedError()
 
@@ -2234,6 +2240,13 @@ class _MapperEntity(_QueryEntity):
         "_polymorphic_discriminator",
     )
 
+    expr: _InternalEntityType
+    mapper: Mapper[Any]
+    entity_zero: _InternalEntityType
+    is_aliased_class: bool
+    path: PathRegistry
+    _label_name: str
+
     def __init__(
         self, compile_state, entity, entities_collection, is_current_entities
     ):
@@ -2389,6 +2402,13 @@ class _BundleEntity(_QueryEntity):
         "supports_single_entity",
     )
 
+    _entities: List[_QueryEntity]
+    bundle: Bundle
+    type: Type[Any]
+    _label_name: str
+    supports_single_entity: bool
+    expr: Bundle
+
     def __init__(
         self,
         compile_state,
index 70507015bce5048cb2d6cf062a0e686e6f8b5d2e..0c990f80998de4c86d396027b534f54c9dc65d90 100644 (file)
@@ -50,7 +50,7 @@ from ..util import hybridproperty
 from ..util import typing as compat_typing
 
 if typing.TYPE_CHECKING:
-    from .state import InstanceState  # noqa
+    from .state import InstanceState
 
 _T = TypeVar("_T", bound=Any)
 
@@ -280,7 +280,7 @@ class declared_attr(interfaces._MappedAttribute[_T]):
         # for the span of the declarative scan_attributes() phase.
         # to achieve this we look at the class manager that's configured.
         cls = owner
-        manager = attributes.manager_of_class(cls)
+        manager = attributes.opt_manager_of_class(cls)
         if manager is None:
             if not re.match(r"^__.+__$", self.fget.__name__):
                 # if there is no manager at all, then this class hasn't been
@@ -1294,8 +1294,8 @@ def as_declarative(**kw):
 @inspection._inspects(
     DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept
 )
-def _inspect_decl_meta(cls):
-    mp = _inspect_mapped_class(cls)
+def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]:
+    mp: Mapper[Any] = _inspect_mapped_class(cls)
     if mp is None:
         if _DeferredMapperConfig.has_cls(cls):
             _DeferredMapperConfig.raise_unmapped_for_cls(cls)
index 804d05ce1930b4f412cfc34830e1b14fa8fe151b..9c79a4172b1ce0c6df161420d2541d3cb6ef3923 100644 (file)
@@ -12,6 +12,8 @@ import collections
 from typing import Any
 from typing import Dict
 from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
 import weakref
 
 from . import attributes
@@ -42,6 +44,10 @@ from ..sql.schema import Column
 from ..sql.schema import Table
 from ..util import topological
 
+if TYPE_CHECKING:
+    from ._typing import _O
+    from ._typing import _RegistryType
+
 
 def _declared_mapping_info(cls):
     # deferred mapping
@@ -121,7 +127,7 @@ def _dive_for_cls_manager(cls):
         return None
 
     for base in cls.__mro__:
-        manager = attributes.manager_of_class(base)
+        manager = attributes.opt_manager_of_class(base)
         if manager:
             return manager
     return None
@@ -171,7 +177,7 @@ class _MapperConfig:
 
     @classmethod
     def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
-        manager = attributes.manager_of_class(cls)
+        manager = attributes.opt_manager_of_class(cls)
         if manager and manager.class_ is cls_:
             raise exc.InvalidRequestError(
                 "Class %r already has been " "instrumented declaratively" % cls
@@ -191,7 +197,12 @@ class _MapperConfig:
 
         return cfg_cls(registry, cls_, dict_, table, mapper_kw)
 
-    def __init__(self, registry, cls_, mapper_kw):
+    def __init__(
+        self,
+        registry: _RegistryType,
+        cls_: Type[Any],
+        mapper_kw: Dict[str, Any],
+    ):
         self.cls = util.assert_arg_type(cls_, type, "cls_")
         self.classname = cls_.__name__
         self.properties = util.OrderedDict()
@@ -206,7 +217,7 @@ class _MapperConfig:
                 init_method=registry.constructor,
             )
         else:
-            manager = attributes.manager_of_class(self.cls)
+            manager = attributes.opt_manager_of_class(self.cls)
             if not manager or not manager.is_mapped:
                 raise exc.InvalidRequestError(
                     "Class %s has no primary mapper configured.  Configure "
index 8beac472e371962e04337d7933eebf5f5304319e..4738d8c2c9ea7ebdd998b48da55166da0e0184f5 100644 (file)
@@ -122,7 +122,11 @@ class DescriptorProperty(MapperProperty[_T]):
 
 
 _CompositeAttrType = Union[
-    str, "Column[Any]", "MappedColumn[Any]", "InstrumentedAttribute[Any]"
+    str,
+    "Column[_T]",
+    "MappedColumn[_T]",
+    "InstrumentedAttribute[_T]",
+    "Mapped[_T]",
 ]
 
 
index c531e7cf19a6081ada3a17e1aa2f973a3fdd364e..331c224eef83b5cbfc1a5fa2929658f3a9805785 100644 (file)
@@ -11,6 +11,9 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Optional
+from typing import Type
+from typing import TYPE_CHECKING
 import weakref
 
 from . import instrumentation
@@ -27,6 +30,10 @@ from .. import exc
 from .. import util
 from ..util.compat import inspect_getfullargspec
 
+if TYPE_CHECKING:
+    from ._typing import _O
+    from .instrumentation import ClassManager
+
 
 class InstrumentationEvents(event.Events):
     """Events related to class instrumentation events.
@@ -214,7 +221,7 @@ class InstanceEvents(event.Events):
             if issubclass(target, mapperlib.Mapper):
                 return instrumentation.ClassManager
             else:
-                manager = instrumentation.manager_of_class(target)
+                manager = instrumentation.opt_manager_of_class(target)
                 if manager:
                     return manager
                 else:
@@ -613,8 +620,8 @@ class _EventsHold(event.RefCollection):
 class _InstanceEventsHold(_EventsHold):
     all_holds = weakref.WeakKeyDictionary()
 
-    def resolve(self, class_):
-        return instrumentation.manager_of_class(class_)
+    def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]:
+        return instrumentation.opt_manager_of_class(class_)
 
     class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents):
         pass
index 00829ecbb7c9870ba0ea1abc5a5a59d56471e183..529a7cd01f0add0a07b461eab9721a446c368627 100644 (file)
@@ -203,7 +203,10 @@ def _default_unmapped(cls) -> Optional[str]:
 
     try:
         mappers = base.manager_of_class(cls).mappers
-    except (TypeError,) + NO_STATE:
+    except (
+        UnmappedClassError,
+        TypeError,
+    ) + NO_STATE:
         mappers = {}
     name = _safe_cls_name(cls)
 
index 0d4b630dad80c895b728784e578383838db09514..88ceacd076a84caad9f10be56ae78b96cba00fa8 100644 (file)
@@ -33,10 +33,13 @@ alternate instrumentation forms.
 from __future__ import annotations
 
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import Generic
 from typing import Optional
 from typing import Set
+from typing import Tuple
+from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
 import weakref
@@ -53,7 +56,9 @@ from ..util import HasMemoized
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
+    from ._typing import _RegistryType
     from .attributes import InstrumentedAttribute
+    from .decl_base import _MapperConfig
     from .mapper import Mapper
     from .state import InstanceState
     from ..event import dispatcher
@@ -72,6 +77,11 @@ class _ExpiredAttributeLoaderProto(Protocol):
         ...
 
 
+class _ManagerFactory(Protocol):
+    def __call__(self, class_: Type[_O]) -> ClassManager[_O]:
+        ...
+
+
 class ClassManager(
     HasMemoized,
     Dict[str, "InstrumentedAttribute[Any]"],
@@ -90,12 +100,12 @@ class ClassManager(
     expired_attribute_loader: _ExpiredAttributeLoaderProto
     "previously known as deferred_scalar_loader"
 
-    init_method = None
+    init_method: Optional[Callable[..., None]]
 
-    factory = None
+    factory: Optional[_ManagerFactory]
 
-    declarative_scan = None
-    registry = None
+    declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
+    registry: Optional[_RegistryType] = None
 
     @property
     @util.deprecated(
@@ -122,11 +132,13 @@ class ClassManager(
         self.local_attrs = {}
         self.originals = {}
         self._finalized = False
+        self.factory = None
+        self.init_method = None
 
         self._bases = [
             mgr
             for mgr in [
-                manager_of_class(base)
+                opt_manager_of_class(base)
                 for base in self.class_.__bases__
                 if isinstance(base, type)
             ]
@@ -139,7 +151,7 @@ class ClassManager(
         self.dispatch._events._new_classmanager_instance(class_, self)
 
         for basecls in class_.__mro__:
-            mgr = manager_of_class(basecls)
+            mgr = opt_manager_of_class(basecls)
             if mgr is not None:
                 self.dispatch._update(mgr.dispatch)
 
@@ -155,16 +167,18 @@ class ClassManager(
 
     def _update_state(
         self,
-        finalize=False,
-        mapper=None,
-        registry=None,
-        declarative_scan=None,
-        expired_attribute_loader=None,
-        init_method=None,
-    ):
+        finalize: bool = False,
+        mapper: Optional[Mapper[_O]] = None,
+        registry: Optional[_RegistryType] = None,
+        declarative_scan: Optional[_MapperConfig] = None,
+        expired_attribute_loader: Optional[
+            _ExpiredAttributeLoaderProto
+        ] = None,
+        init_method: Optional[Callable[..., None]] = None,
+    ) -> None:
 
         if mapper:
-            self.mapper = mapper
+            self.mapper = mapper  # type: ignore[assignment]
         if registry:
             registry._add_manager(self)
         if declarative_scan:
@@ -350,7 +364,7 @@ class ClassManager(
 
     def subclass_managers(self, recursive):
         for cls in self.class_.__subclasses__():
-            mgr = manager_of_class(cls)
+            mgr = opt_manager_of_class(cls)
             if mgr is not None and mgr is not self:
                 yield mgr
                 if recursive:
@@ -374,7 +388,7 @@ class ClassManager(
         self._reset_memoizations()
         del self[key]
         for cls in self.class_.__subclasses__():
-            manager = manager_of_class(cls)
+            manager = opt_manager_of_class(cls)
             if manager:
                 manager.uninstrument_attribute(key, True)
 
@@ -523,7 +537,7 @@ class _SerializeManager:
         manager.dispatch.pickle(state, d)
 
     def __call__(self, state, inst, state_dict):
-        state.manager = manager = manager_of_class(self.class_)
+        state.manager = manager = opt_manager_of_class(self.class_)
         if manager is None:
             raise exc.UnmappedInstanceError(
                 inst,
@@ -546,9 +560,9 @@ class _SerializeManager:
 class InstrumentationFactory:
     """Factory for new ClassManager instances."""
 
-    def create_manager_for_cls(self, class_):
+    def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]:
         assert class_ is not None
-        assert manager_of_class(class_) is None
+        assert opt_manager_of_class(class_) is None
 
         # give a more complicated subclass
         # a chance to do what it wants here
@@ -557,6 +571,8 @@ class InstrumentationFactory:
         if factory is None:
             factory = ClassManager
             manager = factory(class_)
+        else:
+            assert manager is not None
 
         self._check_conflicts(class_, factory)
 
@@ -564,11 +580,15 @@ class InstrumentationFactory:
 
         return manager
 
-    def _locate_extended_factory(self, class_):
+    def _locate_extended_factory(
+        self, class_: Type[_O]
+    ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]:
         """Overridden by a subclass to do an extended lookup."""
         return None, None
 
-    def _check_conflicts(self, class_, factory):
+    def _check_conflicts(
+        self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]]
+    ):
         """Overridden by a subclass to test for conflicting factories."""
         return
 
@@ -590,24 +610,25 @@ instance_state = _default_state_getter = base.instance_state
 instance_dict = _default_dict_getter = base.instance_dict
 
 manager_of_class = _default_manager_getter = base.manager_of_class
+opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class
 
 
 def register_class(
-    class_,
-    finalize=True,
-    mapper=None,
-    registry=None,
-    declarative_scan=None,
-    expired_attribute_loader=None,
-    init_method=None,
-):
+    class_: Type[_O],
+    finalize: bool = True,
+    mapper: Optional[Mapper[_O]] = None,
+    registry: Optional[_RegistryType] = None,
+    declarative_scan: Optional[_MapperConfig] = None,
+    expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None,
+    init_method: Optional[Callable[..., None]] = None,
+) -> ClassManager[_O]:
     """Register class instrumentation.
 
     Returns the existing or newly created class manager.
 
     """
 
-    manager = manager_of_class(class_)
+    manager = opt_manager_of_class(class_)
     if manager is None:
         manager = _instrumentation_factory.create_manager_for_cls(class_)
     manager._update_state(
index abc1300d8dc80f254cff3c505aaf3b1fc3aaf71d..0ca62b7e3584def7c6d96c6bdcea6d253e230558 100644 (file)
@@ -21,10 +21,15 @@ from __future__ import annotations
 import collections
 import typing
 from typing import Any
+from typing import Callable
 from typing import cast
+from typing import ClassVar
+from typing import Dict
+from typing import Iterator
 from typing import List
 from typing import Optional
 from typing import Sequence
+from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TypeVar
@@ -45,7 +50,6 @@ from .base import NotExtension as NotExtension
 from .base import ONETOMANY as ONETOMANY
 from .base import SQLORMOperations
 from .. import ColumnElement
-from .. import inspect
 from .. import inspection
 from .. import util
 from ..sql import operators
@@ -53,19 +57,47 @@ from ..sql import roles
 from ..sql import visitors
 from ..sql.base import ExecutableOption
 from ..sql.cache_key import HasCacheKey
-from ..sql.elements import SQLCoreOperations
 from ..sql.schema import Column
 from ..sql.type_api import TypeEngine
 from ..util.typing import TypedDict
 
+
 if typing.TYPE_CHECKING:
+    from ._typing import _EntityType
+    from ._typing import _IdentityKeyType
+    from ._typing import _InstanceDict
+    from ._typing import _InternalEntityType
+    from ._typing import _ORMAdapterProto
+    from ._typing import _ORMColumnExprArgument
+    from .attributes import InstrumentedAttribute
+    from .context import _MapperEntity
+    from .context import ORMCompileState
     from .decl_api import RegistryType
+    from .loading import _PopulatorDict
+    from .mapper import Mapper
+    from .path_registry import AbstractEntityRegistry
+    from .path_registry import PathRegistry
+    from .query import Query
+    from .session import Session
+    from .state import InstanceState
+    from .strategy_options import _LoadElement
+    from .util import AliasedInsp
+    from .util import CascadeOptions
+    from .util import ORMAdapter
+    from ..engine.result import Result
+    from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _ColumnsClauseArgument
     from ..sql._typing import _DMLColumnArgument
     from ..sql._typing import _InfoType
+    from ..sql._typing import _PropagateAttrsType
+    from ..sql.operators import OperatorType
+    from ..sql.util import ColumnAdapter
+    from ..sql.visitors import _TraverseInternalsType
 
 _T = TypeVar("_T", bound=Any)
 
+_TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]")
+
 
 class ORMStatementRole(roles.StatementRole):
     __slots__ = ()
@@ -91,7 +123,9 @@ class ORMFromClauseRole(roles.StrictFromClauseRole):
 
 class ORMColumnDescription(TypedDict):
     name: str
-    type: Union[Type, TypeEngine]
+    # TODO: add python_type and sql_type here; combining them
+    # into "type" is a bad idea
+    type: Union[Type[Any], TypeEngine[Any]]
     aliased: bool
     expr: _ColumnsClauseArgument
     entity: Optional[_ColumnsClauseArgument]
@@ -102,10 +136,10 @@ class _IntrospectsAnnotations:
 
     def declarative_scan(
         self,
-        registry: "RegistryType",
-        cls: type,
+        registry: RegistryType,
+        cls: Type[Any],
         key: str,
-        annotation: Optional[type],
+        annotation: Optional[Type[Any]],
         is_dataclass_field: Optional[bool],
     ) -> None:
         """Perform class-specific initializaton at early declarative scanning
@@ -124,12 +158,12 @@ class _MapsColumns(_MappedAttribute[_T]):
     __slots__ = ()
 
     @property
-    def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
+    def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]:
         """return a MapperProperty to be assigned to the declarative mapping"""
         raise NotImplementedError()
 
     @property
-    def columns_to_assign(self) -> List[Column]:
+    def columns_to_assign(self) -> List[Column[_T]]:
         """A list of Column objects that should be declaratively added to the
         new Table object.
 
@@ -139,7 +173,10 @@ class _MapsColumns(_MappedAttribute[_T]):
 
 @inspection._self_inspects
 class MapperProperty(
-    HasCacheKey, _MappedAttribute[_T], InspectionAttr, util.MemoizedSlots
+    HasCacheKey,
+    _MappedAttribute[_T],
+    InspectionAttrInfo,
+    util.MemoizedSlots,
 ):
     """Represent a particular class attribute mapped by :class:`_orm.Mapper`.
 
@@ -160,12 +197,12 @@ class MapperProperty(
         "info",
     )
 
-    _cache_key_traversal = [
+    _cache_key_traversal: _TraverseInternalsType = [
         ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key),
         ("key", visitors.ExtendedInternalTraversal.dp_string),
     ]
 
-    cascade = frozenset()
+    cascade: Optional[CascadeOptions] = None
     """The set of 'cascade' attribute names.
 
     This collection is checked before the 'cascade_iterator' method is called.
@@ -184,14 +221,20 @@ class MapperProperty(
     """The :class:`_orm.PropComparator` instance that implements SQL
     expression construction on behalf of this mapped attribute."""
 
-    @property
-    def _links_to_entity(self):
-        """True if this MapperProperty refers to a mapped entity.
+    key: str
+    """name of class attribute"""
 
-        Should only be True for Relationship, False for all others.
+    parent: Mapper[Any]
+    """the :class:`.Mapper` managing this property."""
 
-        """
-        raise NotImplementedError()
+    _is_relationship = False
+
+    _links_to_entity: bool
+    """True if this MapperProperty refers to a mapped entity.
+
+    Should only be True for Relationship, False for all others.
+
+    """
 
     def _memoized_attr_info(self) -> _InfoType:
         """Info dictionary associated with the object, allowing user-defined
@@ -217,7 +260,14 @@ class MapperProperty(
         """
         return {}
 
-    def setup(self, context, query_entity, path, adapter, **kwargs):
+    def setup(
+        self,
+        context: ORMCompileState,
+        query_entity: _MapperEntity,
+        path: PathRegistry,
+        adapter: Optional[ColumnAdapter],
+        **kwargs: Any,
+    ) -> None:
         """Called by Query for the purposes of constructing a SQL statement.
 
         Each MapperProperty associated with the target mapper processes the
@@ -227,16 +277,30 @@ class MapperProperty(
         """
 
     def create_row_processor(
-        self, context, query_entity, path, mapper, result, adapter, populators
-    ):
+        self,
+        context: ORMCompileState,
+        query_entity: _MapperEntity,
+        path: PathRegistry,
+        mapper: Mapper[Any],
+        result: Result,
+        adapter: Optional[ColumnAdapter],
+        populators: _PopulatorDict,
+    ) -> None:
         """Produce row processing functions and append to the given
         set of populators lists.
 
         """
 
     def cascade_iterator(
-        self, type_, state, dict_, visited_states, halt_on=None
-    ):
+        self,
+        type_: str,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        visited_states: Set[InstanceState[Any]],
+        halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None,
+    ) -> Iterator[
+        Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict]
+    ]:
         """Iterate through instances related to the given instance for
         a particular 'cascade', starting with this MapperProperty.
 
@@ -251,7 +315,7 @@ class MapperProperty(
 
         return iter(())
 
-    def set_parent(self, parent, init):
+    def set_parent(self, parent: Mapper[Any], init: bool) -> None:
         """Set the parent mapper that references this MapperProperty.
 
         This method is overridden by some subclasses to perform extra
@@ -260,7 +324,7 @@ class MapperProperty(
         """
         self.parent = parent
 
-    def instrument_class(self, mapper):
+    def instrument_class(self, mapper: Mapper[Any]) -> None:
         """Hook called by the Mapper to the property to initiate
         instrumentation of the class attribute managed by this
         MapperProperty.
@@ -280,11 +344,11 @@ class MapperProperty(
 
         """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._configure_started = False
         self._configure_finished = False
 
-    def init(self):
+    def init(self) -> None:
         """Called after all mappers are created to assemble
         relationships between mappers and perform other post-mapper-creation
         initialization steps.
@@ -296,7 +360,7 @@ class MapperProperty(
         self._configure_finished = True
 
     @property
-    def class_attribute(self):
+    def class_attribute(self) -> InstrumentedAttribute[_T]:
         """Return the class-bound descriptor corresponding to this
         :class:`.MapperProperty`.
 
@@ -319,9 +383,9 @@ class MapperProperty(
 
         """
 
-        return getattr(self.parent.class_, self.key)
+        return getattr(self.parent.class_, self.key)  # type: ignore
 
-    def do_init(self):
+    def do_init(self) -> None:
         """Perform subclass-specific initialization post-mapper-creation
         steps.
 
@@ -330,7 +394,7 @@ class MapperProperty(
 
         """
 
-    def post_instrument_class(self, mapper):
+    def post_instrument_class(self, mapper: Mapper[Any]) -> None:
         """Perform instrumentation adjustments that need to occur
         after init() has completed.
 
@@ -347,21 +411,21 @@ class MapperProperty(
 
     def merge(
         self,
-        session,
-        source_state,
-        source_dict,
-        dest_state,
-        dest_dict,
-        load,
-        _recursive,
-        _resolve_conflict_map,
-    ):
+        session: Session,
+        source_state: InstanceState[Any],
+        source_dict: _InstanceDict,
+        dest_state: InstanceState[Any],
+        dest_dict: _InstanceDict,
+        load: bool,
+        _recursive: Set[InstanceState[Any]],
+        _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+    ) -> None:
         """Merge the attribute represented by this ``MapperProperty``
         from source to destination object.
 
         """
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<%s at 0x%x; %s>" % (
             self.__class__.__name__,
             id(self),
@@ -452,21 +516,28 @@ class PropComparator(SQLORMOperations[_T]):
 
     """
 
-    __slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
+    __slots__ = "prop", "_parententity", "_adapt_to_entity"
 
     __visit_name__ = "orm_prop_comparator"
 
+    _parententity: _InternalEntityType[Any]
+    _adapt_to_entity: Optional[AliasedInsp[Any]]
+
     def __init__(
         self,
-        prop,
-        parentmapper,
-        adapt_to_entity=None,
+        prop: MapperProperty[_T],
+        parentmapper: _InternalEntityType[Any],
+        adapt_to_entity: Optional[AliasedInsp[Any]] = None,
     ):
-        self.prop = self.property = prop
+        self.prop = prop
         self._parententity = adapt_to_entity or parentmapper
         self._adapt_to_entity = adapt_to_entity
 
-    def __clause_element__(self):
+    @util.ro_non_memoized_property
+    def property(self) -> Optional[MapperProperty[_T]]:
+        return self.prop
+
+    def __clause_element__(self) -> _ORMColumnExprArgument[_T]:
         raise NotImplementedError("%r" % self)
 
     def _bulk_update_tuples(
@@ -480,22 +551,24 @@ class PropComparator(SQLORMOperations[_T]):
 
         """
 
-        return [(self.__clause_element__(), value)]
+        return [(cast("_DMLColumnArgument", self.__clause_element__()), value)]
 
-    def adapt_to_entity(self, adapt_to_entity):
+    def adapt_to_entity(
+        self, adapt_to_entity: AliasedInsp[Any]
+    ) -> PropComparator[_T]:
         """Return a copy of this PropComparator which will use the given
         :class:`.AliasedInsp` to produce corresponding expressions.
         """
         return self.__class__(self.prop, self._parententity, adapt_to_entity)
 
-    @property
-    def _parentmapper(self):
+    @util.ro_non_memoized_property
+    def _parentmapper(self) -> Mapper[Any]:
         """legacy; this is renamed to _parententity to be
         compatible with QueryableAttribute."""
-        return inspect(self._parententity).mapper
+        return self._parententity.mapper
 
-    @property
-    def _propagate_attrs(self):
+    @util.memoized_property
+    def _propagate_attrs(self) -> _PropagateAttrsType:
         # this suits the case in coercions where we don't actually
         # call ``__clause_element__()`` but still need to get
         # resolved._propagate_attrs.  See #6558.
@@ -507,12 +580,14 @@ class PropComparator(SQLORMOperations[_T]):
         )
 
     def _criterion_exists(
-        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+        self,
+        criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+        **kwargs: Any,
     ) -> ColumnElement[Any]:
         return self.prop.comparator._criterion_exists(criterion, **kwargs)
 
-    @property
-    def adapter(self):
+    @util.ro_non_memoized_property
+    def adapter(self) -> Optional[_ORMAdapterProto[_T]]:
         """Produce a callable that adapts column expressions
         to suit an aliased version of this comparator.
 
@@ -522,20 +597,20 @@ class PropComparator(SQLORMOperations[_T]):
         else:
             return self._adapt_to_entity._adapt_element
 
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def info(self) -> _InfoType:
-        return self.property.info
+        return self.prop.info
 
     @staticmethod
-    def _any_op(a, b, **kwargs):
+    def _any_op(a: Any, b: Any, **kwargs: Any) -> Any:
         return a.any(b, **kwargs)
 
     @staticmethod
-    def _has_op(left, other, **kwargs):
+    def _has_op(left: Any, other: Any, **kwargs: Any) -> Any:
         return left.has(other, **kwargs)
 
     @staticmethod
-    def _of_type_op(a, class_):
+    def _of_type_op(a: Any, class_: Any) -> Any:
         return a.of_type(class_)
 
     any_op = cast(operators.OperatorType, _any_op)
@@ -545,16 +620,16 @@ class PropComparator(SQLORMOperations[_T]):
     if typing.TYPE_CHECKING:
 
         def operate(
-            self, op: operators.OperatorType, *other: Any, **kwargs: Any
-        ) -> "SQLCoreOperations[Any]":
+            self, op: OperatorType, *other: Any, **kwargs: Any
+        ) -> ColumnElement[Any]:
             ...
 
         def reverse_operate(
-            self, op: operators.OperatorType, other: Any, **kwargs: Any
-        ) -> "SQLCoreOperations[Any]":
+            self, op: OperatorType, other: Any, **kwargs: Any
+        ) -> ColumnElement[Any]:
             ...
 
-    def of_type(self, class_) -> "SQLORMOperations[_T]":
+    def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
         r"""Redefine this object in terms of a polymorphic subclass,
         :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased`
         construct.
@@ -578,9 +653,11 @@ class PropComparator(SQLORMOperations[_T]):
 
         """
 
-        return self.operate(PropComparator.of_type_op, class_)
+        return self.operate(PropComparator.of_type_op, class_)  # type: ignore
 
-    def and_(self, *criteria) -> "SQLORMOperations[_T]":
+    def and_(
+        self, *criteria: _ColumnExpressionArgument[bool]
+    ) -> ColumnElement[bool]:
         """Add additional criteria to the ON clause that's represented by this
         relationship attribute.
 
@@ -606,10 +683,12 @@ class PropComparator(SQLORMOperations[_T]):
             :func:`.with_loader_criteria`
 
         """
-        return self.operate(operators.and_, *criteria)
+        return self.operate(operators.and_, *criteria)  # type: ignore
 
     def any(
-        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs
+        self,
+        criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+        **kwargs: Any,
     ) -> ColumnElement[bool]:
         r"""Return a SQL expression representing true if this element
         references a member which meets the given criterion.
@@ -626,10 +705,14 @@ class PropComparator(SQLORMOperations[_T]):
 
         """
 
-        return self.operate(PropComparator.any_op, criterion, **kwargs)
+        return self.operate(  # type: ignore
+            PropComparator.any_op, criterion, **kwargs
+        )
 
     def has(
-        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs
+        self,
+        criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+        **kwargs: Any,
     ) -> ColumnElement[bool]:
         r"""Return a SQL expression representing true if this element
         references a member which meets the given criterion.
@@ -646,7 +729,9 @@ class PropComparator(SQLORMOperations[_T]):
 
         """
 
-        return self.operate(PropComparator.has_op, criterion, **kwargs)
+        return self.operate(  # type: ignore
+            PropComparator.has_op, criterion, **kwargs
+        )
 
 
 class StrategizedProperty(MapperProperty[_T]):
@@ -674,23 +759,30 @@ class StrategizedProperty(MapperProperty[_T]):
         "strategy_key",
     )
     inherit_cache = True
-    strategy_wildcard_key = None
+    strategy_wildcard_key: ClassVar[str]
 
     strategy_key: Tuple[Any, ...]
 
-    def _memoized_attr__wildcard_token(self):
+    _strategies: Dict[Tuple[Any, ...], LoaderStrategy]
+
+    def _memoized_attr__wildcard_token(self) -> Tuple[str]:
         return (
             f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}",
         )
 
-    def _memoized_attr__default_path_loader_key(self):
+    def _memoized_attr__default_path_loader_key(
+        self,
+    ) -> Tuple[str, Tuple[str]]:
         return (
             "loader",
             (f"{self.strategy_wildcard_key}:{path_registry._DEFAULT_TOKEN}",),
         )
 
-    def _get_context_loader(self, context, path):
-        load = None
+    def _get_context_loader(
+        self, context: ORMCompileState, path: AbstractEntityRegistry
+    ) -> Optional[_LoadElement]:
+
+        load: Optional[_LoadElement] = None
 
         search_path = path[self]
 
@@ -714,7 +806,7 @@ class StrategizedProperty(MapperProperty[_T]):
 
         return load
 
-    def _get_strategy(self, key):
+    def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy:
         try:
             return self._strategies[key]
         except KeyError:
@@ -768,11 +860,13 @@ class StrategizedProperty(MapperProperty[_T]):
         ):
             self.strategy.init_class_attribute(mapper)
 
-    _all_strategies = collections.defaultdict(dict)
+    _all_strategies: collections.defaultdict[
+        Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]]
+    ] = collections.defaultdict(dict)
 
     @classmethod
-    def strategy_for(cls, **kw):
-        def decorate(dec_cls):
+    def strategy_for(cls, **kw: Any) -> Callable[[_TLS], _TLS]:
+        def decorate(dec_cls: _TLS) -> _TLS:
             # ensure each subclass of the strategy has its
             # own _strategy_keys collection
             if "_strategy_keys" not in dec_cls.__dict__:
@@ -785,7 +879,9 @@ class StrategizedProperty(MapperProperty[_T]):
         return decorate
 
     @classmethod
-    def _strategy_lookup(cls, requesting_property, *key):
+    def _strategy_lookup(
+        cls, requesting_property: MapperProperty[Any], *key: Any
+    ) -> Type[LoaderStrategy]:
         requesting_property.parent._with_polymorphic_mappers
 
         for prop_cls in cls.__mro__:
@@ -984,10 +1080,10 @@ class MapperOption(ORMOption):
 
     """
 
-    def process_query(self, query):
+    def process_query(self, query: Query[Any]) -> None:
         """Apply a modification to the given :class:`_query.Query`."""
 
-    def process_query_conditionally(self, query):
+    def process_query_conditionally(self, query: Query[Any]) -> None:
         """same as process_query(), except that this option may not
         apply to the given query.
 
@@ -1034,7 +1130,11 @@ class LoaderStrategy:
         "strategy_opts",
     )
 
-    def __init__(self, parent, strategy_key):
+    _strategy_keys: ClassVar[List[Tuple[Any, ...]]]
+
+    def __init__(
+        self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...]
+    ):
         self.parent_property = parent
         self.is_class_level = False
         self.parent = self.parent_property.parent
@@ -1042,12 +1142,18 @@ class LoaderStrategy:
         self.strategy_key = strategy_key
         self.strategy_opts = dict(strategy_key)
 
-    def init_class_attribute(self, mapper):
+    def init_class_attribute(self, mapper: Mapper[Any]) -> None:
         pass
 
     def setup_query(
-        self, compile_state, query_entity, path, loadopt, adapter, **kwargs
-    ):
+        self,
+        compile_state: ORMCompileState,
+        query_entity: _MapperEntity,
+        path: AbstractEntityRegistry,
+        loadopt: Optional[_LoadElement],
+        adapter: Optional[ORMAdapter],
+        **kwargs: Any,
+    ) -> None:
         """Establish column and other state for a given QueryContext.
 
         This method fulfills the contract specified by MapperProperty.setup().
@@ -1059,15 +1165,15 @@ class LoaderStrategy:
 
     def create_row_processor(
         self,
-        context,
-        query_entity,
-        path,
-        loadopt,
-        mapper,
-        result,
-        adapter,
-        populators,
-    ):
+        context: ORMCompileState,
+        query_entity: _MapperEntity,
+        path: AbstractEntityRegistry,
+        loadopt: Optional[_LoadElement],
+        mapper: Mapper[Any],
+        result: Result,
+        adapter: Optional[ORMAdapter],
+        populators: _PopulatorDict,
+    ) -> None:
         """Establish row processing functions for a given QueryContext.
 
         This method fulfills the contract specified by
index ae083054cd17f04cf94560b1b272190cda8e3ac9..d9949eb7a351161dc1e5812d73392e1c7a0c3979 100644 (file)
@@ -16,7 +16,9 @@ as well as some of the attribute loading strategies.
 from __future__ import annotations
 
 from typing import Any
+from typing import Dict
 from typing import Iterable
+from typing import List
 from typing import Mapping
 from typing import Optional
 from typing import Sequence
@@ -65,6 +67,9 @@ _O = TypeVar("_O", bound=object)
 _new_runid = util.counter()
 
 
+_PopulatorDict = Dict[str, List[Tuple[str, Any]]]
+
+
 def instances(cursor, context):
     """Return a :class:`.Result` given an ORM query context.
 
@@ -383,7 +388,7 @@ def get_from_identity(
     mapper: Mapper[_O],
     key: _IdentityKeyType[_O],
     passive: PassiveFlag,
-) -> Union[Optional[_O], LoaderCallableStatus]:
+) -> Union[LoaderCallableStatus, Optional[_O]]:
     """Look up the given key in the given session's identity map,
     check the object for expired state if found.
 
index abe11cc68cdc9c94c74cd9ee3ce74759b41507b7..b37c080eaf190f3041e4e6f7dca78f6021728609 100644 (file)
@@ -23,12 +23,23 @@ import sys
 import threading
 from typing import Any
 from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Deque
+from typing import Dict
 from typing import Generic
+from typing import Iterable
 from typing import Iterator
+from typing import List
+from typing import Mapping
 from typing import Optional
+from typing import Sequence
+from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
 import weakref
 
 from . import attributes
@@ -39,8 +50,8 @@ from . import properties
 from . import util as orm_util
 from ._typing import _O
 from .base import _class_to_mapper
+from .base import _parse_mapper_argument
 from .base import _state_mapper
-from .base import class_mapper
 from .base import PassiveFlag
 from .base import state_str
 from .interfaces import _MappedAttribute
@@ -58,6 +69,8 @@ from .. import log
 from .. import schema
 from .. import sql
 from .. import util
+from ..event import dispatcher
+from ..event import EventTarget
 from ..sql import base as sql_base
 from ..sql import coercions
 from ..sql import expression
@@ -65,26 +78,68 @@ from ..sql import operators
 from ..sql import roles
 from ..sql import util as sql_util
 from ..sql import visitors
+from ..sql.cache_key import MemoizedHasCacheKey
+from ..sql.schema import Table
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..util import HasMemoized
+from ..util import HasMemoized_ro_memoized_attribute
+from ..util.typing import Literal
 
 if TYPE_CHECKING:
     from ._typing import _IdentityKeyType
     from ._typing import _InstanceDict
+    from ._typing import _ORMColumnExprArgument
+    from ._typing import _RegistryType
+    from .decl_api import registry
+    from .dependency import DependencyProcessor
+    from .descriptor_props import Composite
+    from .descriptor_props import Synonym
+    from .events import MapperEvents
     from .instrumentation import ClassManager
+    from .path_registry import AbstractEntityRegistry
+    from .path_registry import CachingEntityRegistry
+    from .properties import ColumnProperty
+    from .relationships import Relationship
     from .state import InstanceState
+    from ..engine import Row
+    from ..engine import RowMapping
+    from ..sql._typing import _ColumnExpressionArgument
+    from ..sql._typing import _EquivalentColumnMap
+    from ..sql.base import ReadOnlyColumnCollection
+    from ..sql.elements import ColumnClause
     from ..sql.elements import ColumnElement
     from ..sql.schema import Column
+    from ..sql.schema import Table
+    from ..sql.selectable import FromClause
+    from ..sql.selectable import TableClause
+    from ..sql.util import ColumnAdapter
+    from ..util import OrderedSet
 
-_mapper_registries = weakref.WeakKeyDictionary()
 
+_T = TypeVar("_T", bound=Any)
+_MP = TypeVar("_MP", bound="MapperProperty[Any]")
 
-def _all_registries():
+_WithPolymorphicArg = Union[
+    Literal["*"],
+    Tuple[
+        Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]],
+        Optional["FromClause"],
+    ],
+    Sequence[Union["Mapper[Any]", Type[Any]]],
+]
+
+
+_mapper_registries: weakref.WeakKeyDictionary[
+    _RegistryType, bool
+] = weakref.WeakKeyDictionary()
+
+
+def _all_registries() -> Set[registry]:
     with _CONFIGURE_MUTEX:
         return set(_mapper_registries)
 
 
-def _unconfigured_mappers():
+def _unconfigured_mappers() -> Iterator[Mapper[Any]]:
     for reg in _all_registries():
         for mapper in reg._mappers_to_configure():
             yield mapper
@@ -107,9 +162,11 @@ _CONFIGURE_MUTEX = threading.RLock()
 class Mapper(
     ORMFromClauseRole,
     ORMEntityColumnsClauseRole,
-    sql_base.MemoizedHasCacheKey,
+    MemoizedHasCacheKey,
     InspectionAttr,
     log.Identified,
+    inspection.Inspectable["Mapper[_O]"],
+    EventTarget,
     Generic[_O],
 ):
     """Defines an association between a Python class and a database table or
@@ -123,18 +180,11 @@ class Mapper(
 
     """
 
+    dispatch: dispatcher[Mapper[_O]]
+
     _dispose_called = False
     _ready_for_configure = False
 
-    class_: Type[_O]
-    """The class to which this :class:`_orm.Mapper` is mapped."""
-
-    _identity_class: Type[_O]
-
-    always_refresh: bool
-    allow_partial_pks: bool
-    version_id_col: Optional[ColumnElement[Any]]
-
     @util.deprecated_params(
         non_primary=(
             "1.3",
@@ -148,33 +198,39 @@ class Mapper(
     def __init__(
         self,
         class_: Type[_O],
-        local_table=None,
-        properties=None,
-        primary_key=None,
-        non_primary=False,
-        inherits=None,
-        inherit_condition=None,
-        inherit_foreign_keys=None,
-        always_refresh=False,
-        version_id_col=None,
-        version_id_generator=None,
-        polymorphic_on=None,
-        _polymorphic_map=None,
-        polymorphic_identity=None,
-        concrete=False,
-        with_polymorphic=None,
-        polymorphic_load=None,
-        allow_partial_pks=True,
-        batch=True,
-        column_prefix=None,
-        include_properties=None,
-        exclude_properties=None,
-        passive_updates=True,
-        passive_deletes=False,
-        confirm_deleted_rows=True,
-        eager_defaults=False,
-        legacy_is_orphan=False,
-        _compiled_cache_size=100,
+        local_table: Optional[FromClause] = None,
+        properties: Optional[Mapping[str, MapperProperty[Any]]] = None,
+        primary_key: Optional[Iterable[_ORMColumnExprArgument[Any]]] = None,
+        non_primary: bool = False,
+        inherits: Optional[Union[Mapper[Any], Type[Any]]] = None,
+        inherit_condition: Optional[_ColumnExpressionArgument[bool]] = None,
+        inherit_foreign_keys: Optional[
+            Sequence[_ORMColumnExprArgument[Any]]
+        ] = None,
+        always_refresh: bool = False,
+        version_id_col: Optional[_ORMColumnExprArgument[Any]] = None,
+        version_id_generator: Optional[
+            Union[Literal[False], Callable[[Any], Any]]
+        ] = None,
+        polymorphic_on: Optional[
+            Union[_ORMColumnExprArgument[Any], str, MapperProperty[Any]]
+        ] = None,
+        _polymorphic_map: Optional[Dict[Any, Mapper[Any]]] = None,
+        polymorphic_identity: Optional[Any] = None,
+        concrete: bool = False,
+        with_polymorphic: Optional[_WithPolymorphicArg] = None,
+        polymorphic_load: Optional[Literal["selectin", "inline"]] = None,
+        allow_partial_pks: bool = True,
+        batch: bool = True,
+        column_prefix: Optional[str] = None,
+        include_properties: Optional[Sequence[str]] = None,
+        exclude_properties: Optional[Sequence[str]] = None,
+        passive_updates: bool = True,
+        passive_deletes: bool = False,
+        confirm_deleted_rows: bool = True,
+        eager_defaults: bool = False,
+        legacy_is_orphan: bool = False,
+        _compiled_cache_size: int = 100,
     ):
         r"""Direct constructor for a new :class:`_orm.Mapper` object.
 
@@ -593,8 +649,6 @@ class Mapper(
             self.class_.__name__,
         )
 
-        self.class_manager = None
-
         self._primary_key_argument = util.to_list(primary_key)
         self.non_primary = non_primary
 
@@ -623,17 +677,36 @@ class Mapper(
 
         self.concrete = concrete
         self.single = False
-        self.inherits = inherits
+
+        if inherits is not None:
+            self.inherits = _parse_mapper_argument(inherits)
+        else:
+            self.inherits = None
+
         if local_table is not None:
             self.local_table = coercions.expect(
                 roles.StrictFromClauseRole, local_table
             )
+        elif self.inherits:
+            # note this is a new flow as of 2.0 so that
+            # .local_table need not be Optional
+            self.local_table = self.inherits.local_table
+            self.single = True
         else:
-            self.local_table = None
+            raise sa_exc.ArgumentError(
+                f"Mapper[{self.class_.__name__}(None)] has None for a "
+                "primary table argument and does not specify 'inherits'"
+            )
+
+        if inherit_condition is not None:
+            self.inherit_condition = coercions.expect(
+                roles.OnClauseRole, inherit_condition
+            )
+        else:
+            self.inherit_condition = None
 
-        self.inherit_condition = inherit_condition
         self.inherit_foreign_keys = inherit_foreign_keys
-        self._init_properties = properties or {}
+        self._init_properties = dict(properties) if properties else {}
         self._delete_orphans = []
         self.batch = batch
         self.eager_defaults = eager_defaults
@@ -694,7 +767,10 @@ class Mapper(
         # while a configure_mappers() is occurring (and defer a
         # configure_mappers() until construction succeeds)
         with _CONFIGURE_MUTEX:
-            self.dispatch._events._new_mapper_instance(class_, self)
+
+            cast("MapperEvents", self.dispatch._events)._new_mapper_instance(
+                class_, self
+            )
             self._configure_inheritance()
             self._configure_class_instrumentation()
             self._configure_properties()
@@ -704,16 +780,21 @@ class Mapper(
             self._log("constructed")
             self._expire_memoizations()
 
-    # major attributes initialized at the classlevel so that
-    # they can be Sphinx-documented.
+    def _gen_cache_key(self, anon_map, bindparams):
+        return (self,)
+
+    # ### BEGIN
+    # ATTRIBUTE DECLARATIONS START HERE
 
     is_mapper = True
     """Part of the inspection API."""
 
     represents_outer_join = False
 
+    registry: _RegistryType
+
     @property
-    def mapper(self):
+    def mapper(self) -> Mapper[_O]:
         """Part of the inspection API.
 
         Returns self.
@@ -721,9 +802,6 @@ class Mapper(
         """
         return self
 
-    def _gen_cache_key(self, anon_map, bindparams):
-        return (self,)
-
     @property
     def entity(self):
         r"""Part of the inspection API.
@@ -733,49 +811,109 @@ class Mapper(
         """
         return self.class_
 
-    local_table = None
-    """The :class:`_expression.Selectable` which this :class:`_orm.Mapper`
-    manages.
+    class_: Type[_O]
+    """The class to which this :class:`_orm.Mapper` is mapped."""
+
+    _identity_class: Type[_O]
+
+    _delete_orphans: List[Tuple[str, Type[Any]]]
+    _dependency_processors: List[DependencyProcessor]
+    _memoized_values: Dict[Any, Callable[[], Any]]
+    _inheriting_mappers: util.WeakSequence[Mapper[Any]]
+    _all_tables: Set[Table]
+
+    _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]]
+    _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]]
+
+    _props: util.OrderedDict[str, MapperProperty[Any]]
+    _init_properties: Dict[str, MapperProperty[Any]]
+
+    _columntoproperty: _ColumnMapping
+
+    _set_polymorphic_identity: Optional[Callable[[InstanceState[_O]], None]]
+    _validate_polymorphic_identity: Optional[
+        Callable[[Mapper[_O], InstanceState[_O], _InstanceDict], None]
+    ]
+
+    tables: Sequence[Table]
+    """A sequence containing the collection of :class:`_schema.Table` objects
+    which this :class:`_orm.Mapper` is aware of.
+
+    If the mapper is mapped to a :class:`_expression.Join`, or an
+    :class:`_expression.Alias`
+    representing a :class:`_expression.Select`, the individual
+    :class:`_schema.Table`
+    objects that comprise the full construct will be represented here.
+
+    This is a *read only* attribute determined during mapper construction.
+    Behavior is undefined if directly modified.
+
+    """
+
+    validators: util.immutabledict[str, Tuple[str, Dict[str, Any]]]
+    """An immutable dictionary of attributes which have been decorated
+    using the :func:`_orm.validates` decorator.
+
+    The dictionary contains string attribute names as keys
+    mapped to the actual validation method.
+
+    """
+
+    always_refresh: bool
+    allow_partial_pks: bool
+    version_id_col: Optional[ColumnElement[Any]]
+
+    with_polymorphic: Optional[
+        Tuple[
+            Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]],
+            Optional["FromClause"],
+        ]
+    ]
+
+    version_id_generator: Optional[Union[Literal[False], Callable[[Any], Any]]]
+
+    local_table: FromClause
+    """The immediate :class:`_expression.FromClause` which this
+    :class:`_orm.Mapper` refers towards.
 
-    Typically is an instance of :class:`_schema.Table` or
-    :class:`_expression.Alias`.
-    May also be ``None``.
+    Typically is an instance of :class:`_schema.Table`, may be any
+    :class:`.FromClause`.
 
     The "local" table is the
     selectable that the :class:`_orm.Mapper` is directly responsible for
     managing from an attribute access and flush perspective.   For
-    non-inheriting mappers, the local table is the same as the
-    "mapped" table.   For joined-table inheritance mappers, local_table
-    will be the particular sub-table of the overall "join" which
-    this :class:`_orm.Mapper` represents.  If this mapper is a
-    single-table inheriting mapper, local_table will be ``None``.
+    non-inheriting mappers, :attr:`.Mapper.local_table` will be the same
+    as :attr:`.Mapper.persist_selectable`.  For inheriting mappers,
+    :attr:`.Mapper.local_table` refers to the specific portion of
+    :attr:`.Mapper.persist_selectable` that includes the columns to which
+    this :class:`.Mapper` is loading/persisting, such as a particular
+    :class:`.Table` within a join.
 
     .. seealso::
 
         :attr:`_orm.Mapper.persist_selectable`.
 
+        :attr:`_orm.Mapper.selectable`.
+
     """
 
-    persist_selectable = None
-    """The :class:`_expression.Selectable` to which this :class:`_orm.Mapper`
+    persist_selectable: FromClause
+    """The :class:`_expression.FromClause` to which this :class:`_orm.Mapper`
     is mapped.
 
-    Typically an instance of :class:`_schema.Table`,
-    :class:`_expression.Join`, or :class:`_expression.Alias`.
-
-    The :attr:`_orm.Mapper.persist_selectable` is separate from
-    :attr:`_orm.Mapper.selectable` in that the former represents columns
-    that are mapped on this class or its superclasses, whereas the
-    latter may be a "polymorphic" selectable that contains additional columns
-    which are in fact mapped on subclasses only.
+    Typically is an instance of :class:`_schema.Table`, may be any
+    :class:`.FromClause`.
 
-    "persist selectable" is the "thing the mapper writes to" and
-    "selectable" is the "thing the mapper selects from".
-
-    :attr:`_orm.Mapper.persist_selectable` is also separate from
-    :attr:`_orm.Mapper.local_table`, which represents the set of columns that
-    are locally mapped on this class directly.
+    The :attr:`_orm.Mapper.persist_selectable` is similar to
+    :attr:`.Mapper.local_table`, but represents the :class:`.FromClause` that
+    represents the inheriting class hierarchy overall in an inheritance
+    scenario.
 
+    :attr.`.Mapper.persist_selectable` is also separate from the
+    :attr:`.Mapper.selectable` attribute, the latter of which may be an
+    alternate subquery used for selecting columns.
+    :attr.`.Mapper.persist_selectable` is oriented towards columns that
+    will be written on a persist operation.
 
     .. seealso::
 
@@ -785,16 +923,15 @@ class Mapper(
 
     """
 
-    inherits = None
+    inherits: Optional[Mapper[Any]]
     """References the :class:`_orm.Mapper` which this :class:`_orm.Mapper`
     inherits from, if any.
 
-    This is a *read only* attribute determined during mapper construction.
-    Behavior is undefined if directly modified.
-
     """
 
-    configured = False
+    inherit_condition: Optional[ColumnElement[bool]]
+
+    configured: bool = False
     """Represent ``True`` if this :class:`_orm.Mapper` has been configured.
 
     This is a *read only* attribute determined during mapper construction.
@@ -806,7 +943,7 @@ class Mapper(
 
     """
 
-    concrete = None
+    concrete: bool
     """Represent ``True`` if this :class:`_orm.Mapper` is a concrete
     inheritance mapper.
 
@@ -815,21 +952,6 @@ class Mapper(
 
     """
 
-    tables = None
-    """An iterable containing the collection of :class:`_schema.Table` objects
-    which this :class:`_orm.Mapper` is aware of.
-
-    If the mapper is mapped to a :class:`_expression.Join`, or an
-    :class:`_expression.Alias`
-    representing a :class:`_expression.Select`, the individual
-    :class:`_schema.Table`
-    objects that comprise the full construct will be represented here.
-
-    This is a *read only* attribute determined during mapper construction.
-    Behavior is undefined if directly modified.
-
-    """
-
     primary_key: Tuple[Column[Any], ...]
     """An iterable containing the collection of :class:`_schema.Column`
     objects
@@ -854,14 +976,6 @@ class Mapper(
 
     """
 
-    class_: Type[_O]
-    """The Python class which this :class:`_orm.Mapper` maps.
-
-    This is a *read only* attribute determined during mapper construction.
-    Behavior is undefined if directly modified.
-
-    """
-
     class_manager: ClassManager[_O]
     """The :class:`.ClassManager` which maintains event listeners
     and class-bound descriptors for this :class:`_orm.Mapper`.
@@ -871,7 +985,7 @@ class Mapper(
 
     """
 
-    single = None
+    single: bool
     """Represent ``True`` if this :class:`_orm.Mapper` is a single table
     inheritance mapper.
 
@@ -882,7 +996,7 @@ class Mapper(
 
     """
 
-    non_primary = None
+    non_primary: bool
     """Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary"
     mapper, e.g. a mapper that is used only to select rows but not for
     persistence management.
@@ -892,7 +1006,7 @@ class Mapper(
 
     """
 
-    polymorphic_on = None
+    polymorphic_on: Optional[ColumnElement[Any]]
     """The :class:`_schema.Column` or SQL expression specified as the
     ``polymorphic_on`` argument
     for this :class:`_orm.Mapper`, within an inheritance scenario.
@@ -906,7 +1020,7 @@ class Mapper(
 
     """
 
-    polymorphic_map = None
+    polymorphic_map: Dict[Any, Mapper[Any]]
     """A mapping of "polymorphic identity" identifiers mapped to
     :class:`_orm.Mapper` instances, within an inheritance scenario.
 
@@ -922,7 +1036,7 @@ class Mapper(
 
     """
 
-    polymorphic_identity = None
+    polymorphic_identity: Optional[Any]
     """Represent an identifier which is matched against the
     :attr:`_orm.Mapper.polymorphic_on` column during result row loading.
 
@@ -935,7 +1049,7 @@ class Mapper(
 
     """
 
-    base_mapper = None
+    base_mapper: Mapper[Any]
     """The base-most :class:`_orm.Mapper` in an inheritance chain.
 
     In a non-inheriting scenario, this attribute will always be this
@@ -948,7 +1062,7 @@ class Mapper(
 
     """
 
-    columns = None
+    columns: ReadOnlyColumnCollection[str, Column[Any]]
     """A collection of :class:`_schema.Column` or other scalar expression
     objects maintained by this :class:`_orm.Mapper`.
 
@@ -965,25 +1079,16 @@ class Mapper(
 
     """
 
-    validators = None
-    """An immutable dictionary of attributes which have been decorated
-    using the :func:`_orm.validates` decorator.
-
-    The dictionary contains string attribute names as keys
-    mapped to the actual validation method.
-
-    """
-
-    c = None
+    c: ReadOnlyColumnCollection[str, Column[Any]]
     """A synonym for :attr:`_orm.Mapper.columns`."""
 
-    @property
+    @util.non_memoized_property
     @util.deprecated("1.3", "Use .persist_selectable")
     def mapped_table(self):
         return self.persist_selectable
 
     @util.memoized_property
-    def _path_registry(self) -> PathRegistry:
+    def _path_registry(self) -> CachingEntityRegistry:
         return PathRegistry.per_mapper(self)
 
     def _configure_inheritance(self):
@@ -994,8 +1099,6 @@ class Mapper(
         self._inheriting_mappers = util.WeakSequence()
 
         if self.inherits:
-            if isinstance(self.inherits, type):
-                self.inherits = class_mapper(self.inherits, configure=False)
             if not issubclass(self.class_, self.inherits.class_):
                 raise sa_exc.ArgumentError(
                     "Class '%s' does not inherit from '%s'"
@@ -1011,11 +1114,9 @@ class Mapper(
                     "only allowed from a %s mapper"
                     % (np, self.class_.__name__, np)
                 )
-            # inherit_condition is optional.
-            if self.local_table is None:
-                self.local_table = self.inherits.local_table
+
+            if self.single:
                 self.persist_selectable = self.inherits.persist_selectable
-                self.single = True
             elif self.local_table is not self.inherits.local_table:
                 if self.concrete:
                     self.persist_selectable = self.local_table
@@ -1068,6 +1169,7 @@ class Mapper(
                                     self.local_table.description,
                                 )
                             ) from afe
+                    assert self.inherits.persist_selectable is not None
                     self.persist_selectable = sql.join(
                         self.inherits.persist_selectable,
                         self.local_table,
@@ -1149,6 +1251,7 @@ class Mapper(
         else:
             self._all_tables = set()
             self.base_mapper = self
+            assert self.local_table is not None
             self.persist_selectable = self.local_table
             if self.polymorphic_identity is not None:
                 self.polymorphic_map[self.polymorphic_identity] = self
@@ -1160,21 +1263,34 @@ class Mapper(
                 % self
             )
 
-    def _set_with_polymorphic(self, with_polymorphic):
+    def _set_with_polymorphic(
+        self, with_polymorphic: Optional[_WithPolymorphicArg]
+    ) -> None:
         if with_polymorphic == "*":
             self.with_polymorphic = ("*", None)
         elif isinstance(with_polymorphic, (tuple, list)):
             if isinstance(with_polymorphic[0], (str, tuple, list)):
-                self.with_polymorphic = with_polymorphic
+                self.with_polymorphic = cast(
+                    """Tuple[
+                        Union[
+                            Literal["*"],
+                            Sequence[Union["Mapper[Any]", Type[Any]]],
+                        ],
+                        Optional["FromClause"],
+                    ]""",
+                    with_polymorphic,
+                )
             else:
                 self.with_polymorphic = (with_polymorphic, None)
         elif with_polymorphic is not None:
-            raise sa_exc.ArgumentError("Invalid setting for with_polymorphic")
+            raise sa_exc.ArgumentError(
+                f"Invalid setting for with_polymorphic: {with_polymorphic!r}"
+            )
         else:
             self.with_polymorphic = None
 
         if self.with_polymorphic and self.with_polymorphic[1] is not None:
-            self.with_polymorphic = (
+            self.with_polymorphic = (  # type: ignore
                 self.with_polymorphic[0],
                 coercions.expect(
                     roles.StrictFromClauseRole,
@@ -1191,6 +1307,7 @@ class Mapper(
         if self.with_polymorphic is None:
             self._set_with_polymorphic((subcl,))
         elif self.with_polymorphic[0] != "*":
+            assert isinstance(self.with_polymorphic[0], tuple)
             self._set_with_polymorphic(
                 (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1])
             )
@@ -1241,7 +1358,7 @@ class Mapper(
         # we expect that declarative has applied the class manager
         # already and set up a registry.  if this is None,
         # this raises as of 2.0.
-        manager = attributes.manager_of_class(self.class_)
+        manager = attributes.opt_manager_of_class(self.class_)
 
         if self.non_primary:
             if not manager or not manager.is_mapped:
@@ -1251,6 +1368,8 @@ class Mapper(
                     "Mapper." % self.class_
                 )
             self.class_manager = manager
+
+            assert manager.registry is not None
             self.registry = manager.registry
             self._identity_class = manager.mapper._identity_class
             manager.registry._add_non_primary_mapper(self)
@@ -1275,7 +1394,7 @@ class Mapper(
         manager = instrumentation.register_class(
             self.class_,
             mapper=self,
-            expired_attribute_loader=util.partial(
+            expired_attribute_loader=util.partial(  # type: ignore
                 loading.load_scalar_attributes, self
             ),
             # finalize flag means instrument the __init__ method
@@ -1284,6 +1403,8 @@ class Mapper(
         )
 
         self.class_manager = manager
+
+        assert manager.registry is not None
         self.registry = manager.registry
 
         # The remaining members can be added by any mapper,
@@ -1315,15 +1436,25 @@ class Mapper(
                             {name: (method, validation_opts)}
                         )
 
-    def _set_dispose_flags(self):
+    def _set_dispose_flags(self) -> None:
         self.configured = True
         self._ready_for_configure = True
         self._dispose_called = True
 
         self.__dict__.pop("_configure_failed", None)
 
-    def _configure_pks(self):
-        self.tables = sql_util.find_tables(self.persist_selectable)
+    def _configure_pks(self) -> None:
+        self.tables = cast(
+            "List[Table]", sql_util.find_tables(self.persist_selectable)
+        )
+        for t in self.tables:
+            if not isinstance(t, Table):
+                raise sa_exc.ArgumentError(
+                    f"ORM mappings can only be made against schema-level "
+                    f"Table objects, not TableClause; got "
+                    f"tableclause {t.name !r}"
+                )
+        self._all_tables.update(t for t in self.tables if isinstance(t, Table))
 
         self._pks_by_table = {}
         self._cols_by_table = {}
@@ -1335,16 +1466,16 @@ class Mapper(
         pk_cols = util.column_set(c for c in all_cols if c.primary_key)
 
         # identify primary key columns which are also mapped by this mapper.
-        tables = set(self.tables + [self.persist_selectable])
-        self._all_tables.update(tables)
-        for t in tables:
-            if t.primary_key and pk_cols.issuperset(t.primary_key):
+        for fc in set(self.tables).union([self.persist_selectable]):
+            if fc.primary_key and pk_cols.issuperset(fc.primary_key):
                 # ordering is important since it determines the ordering of
                 # mapper.primary_key (and therefore query.get())
-                self._pks_by_table[t] = util.ordered_column_set(
-                    t.primary_key
-                ).intersection(pk_cols)
-            self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(
+                self._pks_by_table[fc] = util.ordered_column_set(  # type: ignore  # noqa: E501
+                    fc.primary_key
+                ).intersection(
+                    pk_cols
+                )
+            self._cols_by_table[fc] = util.ordered_column_set(fc.c).intersection(  # type: ignore  # noqa: E501
                 all_cols
             )
 
@@ -1386,10 +1517,15 @@ class Mapper(
             self.primary_key = self.inherits.primary_key
         else:
             # determine primary key from argument or persist_selectable pks
+            primary_key: Collection[ColumnElement[Any]]
+
             if self._primary_key_argument:
                 primary_key = [
-                    self.persist_selectable.corresponding_column(c)
-                    for c in self._primary_key_argument
+                    cc if cc is not None else c
+                    for cc, c in (
+                        (self.persist_selectable.corresponding_column(c), c)
+                        for c in self._primary_key_argument
+                    )
                 ]
             else:
                 # if heuristically determined PKs, reduce to the minimal set
@@ -1413,7 +1549,7 @@ class Mapper(
 
         # determine cols that aren't expressed within our tables; mark these
         # as "read only" properties which are refreshed upon INSERT/UPDATE
-        self._readonly_props = set(
+        self._readonly_props = {
             self._columntoproperty[col]
             for col in self._columntoproperty
             if self._columntoproperty[col] not in self._identity_key_props
@@ -1421,12 +1557,12 @@ class Mapper(
                 not hasattr(col, "table")
                 or col.table not in self._cols_by_table
             )
-        )
+        }
 
-    def _configure_properties(self):
+    def _configure_properties(self) -> None:
 
         # TODO: consider using DedupeColumnCollection
-        self.columns = self.c = sql_base.ColumnCollection()
+        self.columns = self.c = sql_base.ColumnCollection()  # type: ignore
 
         # object attribute names mapped to MapperProperty objects
         self._props = util.OrderedDict()
@@ -1454,7 +1590,6 @@ class Mapper(
                 continue
 
             column_key = (self.column_prefix or "") + column.key
-
             if self._should_exclude(
                 column.key,
                 column_key,
@@ -1542,6 +1677,7 @@ class Mapper(
                     col = self.polymorphic_on
                     if isinstance(col, schema.Column) and (
                         self.with_polymorphic is None
+                        or self.with_polymorphic[1] is None
                         or self.with_polymorphic[1].corresponding_column(col)
                         is None
                     ):
@@ -1763,8 +1899,8 @@ class Mapper(
             self.columns.add(col, key)
 
             for col in prop.columns + prop._orig_columns:
-                for col in col.proxy_set:
-                    self._columntoproperty[col] = prop
+                for proxy_col in col.proxy_set:
+                    self._columntoproperty[proxy_col] = prop
 
         prop.key = key
 
@@ -2033,7 +2169,9 @@ class Mapper(
         self._check_configure()
         return iter(self._props.values())
 
-    def _mappers_from_spec(self, spec, selectable):
+    def _mappers_from_spec(
+        self, spec: Any, selectable: Optional[FromClause]
+    ) -> Sequence[Mapper[Any]]:
         """given a with_polymorphic() argument, return the set of mappers it
         represents.
 
@@ -2044,7 +2182,7 @@ class Mapper(
         if spec == "*":
             mappers = list(self.self_and_descendants)
         elif spec:
-            mappers = set()
+            mapper_set = set()
             for m in util.to_list(spec):
                 m = _class_to_mapper(m)
                 if not m.isa(self):
@@ -2053,10 +2191,10 @@ class Mapper(
                     )
 
                 if selectable is None:
-                    mappers.update(m.iterate_to_root())
+                    mapper_set.update(m.iterate_to_root())
                 else:
-                    mappers.add(m)
-            mappers = [m for m in self.self_and_descendants if m in mappers]
+                    mapper_set.add(m)
+            mappers = [m for m in self.self_and_descendants if m in mapper_set]
         else:
             mappers = []
 
@@ -2067,7 +2205,9 @@ class Mapper(
             mappers = [m for m in mappers if m.local_table in tables]
         return mappers
 
-    def _selectable_from_mappers(self, mappers, innerjoin):
+    def _selectable_from_mappers(
+        self, mappers: Iterable[Mapper[Any]], innerjoin: bool
+    ) -> FromClause:
         """given a list of mappers (assumed to be within this mapper's
         inheritance hierarchy), construct an outerjoin amongst those mapper's
         mapped tables.
@@ -2098,13 +2238,13 @@ class Mapper(
     def _single_table_criterion(self):
         if self.single and self.inherits and self.polymorphic_on is not None:
             return self.polymorphic_on._annotate({"parentmapper": self}).in_(
-                m.polymorphic_identity for m in self.self_and_descendants
+                [m.polymorphic_identity for m in self.self_and_descendants]
             )
         else:
             return None
 
     @HasMemoized.memoized_attribute
-    def _with_polymorphic_mappers(self):
+    def _with_polymorphic_mappers(self) -> Sequence[Mapper[Any]]:
         self._check_configure()
 
         if not self.with_polymorphic:
@@ -2124,8 +2264,8 @@ class Mapper(
         """
         self._check_configure()
 
-    @HasMemoized.memoized_attribute
-    def _with_polymorphic_selectable(self):
+    @HasMemoized_ro_memoized_attribute
+    def _with_polymorphic_selectable(self) -> FromClause:
         if not self.with_polymorphic:
             return self.persist_selectable
 
@@ -2143,7 +2283,7 @@ class Mapper(
 
     """
 
-    @HasMemoized.memoized_attribute
+    @HasMemoized_ro_memoized_attribute
     def _insert_cols_evaluating_none(self):
         return dict(
             (
@@ -2250,7 +2390,7 @@ class Mapper(
     @HasMemoized.memoized_instancemethod
     def __clause_element__(self):
 
-        annotations = {
+        annotations: Dict[str, Any] = {
             "entity_namespace": self,
             "parententity": self,
             "parentmapper": self,
@@ -2290,7 +2430,7 @@ class Mapper(
         )
 
     @property
-    def selectable(self):
+    def selectable(self) -> FromClause:
         """The :class:`_schema.FromClause` construct this
         :class:`_orm.Mapper` selects from by default.
 
@@ -2302,8 +2442,11 @@ class Mapper(
         return self._with_polymorphic_selectable
 
     def _with_polymorphic_args(
-        self, spec=None, selectable=False, innerjoin=False
-    ):
+        self,
+        spec: Any = None,
+        selectable: Union[Literal[False, None], FromClause] = False,
+        innerjoin: bool = False,
+    ) -> Tuple[Sequence[Mapper[Any]], FromClause]:
         if selectable not in (None, False):
             selectable = coercions.expect(
                 roles.StrictFromClauseRole, selectable, allow_select=True
@@ -2357,7 +2500,7 @@ class Mapper(
         ]
 
     @HasMemoized.memoized_attribute
-    def _polymorphic_adapter(self):
+    def _polymorphic_adapter(self) -> Optional[sql_util.ColumnAdapter]:
         if self.with_polymorphic:
             return sql_util.ColumnAdapter(
                 self.selectable, equivalents=self._equivalent_columns
@@ -2394,7 +2537,7 @@ class Mapper(
                 yield c
 
     @HasMemoized.memoized_attribute
-    def attrs(self) -> util.ReadOnlyProperties["MapperProperty"]:
+    def attrs(self) -> util.ReadOnlyProperties[MapperProperty[Any]]:
         """A namespace of all :class:`.MapperProperty` objects
         associated this mapper.
 
@@ -2432,7 +2575,7 @@ class Mapper(
         return util.ReadOnlyProperties(self._props)
 
     @HasMemoized.memoized_attribute
-    def all_orm_descriptors(self):
+    def all_orm_descriptors(self) -> util.ReadOnlyProperties[InspectionAttr]:
         """A namespace of all :class:`.InspectionAttr` attributes associated
         with the mapped class.
 
@@ -2503,7 +2646,7 @@ class Mapper(
 
     @HasMemoized.memoized_attribute
     @util.preload_module("sqlalchemy.orm.descriptor_props")
-    def synonyms(self):
+    def synonyms(self) -> util.ReadOnlyProperties[Synonym[Any]]:
         """Return a namespace of all :class:`.Synonym`
         properties maintained by this :class:`_orm.Mapper`.
 
@@ -2523,7 +2666,7 @@ class Mapper(
         return self.class_
 
     @HasMemoized.memoized_attribute
-    def column_attrs(self):
+    def column_attrs(self) -> util.ReadOnlyProperties[ColumnProperty[Any]]:
         """Return a namespace of all :class:`.ColumnProperty`
         properties maintained by this :class:`_orm.Mapper`.
 
@@ -2536,9 +2679,9 @@ class Mapper(
         """
         return self._filter_properties(properties.ColumnProperty)
 
-    @util.preload_module("sqlalchemy.orm.relationships")
     @HasMemoized.memoized_attribute
-    def relationships(self):
+    @util.preload_module("sqlalchemy.orm.relationships")
+    def relationships(self) -> util.ReadOnlyProperties[Relationship[Any]]:
         """A namespace of all :class:`.Relationship` properties
         maintained by this :class:`_orm.Mapper`.
 
@@ -2567,7 +2710,7 @@ class Mapper(
 
     @HasMemoized.memoized_attribute
     @util.preload_module("sqlalchemy.orm.descriptor_props")
-    def composites(self):
+    def composites(self) -> util.ReadOnlyProperties[Composite[Any]]:
         """Return a namespace of all :class:`.Composite`
         properties maintained by this :class:`_orm.Mapper`.
 
@@ -2582,7 +2725,9 @@ class Mapper(
             util.preloaded.orm_descriptor_props.Composite
         )
 
-    def _filter_properties(self, type_):
+    def _filter_properties(
+        self, type_: Type[_MP]
+    ) -> util.ReadOnlyProperties[_MP]:
         self._check_configure()
         return util.ReadOnlyProperties(
             util.OrderedDict(
@@ -2610,7 +2755,7 @@ class Mapper(
         )
 
     @HasMemoized.memoized_attribute
-    def _equivalent_columns(self):
+    def _equivalent_columns(self) -> _EquivalentColumnMap:
         """Create a map of all equivalent columns, based on
         the determination of column pairs that are equated to
         one another based on inherit condition.  This is designed
@@ -2630,18 +2775,18 @@ class Mapper(
             }
 
         """
-        result = util.column_dict()
+        result: _EquivalentColumnMap = {}
 
         def visit_binary(binary):
             if binary.operator == operators.eq:
                 if binary.left in result:
                     result[binary.left].add(binary.right)
                 else:
-                    result[binary.left] = util.column_set((binary.right,))
+                    result[binary.left] = {binary.right}
                 if binary.right in result:
                     result[binary.right].add(binary.left)
                 else:
-                    result[binary.right] = util.column_set((binary.left,))
+                    result[binary.right] = {binary.left}
 
         for mapper in self.base_mapper.self_and_descendants:
             if mapper.inherit_condition is not None:
@@ -2711,13 +2856,13 @@ class Mapper(
 
         return False
 
-    def common_parent(self, other):
+    def common_parent(self, other: Mapper[Any]) -> bool:
         """Return true if the given mapper shares a
         common inherited parent as this mapper."""
 
         return self.base_mapper is other.base_mapper
 
-    def is_sibling(self, other):
+    def is_sibling(self, other: Mapper[Any]) -> bool:
         """return true if the other mapper is an inheriting sibling to this
         one.  common parent but different branch
 
@@ -2728,7 +2873,9 @@ class Mapper(
             and not other.isa(self)
         )
 
-    def _canload(self, state, allow_subtypes):
+    def _canload(
+        self, state: InstanceState[Any], allow_subtypes: bool
+    ) -> bool:
         s = self.primary_mapper()
         if self.polymorphic_on is not None or allow_subtypes:
             return _state_mapper(state).isa(s)
@@ -2738,19 +2885,19 @@ class Mapper(
     def isa(self, other: Mapper[Any]) -> bool:
         """Return True if the this mapper inherits from the given mapper."""
 
-        m = self
+        m: Optional[Mapper[Any]] = self
         while m and m is not other:
             m = m.inherits
         return bool(m)
 
-    def iterate_to_root(self):
-        m = self
+    def iterate_to_root(self) -> Iterator[Mapper[Any]]:
+        m: Optional[Mapper[Any]] = self
         while m:
             yield m
             m = m.inherits
 
     @HasMemoized.memoized_attribute
-    def self_and_descendants(self):
+    def self_and_descendants(self) -> Sequence[Mapper[Any]]:
         """The collection including this mapper and all descendant mappers.
 
         This includes not just the immediately inheriting mappers but
@@ -2765,7 +2912,7 @@ class Mapper(
             stack.extend(item._inheriting_mappers)
         return util.WeakSequence(descendants)
 
-    def polymorphic_iterator(self):
+    def polymorphic_iterator(self) -> Iterator[Mapper[Any]]:
         """Iterate through the collection including this mapper and
         all descendant mappers.
 
@@ -2778,18 +2925,18 @@ class Mapper(
         """
         return iter(self.self_and_descendants)
 
-    def primary_mapper(self):
+    def primary_mapper(self) -> Mapper[Any]:
         """Return the primary mapper corresponding to this mapper's class key
         (class)."""
 
         return self.class_manager.mapper
 
     @property
-    def primary_base_mapper(self):
+    def primary_base_mapper(self) -> Mapper[Any]:
         return self.class_manager.mapper.base_mapper
 
     def _result_has_identity_key(self, result, adapter=None):
-        pk_cols = self.primary_key
+        pk_cols: Sequence[ColumnClause[Any]] = self.primary_key
         if adapter:
             pk_cols = [adapter.columns[c] for c in pk_cols]
         rk = result.keys()
@@ -2799,25 +2946,35 @@ class Mapper(
         else:
             return True
 
-    def identity_key_from_row(self, row, identity_token=None, adapter=None):
+    def identity_key_from_row(
+        self,
+        row: Optional[Union[Row, RowMapping]],
+        identity_token: Optional[Any] = None,
+        adapter: Optional[ColumnAdapter] = None,
+    ) -> _IdentityKeyType[_O]:
         """Return an identity-map key for use in storing/retrieving an
         item from the identity map.
 
-        :param row: A :class:`.Row` instance.  The columns which are
-         mapped by this :class:`_orm.Mapper` should be locatable in the row,
-         preferably via the :class:`_schema.Column`
-         object directly (as is the case
-         when a :func:`_expression.select` construct is executed), or
-         via string names of the form ``<tablename>_<colname>``.
+        :param row: A :class:`.Row` or :class:`.RowMapping` produced from a
+         result set that selected from the ORM mapped primary key columns.
+
+         .. versionchanged:: 2.0
+            :class:`.Row` or :class:`.RowMapping` are accepted
+            for the "row" argument
 
         """
-        pk_cols = self.primary_key
+        pk_cols: Sequence[ColumnClause[Any]] = self.primary_key
         if adapter:
             pk_cols = [adapter.columns[c] for c in pk_cols]
 
+        if hasattr(row, "_mapping"):
+            mapping = row._mapping  # type: ignore
+        else:
+            mapping = cast("Mapping[Any, Any]", row)
+
         return (
             self._identity_class,
-            tuple(row[column] for column in pk_cols),
+            tuple(mapping[column] for column in pk_cols),  # type: ignore
             identity_token,
         )
 
@@ -2852,12 +3009,12 @@ class Mapper(
 
         """
         state = attributes.instance_state(instance)
-        return self._identity_key_from_state(state, attributes.PASSIVE_OFF)
+        return self._identity_key_from_state(state, PassiveFlag.PASSIVE_OFF)
 
     def _identity_key_from_state(
         self,
         state: InstanceState[_O],
-        passive: PassiveFlag = attributes.PASSIVE_RETURN_NO_VALUE,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE,
     ) -> _IdentityKeyType[_O]:
         dict_ = state.dict
         manager = state.manager
@@ -2884,7 +3041,7 @@ class Mapper(
         """
         state = attributes.instance_state(instance)
         identity_key = self._identity_key_from_state(
-            state, attributes.PASSIVE_OFF
+            state, PassiveFlag.PASSIVE_OFF
         )
         return identity_key[1]
 
@@ -2913,14 +3070,14 @@ class Mapper(
 
     @HasMemoized.memoized_attribute
     def _all_pk_cols(self):
-        collection = set()
+        collection: Set[ColumnClause[Any]] = set()
         for table in self.tables:
             collection.update(self._pks_by_table[table])
         return collection
 
     @HasMemoized.memoized_attribute
     def _should_undefer_in_wildcard(self):
-        cols = set(self.primary_key)
+        cols: Set[ColumnElement[Any]] = set(self.primary_key)
         if self.polymorphic_on is not None:
             cols.add(self.polymorphic_on)
         return cols
@@ -2951,11 +3108,11 @@ class Mapper(
         state = attributes.instance_state(obj)
         dict_ = attributes.instance_dict(obj)
         return self._get_committed_state_attr_by_column(
-            state, dict_, column, passive=attributes.PASSIVE_OFF
+            state, dict_, column, passive=PassiveFlag.PASSIVE_OFF
         )
 
     def _get_committed_state_attr_by_column(
-        self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE
+        self, state, dict_, column, passive=PassiveFlag.PASSIVE_RETURN_NO_VALUE
     ):
 
         prop = self._columntoproperty[column]
@@ -2978,7 +3135,7 @@ class Mapper(
         col_attribute_names = set(attribute_names).intersection(
             state.mapper.column_attrs.keys()
         )
-        tables = set(
+        tables: Set[FromClause] = set(
             chain(
                 *[
                     sql_util.find_tables(c, check_columns=True)
@@ -3002,7 +3159,7 @@ class Mapper(
                     state,
                     state.dict,
                     leftcol,
-                    passive=attributes.PASSIVE_NO_INITIALIZE,
+                    passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
                 )
                 if leftval in orm_util._none_set:
                     raise _OptGetColumnsNotAvailable()
@@ -3014,7 +3171,7 @@ class Mapper(
                     state,
                     state.dict,
                     rightcol,
-                    passive=attributes.PASSIVE_NO_INITIALIZE,
+                    passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
                 )
                 if rightval in orm_util._none_set:
                     raise _OptGetColumnsNotAvailable()
@@ -3022,7 +3179,7 @@ class Mapper(
                     None, rightval, type_=binary.right.type
                 )
 
-        allconds = []
+        allconds: List[ColumnElement[bool]] = []
 
         start = False
 
@@ -3035,6 +3192,9 @@ class Mapper(
             elif not isinstance(mapper.local_table, expression.TableClause):
                 return None
             if start and not mapper.single:
+                assert mapper.inherits
+                assert not mapper.concrete
+                assert mapper.inherit_condition is not None
                 allconds.append(mapper.inherit_condition)
                 tables.add(mapper.local_table)
 
@@ -3043,11 +3203,13 @@ class Mapper(
         # descendant-most class should all be present and joined to each
         # other.
         try:
-            allconds[0] = visitors.cloned_traverse(
+            _traversed = visitors.cloned_traverse(
                 allconds[0], {}, {"binary": visit_binary}
             )
         except _OptGetColumnsNotAvailable:
             return None
+        else:
+            allconds[0] = _traversed
 
         cond = sql.and_(*allconds)
 
@@ -3145,6 +3307,8 @@ class Mapper(
             for pk in self.primary_key
         ]
 
+        in_expr: ColumnElement[Any]
+
         if len(primary_key) > 1:
             in_expr = sql.tuple_(*primary_key)
         else:
@@ -3209,11 +3373,22 @@ class Mapper(
             traverse all objects without relying on cascades.
 
         """
-        visited_states = set()
+        visited_states: Set[InstanceState[Any]] = set()
         prp, mpp = object(), object()
 
         assert state.mapper.isa(self)
 
+        # this is actually a recursive structure, fully typing it seems
+        # a little too difficult for what it's worth here
+        visitables: Deque[
+            Tuple[
+                Deque[Any],
+                object,
+                Optional[InstanceState[Any]],
+                Optional[_InstanceDict],
+            ]
+        ]
+
         visitables = deque(
             [(deque(state.mapper._props.values()), prp, state, state.dict)]
         )
@@ -3226,8 +3401,10 @@ class Mapper(
 
             if item_type is prp:
                 prop = iterator.popleft()
-                if type_ not in prop.cascade:
+                if not prop.cascade or type_ not in prop.cascade:
                     continue
+                assert parent_state is not None
+                assert parent_dict is not None
                 queue = deque(
                     prop.cascade_iterator(
                         type_,
@@ -3267,7 +3444,7 @@ class Mapper(
 
     @HasMemoized.memoized_attribute
     def _sorted_tables(self):
-        table_to_mapper = {}
+        table_to_mapper: Dict[Table, Mapper[Any]] = {}
 
         for mapper in self.base_mapper.self_and_descendants:
             for t in mapper.tables:
@@ -3316,9 +3493,9 @@ class Mapper(
             ret[t] = table_to_mapper[t]
         return ret
 
-    def _memo(self, key, callable_):
+    def _memo(self, key: Any, callable_: Callable[[], _T]) -> _T:
         if key in self._memoized_values:
-            return self._memoized_values[key]
+            return cast(_T, self._memoized_values[key])
         else:
             self._memoized_values[key] = value = callable_()
             return value
@@ -3328,14 +3505,22 @@ class Mapper(
         """memoized map of tables to collections of columns to be
         synchronized upwards to the base mapper."""
 
-        result = util.defaultdict(list)
+        result: util.defaultdict[
+            Table,
+            List[
+                Tuple[
+                    Mapper[Any],
+                    List[Tuple[ColumnElement[Any], ColumnElement[Any]]],
+                ]
+            ],
+        ] = util.defaultdict(list)
 
         for table in self._sorted_tables:
             cols = set(table.c)
             for m in self.iterate_to_root():
                 if m._inherits_equated_pairs and cols.intersection(
                     reduce(
-                        set.union,
+                        set.union,  # type: ignore
                         [l.proxy_set for l, r in m._inherits_equated_pairs],
                     )
                 ):
@@ -3440,7 +3625,7 @@ def _configure_registries(registries, cascade):
             else:
                 return
 
-            Mapper.dispatch._for_class(Mapper).before_configured()
+            Mapper.dispatch._for_class(Mapper).before_configured()  # type: ignore # noqa: E501
             # initialize properties on all mappers
             # note that _mapper_registry is unordered, which
             # may randomly conceal/reveal issues related to
@@ -3449,7 +3634,7 @@ def _configure_registries(registries, cascade):
             _do_configure_registries(registries, cascade)
         finally:
             _already_compiling = False
-    Mapper.dispatch._for_class(Mapper).after_configured()
+    Mapper.dispatch._for_class(Mapper).after_configured()  # type: ignore
 
 
 @util.preload_module("sqlalchemy.orm.decl_api")
@@ -3480,7 +3665,7 @@ def _do_configure_registries(registries, cascade):
                     "Original exception was: %s"
                     % (mapper, mapper._configure_failed)
                 )
-                e._configure_failed = mapper._configure_failed
+                e._configure_failed = mapper._configure_failed  # type: ignore
                 raise e
 
             if not mapper.configured:
@@ -3636,7 +3821,7 @@ def _event_on_init(state, args, kwargs):
             instrumenting_mapper._set_polymorphic_identity(state)
 
 
-class _ColumnMapping(dict):
+class _ColumnMapping(Dict["ColumnElement[Any]", "MapperProperty[Any]"]):
     """Error reporting helper for mapper._columntoproperty."""
 
     __slots__ = ("mapper",)
index e2cf1d5b04298c7f8ab6d76fa76f0eef3611d70f..361cea9757731aa326f440125539949b700d11e1 100644 (file)
@@ -13,22 +13,70 @@ from __future__ import annotations
 from functools import reduce
 from itertools import chain
 import logging
+import operator
 from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Tuple
+from typing import TYPE_CHECKING
 from typing import Union
 
 from . import base as orm_base
+from ._typing import insp_is_mapper_property
 from .. import exc
-from .. import inspection
 from .. import util
 from ..sql import visitors
 from ..sql.cache_key import HasCacheKey
 
+if TYPE_CHECKING:
+    from ._typing import _InternalEntityType
+    from .interfaces import MapperProperty
+    from .mapper import Mapper
+    from .relationships import Relationship
+    from .util import AliasedInsp
+    from ..sql.cache_key import _CacheKeyTraversalType
+    from ..sql.elements import BindParameter
+    from ..sql.visitors import anon_map
+    from ..util.typing import TypeGuard
+
+    def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]:
+        ...
+
+    def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]:
+        ...
+
+else:
+    is_root = operator.attrgetter("is_root")
+    is_entity = operator.attrgetter("is_entity")
+
+
+_SerializedPath = List[Any]
+
+_PathElementType = Union[
+    str, "_InternalEntityType[Any]", "MapperProperty[Any]"
+]
+
+# the representation is in fact
+# a tuple with alternating:
+# [_InternalEntityType[Any], Union[str, MapperProperty[Any]],
+# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...]
+# this might someday be a tuple of 2-tuples instead, but paths can be
+# chopped at odd intervals as well so this is less flexible
+_PathRepresentation = Tuple[_PathElementType, ...]
+
+_OddPathRepresentation = Sequence["_InternalEntityType[Any]"]
+_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]]
+
+
 log = logging.getLogger(__name__)
 
 
-def _unreduce_path(path):
+def _unreduce_path(path: _SerializedPath) -> PathRegistry:
     return PathRegistry.deserialize(path)
 
 
@@ -67,17 +115,18 @@ class PathRegistry(HasCacheKey):
     is_token = False
     is_root = False
     has_entity = False
+    is_entity = False
 
-    path: Tuple
-    natural_path: Tuple
-    parent: Union["PathRegistry", None]
+    path: _PathRepresentation
+    natural_path: _PathRepresentation
+    parent: Optional[PathRegistry]
+    root: RootRegistry
 
-    root: "PathRegistry"
-    _cache_key_traversal = [
+    _cache_key_traversal: _CacheKeyTraversalType = [
         ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list)
     ]
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         try:
             return other is not None and self.path == other._path_for_compare
         except AttributeError:
@@ -87,7 +136,7 @@ class PathRegistry(HasCacheKey):
             )
             return False
 
-    def __ne__(self, other):
+    def __ne__(self, other: Any) -> bool:
         try:
             return other is None or self.path != other._path_for_compare
         except AttributeError:
@@ -98,74 +147,88 @@ class PathRegistry(HasCacheKey):
             return True
 
     @property
-    def _path_for_compare(self):
+    def _path_for_compare(self) -> Optional[_PathRepresentation]:
         return self.path
 
-    def set(self, attributes, key, value):
+    def set(self, attributes: Dict[Any, Any], key: Any, value: Any) -> None:
         log.debug("set '%s' on path '%s' to '%s'", key, self, value)
         attributes[(key, self.natural_path)] = value
 
-    def setdefault(self, attributes, key, value):
+    def setdefault(
+        self, attributes: Dict[Any, Any], key: Any, value: Any
+    ) -> None:
         log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value)
         attributes.setdefault((key, self.natural_path), value)
 
-    def get(self, attributes, key, value=None):
+    def get(
+        self, attributes: Dict[Any, Any], key: Any, value: Optional[Any] = None
+    ) -> Any:
         key = (key, self.natural_path)
         if key in attributes:
             return attributes[key]
         else:
             return value
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.path)
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return id(self)
 
-    def __getitem__(self, key: Any) -> "PathRegistry":
+    def __getitem__(self, key: Any) -> PathRegistry:
         raise NotImplementedError()
 
+    # TODO: what are we using this for?
     @property
-    def length(self):
+    def length(self) -> int:
         return len(self.path)
 
-    def pairs(self):
-        path = self.path
-        for i in range(0, len(path), 2):
-            yield path[i], path[i + 1]
-
-    def contains_mapper(self, mapper):
-        for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]:
+    def pairs(
+        self,
+    ) -> Iterator[
+        Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]]
+    ]:
+        odd_path = cast(_OddPathRepresentation, self.path)
+        even_path = cast(_EvenPathRepresentation, odd_path)
+        for i in range(0, len(odd_path), 2):
+            yield odd_path[i], even_path[i + 1]
+
+    def contains_mapper(self, mapper: Mapper[Any]) -> bool:
+        _m_path = cast(_OddPathRepresentation, self.path)
+        for path_mapper in [_m_path[i] for i in range(0, len(_m_path), 2)]:
             if path_mapper.is_mapper and path_mapper.isa(mapper):
                 return True
         else:
             return False
 
-    def contains(self, attributes, key):
+    def contains(self, attributes: Dict[Any, Any], key: Any) -> bool:
         return (key, self.path) in attributes
 
-    def __reduce__(self):
+    def __reduce__(self) -> Any:
         return _unreduce_path, (self.serialize(),)
 
     @classmethod
-    def _serialize_path(cls, path):
+    def _serialize_path(cls, path: _PathRepresentation) -> _SerializedPath:
+        _m_path = cast(_OddPathRepresentation, path)
+        _p_path = cast(_EvenPathRepresentation, path)
+
         return list(
             zip(
-                [
+                tuple(
                     m.class_ if (m.is_mapper or m.is_aliased_class) else str(m)
-                    for m in [path[i] for i in range(0, len(path), 2)]
-                ],
-                [
-                    path[i].key if (path[i].is_property) else str(path[i])
-                    for i in range(1, len(path), 2)
-                ]
-                + [None],
+                    for m in [_m_path[i] for i in range(0, len(_m_path), 2)]
+                ),
+                tuple(
+                    p.key if insp_is_mapper_property(p) else str(p)
+                    for p in [_p_path[i] for i in range(1, len(_p_path), 2)]
+                )
+                + (None,),
             )
         )
 
     @classmethod
-    def _deserialize_path(cls, path):
-        def _deserialize_mapper_token(mcls):
+    def _deserialize_path(cls, path: _SerializedPath) -> _PathRepresentation:
+        def _deserialize_mapper_token(mcls: Any) -> Any:
             return (
                 # note: we likely dont want configure=True here however
                 # this is maintained at the moment for backwards compatibility
@@ -174,15 +237,15 @@ class PathRegistry(HasCacheKey):
                 else PathToken._intern[mcls]
             )
 
-        def _deserialize_key_token(mcls, key):
+        def _deserialize_key_token(mcls: Any, key: Any) -> Any:
             if key is None:
                 return None
             elif key in PathToken._intern:
                 return PathToken._intern[key]
             else:
-                return orm_base._inspect_mapped_class(
-                    mcls, configure=True
-                ).attrs[key]
+                mp = orm_base._inspect_mapped_class(mcls, configure=True)
+                assert mp is not None
+                return mp.attrs[key]
 
         p = tuple(
             chain(
@@ -199,28 +262,63 @@ class PathRegistry(HasCacheKey):
             p = p[0:-1]
         return p
 
-    def serialize(self) -> Sequence[Any]:
+    def serialize(self) -> _SerializedPath:
         path = self.path
         return self._serialize_path(path)
 
     @classmethod
-    def deserialize(cls, path: Sequence[Any]) -> PathRegistry:
+    def deserialize(cls, path: _SerializedPath) -> PathRegistry:
         assert path is not None
         p = cls._deserialize_path(path)
         return cls.coerce(p)
 
+    @overload
     @classmethod
-    def per_mapper(cls, mapper):
+    def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry:
+        ...
+
+    @overload
+    @classmethod
+    def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry:
+        ...
+
+    @classmethod
+    def per_mapper(
+        cls, mapper: _InternalEntityType[Any]
+    ) -> AbstractEntityRegistry:
         if mapper.is_mapper:
             return CachingEntityRegistry(cls.root, mapper)
         else:
             return SlotsEntityRegistry(cls.root, mapper)
 
     @classmethod
-    def coerce(cls, raw):
-        return reduce(lambda prev, next: prev[next], raw, cls.root)
+    def coerce(cls, raw: _PathRepresentation) -> PathRegistry:
+        def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry:
+            return prev[next_]
+
+        # can't quite get mypy to appreciate this one :)
+        return reduce(_red, raw, cls.root)  # type: ignore
+
+    def __add__(self, other: PathRegistry) -> PathRegistry:
+        def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry:
+            return prev[next_]
 
-    def token(self, token):
+        return reduce(_red, other.path, self)
+
+    def __str__(self) -> str:
+        return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]"
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({self.path!r})"
+
+
+class CreatesToken(PathRegistry):
+    __slots__ = ()
+
+    is_aliased_class: bool
+    is_root: bool
+
+    def token(self, token: str) -> TokenRegistry:
         if token.endswith(f":{_WILDCARD_TOKEN}"):
             return TokenRegistry(self, token)
         elif token.endswith(f":{_DEFAULT_TOKEN}"):
@@ -228,34 +326,47 @@ class PathRegistry(HasCacheKey):
         else:
             raise exc.ArgumentError(f"invalid token: {token}")
 
-    def __add__(self, other):
-        return reduce(lambda prev, next: prev[next], other.path, self)
-
-    def __str__(self):
-        return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]"
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}({self.path!r})"
-
 
-class RootRegistry(PathRegistry):
+class RootRegistry(CreatesToken):
     """Root registry, defers to mappers so that
     paths are maintained per-root-mapper.
 
     """
 
+    __slots__ = ()
+
     inherit_cache = True
 
     path = natural_path = ()
     has_entity = False
     is_aliased_class = False
     is_root = True
+    is_unnatural = False
+
+    @overload
+    def __getitem__(self, entity: str) -> TokenRegistry:
+        ...
+
+    @overload
+    def __getitem__(
+        self, entity: _InternalEntityType[Any]
+    ) -> AbstractEntityRegistry:
+        ...
 
-    def __getitem__(self, entity):
+    def __getitem__(
+        self, entity: Union[str, _InternalEntityType[Any]]
+    ) -> Union[TokenRegistry, AbstractEntityRegistry]:
         if entity in PathToken._intern:
+            if TYPE_CHECKING:
+                assert isinstance(entity, str)
             return TokenRegistry(self, PathToken._intern[entity])
         else:
-            return inspection.inspect(entity)._path_registry
+            try:
+                return entity._path_registry  # type: ignore
+            except AttributeError:
+                raise IndexError(
+                    f"invalid argument for RootRegistry.__getitem__: {entity}"
+                )
 
 
 PathRegistry.root = RootRegistry()
@@ -264,17 +375,19 @@ PathRegistry.root = RootRegistry()
 class PathToken(orm_base.InspectionAttr, HasCacheKey, str):
     """cacheable string token"""
 
-    _intern = {}
+    _intern: Dict[str, PathToken] = {}
 
-    def _gen_cache_key(self, anon_map, bindparams):
+    def _gen_cache_key(
+        self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+    ) -> Tuple[Any, ...]:
         return (str(self),)
 
     @property
-    def _path_for_compare(self):
+    def _path_for_compare(self) -> Optional[_PathRepresentation]:
         return None
 
     @classmethod
-    def intern(cls, strvalue):
+    def intern(cls, strvalue: str) -> PathToken:
         if strvalue in cls._intern:
             return cls._intern[strvalue]
         else:
@@ -287,7 +400,10 @@ class TokenRegistry(PathRegistry):
 
     inherit_cache = True
 
-    def __init__(self, parent, token):
+    token: str
+    parent: CreatesToken
+
+    def __init__(self, parent: CreatesToken, token: str):
         token = PathToken.intern(token)
 
         self.token = token
@@ -299,21 +415,33 @@ class TokenRegistry(PathRegistry):
 
     is_token = True
 
-    def generate_for_superclasses(self):
-        if not self.parent.is_aliased_class and not self.parent.is_root:
-            for ent in self.parent.mapper.iterate_to_root():
-                yield TokenRegistry(self.parent.parent[ent], self.token)
+    def generate_for_superclasses(self) -> Iterator[PathRegistry]:
+        parent = self.parent
+        if is_root(parent):
+            yield self
+            return
+
+        if TYPE_CHECKING:
+            assert isinstance(parent, AbstractEntityRegistry)
+        if not parent.is_aliased_class:
+            for mp_ent in parent.mapper.iterate_to_root():
+                yield TokenRegistry(parent.parent[mp_ent], self.token)
         elif (
-            self.parent.is_aliased_class
-            and self.parent.entity._is_with_polymorphic
+            parent.is_aliased_class
+            and cast(
+                "AliasedInsp[Any]",
+                parent.entity,
+            )._is_with_polymorphic
         ):
             yield self
-            for ent in self.parent.entity._with_polymorphic_entities:
-                yield TokenRegistry(self.parent.parent[ent], self.token)
+            for ent in cast(
+                "AliasedInsp[Any]", parent.entity
+            )._with_polymorphic_entities:
+                yield TokenRegistry(parent.parent[ent], self.token)
         else:
             yield self
 
-    def __getitem__(self, entity):
+    def __getitem__(self, entity: Any) -> Any:
         try:
             return self.path[entity]
         except TypeError as err:
@@ -321,23 +449,42 @@ class TokenRegistry(PathRegistry):
 
 
 class PropRegistry(PathRegistry):
-    is_unnatural = False
+    __slots__ = (
+        "prop",
+        "parent",
+        "path",
+        "natural_path",
+        "has_entity",
+        "entity",
+        "mapper",
+        "_wildcard_path_loader_key",
+        "_default_path_loader_key",
+        "_loader_key",
+        "is_unnatural",
+    )
     inherit_cache = True
 
-    def __init__(self, parent, prop):
+    prop: MapperProperty[Any]
+    mapper: Optional[Mapper[Any]]
+    entity: Optional[_InternalEntityType[Any]]
+
+    def __init__(
+        self, parent: AbstractEntityRegistry, prop: MapperProperty[Any]
+    ):
         # restate this path in terms of the
         # given MapperProperty's parent.
-        insp = inspection.inspect(parent[-1])
-        natural_parent = parent
+        insp = cast("_InternalEntityType[Any]", parent[-1])
+        natural_parent: AbstractEntityRegistry = parent
+        self.is_unnatural = False
 
-        if not insp.is_aliased_class or insp._use_mapper_path:
+        if not insp.is_aliased_class or insp._use_mapper_path:  # type: ignore
             parent = natural_parent = parent.parent[prop.parent]
         elif (
             insp.is_aliased_class
             and insp.with_polymorphic_mappers
             and prop.parent in insp.with_polymorphic_mappers
         ):
-            subclass_entity = parent[-1]._entity_for_mapper(prop.parent)
+            subclass_entity: _InternalEntityType[Any] = parent[-1]._entity_for_mapper(prop.parent)  # type: ignore  # noqa: E501
             parent = parent.parent[subclass_entity]
 
             # when building a path where with_polymorphic() is in use,
@@ -388,43 +535,74 @@ class PropRegistry(PathRegistry):
         self.parent = parent
         self.path = parent.path + (prop,)
         self.natural_path = natural_parent.natural_path + (prop,)
+        self.has_entity = prop._links_to_entity
+        if prop._is_relationship:
+            if TYPE_CHECKING:
+                assert isinstance(prop, Relationship)
+            self.entity = prop.entity
+            self.mapper = prop.mapper
+        else:
+            self.entity = None
+            self.mapper = None
 
         self._wildcard_path_loader_key = (
             "loader",
-            parent.path + self.prop._wildcard_token,
+            parent.path + self.prop._wildcard_token,  # type: ignore
         )
         self._default_path_loader_key = self.prop._default_path_loader_key
         self._loader_key = ("loader", self.natural_path)
 
-    @util.memoized_property
-    def has_entity(self):
-        return self.prop._links_to_entity
+    @property
+    def entity_path(self) -> AbstractEntityRegistry:
+        assert self.entity is not None
+        return self[self.entity]
 
-    @util.memoized_property
-    def entity(self):
-        return self.prop.entity
+    @overload
+    def __getitem__(self, entity: slice) -> _PathRepresentation:
+        ...
 
-    @property
-    def mapper(self):
-        return self.prop.mapper
+    @overload
+    def __getitem__(self, entity: int) -> _PathElementType:
+        ...
 
-    @property
-    def entity_path(self):
-        return self[self.entity]
+    @overload
+    def __getitem__(
+        self, entity: _InternalEntityType[Any]
+    ) -> AbstractEntityRegistry:
+        ...
 
-    def __getitem__(self, entity):
+    def __getitem__(
+        self, entity: Union[int, slice, _InternalEntityType[Any]]
+    ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]:
         if isinstance(entity, (int, slice)):
             return self.path[entity]
         else:
             return SlotsEntityRegistry(self, entity)
 
 
-class AbstractEntityRegistry(PathRegistry):
-    __slots__ = ()
+class AbstractEntityRegistry(CreatesToken):
+    __slots__ = (
+        "key",
+        "parent",
+        "is_aliased_class",
+        "path",
+        "entity",
+        "natural_path",
+    )
 
     has_entity = True
-
-    def __init__(self, parent, entity):
+    is_entity = True
+
+    parent: Union[RootRegistry, PropRegistry]
+    key: _InternalEntityType[Any]
+    entity: _InternalEntityType[Any]
+    is_aliased_class: bool
+
+    def __init__(
+        self,
+        parent: Union[RootRegistry, PropRegistry],
+        entity: _InternalEntityType[Any],
+    ):
         self.key = entity
         self.parent = parent
         self.is_aliased_class = entity.is_aliased_class
@@ -447,11 +625,11 @@ class AbstractEntityRegistry(PathRegistry):
         if parent.path and (self.is_aliased_class or parent.is_unnatural):
             # this is an infrequent code path used only for loader strategies
             # that also make use of of_type().
-            if entity.mapper.isa(parent.natural_path[-1].entity):
+            if entity.mapper.isa(parent.natural_path[-1].entity):  # type: ignore # noqa: E501
                 self.natural_path = parent.natural_path + (entity.mapper,)
             else:
                 self.natural_path = parent.natural_path + (
-                    parent.natural_path[-1].entity,
+                    parent.natural_path[-1].entity,  # type: ignore
                 )
         # it seems to make sense that since these paths get mixed up
         # with statements that are cached or not, we should make
@@ -465,19 +643,35 @@ class AbstractEntityRegistry(PathRegistry):
             self.natural_path = self.path
 
     @property
-    def entity_path(self):
+    def entity_path(self) -> PathRegistry:
         return self
 
     @property
-    def mapper(self):
-        return inspection.inspect(self.entity).mapper
+    def mapper(self) -> Mapper[Any]:
+        return self.entity.mapper
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return True
 
-    __nonzero__ = __bool__
+    @overload
+    def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
+        ...
+
+    @overload
+    def __getitem__(self, entity: str) -> TokenRegistry:
+        ...
+
+    @overload
+    def __getitem__(self, entity: int) -> _PathElementType:
+        ...
 
-    def __getitem__(self, entity):
+    @overload
+    def __getitem__(self, entity: slice) -> _PathRepresentation:
+        ...
+
+    def __getitem__(
+        self, entity: Any
+    ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]:
         if isinstance(entity, (int, slice)):
             return self.path[entity]
         elif entity in PathToken._intern:
@@ -491,31 +685,40 @@ class SlotsEntityRegistry(AbstractEntityRegistry):
     # version
     inherit_cache = True
 
-    __slots__ = (
-        "key",
-        "parent",
-        "is_aliased_class",
-        "entity",
-        "path",
-        "natural_path",
-    )
+
+class _ERDict(Dict[Any, Any]):
+    def __init__(self, registry: CachingEntityRegistry):
+        self.registry = registry
+
+    def __missing__(self, key: Any) -> PropRegistry:
+        self[key] = item = PropRegistry(self.registry, key)
+
+        return item
 
 
-class CachingEntityRegistry(AbstractEntityRegistry, dict):
+class CachingEntityRegistry(AbstractEntityRegistry):
     # for long lived mapper, return dict based caching
     # version that creates reference cycles
 
+    __slots__ = ("_cache",)
+
     inherit_cache = True
 
-    def __getitem__(self, entity):
+    def __init__(
+        self,
+        parent: Union[RootRegistry, PropRegistry],
+        entity: _InternalEntityType[Any],
+    ):
+        super().__init__(parent, entity)
+        self._cache = _ERDict(self)
+
+    def pop(self, key: Any, default: Any) -> Any:
+        return self._cache.pop(key, default)
+
+    def __getitem__(self, entity: Any) -> Any:
         if isinstance(entity, (int, slice)):
             return self.path[entity]
         elif isinstance(entity, PathToken):
             return TokenRegistry(self, entity)
         else:
-            return dict.__getitem__(self, entity)
-
-    def __missing__(self, key):
-        self[key] = item = PropRegistry(self, key)
-
-        return item
+            return self._cache[entity]
index c01825b6d679b8e806fc76d0cb99ddf21f11b816..9f37e84571e929e631d0b5206f5135899f4b600f 100644 (file)
@@ -19,6 +19,8 @@ from typing import cast
 from typing import List
 from typing import Optional
 from typing import Set
+from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 
 from . import attributes
@@ -38,17 +40,22 @@ from .util import _orm_full_deannotate
 from .. import exc as sa_exc
 from .. import ForeignKey
 from .. import log
-from .. import sql
 from .. import util
 from ..sql import coercions
 from ..sql import roles
 from ..sql import sqltypes
 from ..sql.schema import Column
+from ..sql.schema import SchemaConst
 from ..util.typing import de_optionalize_union_types
 from ..util.typing import de_stringify_annotation
 from ..util.typing import is_fwd_ref
 from ..util.typing import NoneType
 
+if TYPE_CHECKING:
+    from ._typing import _ORMColumnExprArgument
+    from ..sql._typing import _InfoType
+    from ..sql.elements import ColumnElement
+
 _T = TypeVar("_T", bound=Any)
 _PT = TypeVar("_PT", bound=Any)
 
@@ -78,6 +85,10 @@ class ColumnProperty(
     inherit_cache = True
     _links_to_entity = False
 
+    columns: List[ColumnElement[Any]]
+
+    _is_polymorphic_discriminator: bool
+
     __slots__ = (
         "_orig_columns",
         "columns",
@@ -99,7 +110,19 @@ class ColumnProperty(
     )
 
     def __init__(
-        self, column: sql.ColumnElement[_T], *additional_columns, **kwargs
+        self,
+        column: _ORMColumnExprArgument[_T],
+        *additional_columns: _ORMColumnExprArgument[Any],
+        group: Optional[str] = None,
+        deferred: bool = False,
+        raiseload: bool = False,
+        comparator_factory: Optional[Type[PropComparator]] = None,
+        descriptor: Optional[Any] = None,
+        active_history: bool = False,
+        expire_on_flush: bool = True,
+        info: Optional[_InfoType] = None,
+        doc: Optional[str] = None,
+        _instrument: bool = True,
     ):
         super(ColumnProperty, self).__init__()
         columns = (column,) + additional_columns
@@ -112,23 +135,24 @@ class ColumnProperty(
             )
             for c in columns
         ]
-        self.parent = self.key = None
-        self.group = kwargs.pop("group", None)
-        self.deferred = kwargs.pop("deferred", False)
-        self.raiseload = kwargs.pop("raiseload", False)
-        self.instrument = kwargs.pop("_instrument", True)
-        self.comparator_factory = kwargs.pop(
-            "comparator_factory", self.__class__.Comparator
+        self.group = group
+        self.deferred = deferred
+        self.raiseload = raiseload
+        self.instrument = _instrument
+        self.comparator_factory = (
+            comparator_factory
+            if comparator_factory is not None
+            else self.__class__.Comparator
         )
-        self.descriptor = kwargs.pop("descriptor", None)
-        self.active_history = kwargs.pop("active_history", False)
-        self.expire_on_flush = kwargs.pop("expire_on_flush", True)
+        self.descriptor = descriptor
+        self.active_history = active_history
+        self.expire_on_flush = expire_on_flush
 
-        if "info" in kwargs:
-            self.info = kwargs.pop("info")
+        if info is not None:
+            self.info = info
 
-        if "doc" in kwargs:
-            self.doc = kwargs.pop("doc")
+        if doc is not None:
+            self.doc = doc
         else:
             for col in reversed(self.columns):
                 doc = getattr(col, "doc", None)
@@ -138,12 +162,6 @@ class ColumnProperty(
             else:
                 self.doc = None
 
-        if kwargs:
-            raise TypeError(
-                "%s received unexpected keyword argument(s): %s"
-                % (self.__class__.__name__, ", ".join(sorted(kwargs.keys())))
-            )
-
         util.set_creation_order(self)
 
         self.strategy_key = (
@@ -445,7 +463,10 @@ class MappedColumn(
         self.deferred = kw.pop("deferred", False)
         self.column = cast("Column[_T]", Column(*arg, **kw))
         self.foreign_keys = self.column.foreign_keys
-        self._has_nullable = "nullable" in kw
+        self._has_nullable = "nullable" in kw and kw.get("nullable") not in (
+            None,
+            SchemaConst.NULL_UNSPECIFIED,
+        )
         util.set_creation_order(self)
 
     def _copy(self, **kw):
index a754bd4f2a6435bf86da32aa68adab22d0106f13..395d01a1eac66cc3eef369498481630943e3a039 100644 (file)
@@ -30,6 +30,7 @@ from typing import Optional
 from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 
 from . import exc as orm_exc
 from . import interfaces
@@ -77,6 +78,8 @@ from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 
 if TYPE_CHECKING:
     from ..sql.selectable import _SetupJoinsElement
+    from ..sql.selectable import Alias
+    from ..sql.selectable import Subquery
 
 __all__ = ["Query", "QueryContext"]
 
@@ -2769,14 +2772,14 @@ class AliasOption(interfaces.LoaderOption):
         "for entities to be matched up to a query that is established "
         "via :meth:`.Query.from_statement` and now does nothing.",
     )
-    def __init__(self, alias):
+    def __init__(self, alias: Union[Alias, Subquery]):
         r"""Return a :class:`.MapperOption` that will indicate to the
         :class:`_query.Query`
         that the main table has been aliased.
 
         """
 
-    def process_compile_state(self, compile_state):
+    def process_compile_state(self, compile_state: ORMCompileState):
         pass
 
 
index 58c7c4efd546984d24a8e1ec138103f00a66747d..66021c9c2028f6d693221330a66b183d46f2491c 100644 (file)
@@ -21,7 +21,10 @@ import re
 import typing
 from typing import Any
 from typing import Callable
+from typing import Dict
 from typing import Optional
+from typing import Sequence
+from typing import Tuple
 from typing import Type
 from typing import TypeVar
 from typing import Union
@@ -30,6 +33,7 @@ import weakref
 from . import attributes
 from . import strategy_options
 from .base import _is_mapped_class
+from .base import class_mapper
 from .base import state_str
 from .interfaces import _IntrospectsAnnotations
 from .interfaces import MANYTOMANY
@@ -53,7 +57,9 @@ from ..sql import expression
 from ..sql import operators
 from ..sql import roles
 from ..sql import visitors
-from ..sql.elements import SQLCoreOperations
+from ..sql._typing import _ColumnExpressionArgument
+from ..sql._typing import _HasClauseElement
+from ..sql.elements import ColumnClause
 from ..sql.util import _deep_deannotate
 from ..sql.util import _shallow_annotate
 from ..sql.util import adapt_criterion_to_null
@@ -61,11 +67,14 @@ from ..sql.util import ClauseAdapter
 from ..sql.util import join_condition
 from ..sql.util import selectables_overlap
 from ..sql.util import visit_binary_product
+from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
+    from ._typing import _EntityType
     from .mapper import Mapper
     from .util import AliasedClass
     from .util import AliasedInsp
+    from ..sql.elements import ColumnElement
 
 _T = TypeVar("_T", bound=Any)
 _PT = TypeVar("_PT", bound=Any)
@@ -81,6 +90,34 @@ _RelationshipArgumentType = Union[
     Callable[[], "AliasedClass[_T]"],
 ]
 
+_LazyLoadArgumentType = Literal[
+    "select",
+    "joined",
+    "selectin",
+    "subquery",
+    "raise",
+    "raise_on_sql",
+    "noload",
+    "immediate",
+    "dynamic",
+    True,
+    False,
+    None,
+]
+
+
+_RelationshipJoinConditionArgument = Union[
+    str, _ColumnExpressionArgument[bool]
+]
+_ORMOrderByArgument = Union[
+    Literal[False], str, _ColumnExpressionArgument[Any]
+]
+_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
+_ORMColCollectionArgument = Union[
+    str,
+    Sequence[Union[ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole]],
+]
+
 
 def remote(expr):
     """Annotate a portion of a primaryjoin expression
@@ -144,6 +181,7 @@ class Relationship(
     inherit_cache = True
 
     _links_to_entity = True
+    _is_relationship = True
 
     _persistence_only = dict(
         passive_deletes=False,
@@ -159,38 +197,39 @@ class Relationship(
         self,
         argument: Optional[_RelationshipArgumentType[_T]] = None,
         secondary=None,
+        *,
+        uselist=None,
+        collection_class=None,
         primaryjoin=None,
         secondaryjoin=None,
-        foreign_keys=None,
-        uselist=None,
+        back_populates=None,
         order_by=False,
         backref=None,
-        back_populates=None,
+        cascade_backrefs=False,
         overlaps=None,
         post_update=False,
-        cascade=False,
+        cascade="save-update, merge",
         viewonly=False,
-        lazy="select",
-        collection_class=None,
-        passive_deletes=_persistence_only["passive_deletes"],
-        passive_updates=_persistence_only["passive_updates"],
+        lazy: _LazyLoadArgumentType = "select",
+        passive_deletes=False,
+        passive_updates=True,
+        active_history=False,
+        enable_typechecks=True,
+        foreign_keys=None,
         remote_side=None,
-        enable_typechecks=_persistence_only["enable_typechecks"],
         join_depth=None,
         comparator_factory=None,
         single_parent=False,
         innerjoin=False,
         distinct_target_key=None,
-        doc=None,
-        active_history=_persistence_only["active_history"],
-        cascade_backrefs=_persistence_only["cascade_backrefs"],
         load_on_pending=False,
-        bake_queries=True,
-        _local_remote_pairs=None,
         query_class=None,
         info=None,
         omit_join=None,
         sync_backref=None,
+        doc=None,
+        bake_queries=True,
+        _local_remote_pairs=None,
         _legacy_inactive_history_style=False,
     ):
         super(Relationship, self).__init__()
@@ -250,7 +289,6 @@ class Relationship(
 
         self.omit_join = omit_join
         self.local_remote_pairs = _local_remote_pairs
-        self.bake_queries = bake_queries
         self.load_on_pending = load_on_pending
         self.comparator_factory = comparator_factory or Relationship.Comparator
         self.comparator = self.comparator_factory(self, None)
@@ -267,12 +305,7 @@ class Relationship(
         else:
             self._overlaps = ()
 
-        if cascade is not False:
-            self.cascade = cascade
-        elif self.viewonly:
-            self.cascade = "none"
-        else:
-            self.cascade = "save-update, merge"
+        self.cascade = cascade
 
         self.order_by = order_by
 
@@ -539,9 +572,9 @@ class Relationship(
 
         def _criterion_exists(
             self,
-            criterion: Optional[SQLCoreOperations[Any]] = None,
+            criterion: Optional[_ColumnExpressionArgument[bool]] = None,
             **kwargs: Any,
-        ) -> Exists[bool]:
+        ) -> Exists:
             if getattr(self, "_of_type", None):
                 info = inspect(self._of_type)
                 target_mapper, to_selectable, is_aliased_class = (
@@ -898,7 +931,12 @@ class Relationship(
 
     comparator: Comparator[_T]
 
-    def _with_parent(self, instance, alias_secondary=True, from_entity=None):
+    def _with_parent(
+        self,
+        instance: object,
+        alias_secondary: bool = True,
+        from_entity: Optional[_EntityType[Any]] = None,
+    ) -> ColumnElement[bool]:
         assert instance is not None
         adapt_source = None
         if from_entity is not None:
@@ -1502,7 +1540,7 @@ class Relationship(
             argument = argument
 
         if isinstance(argument, type):
-            entity = mapperlib.class_mapper(argument, configure=False)
+            entity = class_mapper(argument, configure=False)
         else:
             try:
                 entity = inspect(argument)
@@ -1568,7 +1606,7 @@ class Relationship(
         """Test that this relationship is legal, warn about
         inheritance conflicts."""
         mapperlib = util.preloaded.orm_mapper
-        if self.parent.non_primary and not mapperlib.class_mapper(
+        if self.parent.non_primary and not class_mapper(
             self.parent.class_, configure=False
         ).has_property(self.key):
             raise sa_exc.ArgumentError(
@@ -1585,29 +1623,23 @@ class Relationship(
             )
 
     @property
-    def cascade(self):
+    def cascade(self) -> CascadeOptions:
         """Return the current cascade setting for this
         :class:`.Relationship`.
         """
         return self._cascade
 
     @cascade.setter
-    def cascade(self, cascade):
+    def cascade(self, cascade: Union[str, CascadeOptions]):
         self._set_cascade(cascade)
 
-    def _set_cascade(self, cascade):
-        cascade = CascadeOptions(cascade)
+    def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]):
+        cascade = CascadeOptions(cascade_arg)
 
         if self.viewonly:
-            non_viewonly = set(cascade).difference(
-                CascadeOptions._viewonly_cascades
+            cascade = CascadeOptions(
+                cascade.intersection(CascadeOptions._viewonly_cascades)
             )
-            if non_viewonly:
-                raise sa_exc.ArgumentError(
-                    'Cascade settings "%s" apply to persistence operations '
-                    "and should not be combined with a viewonly=True "
-                    "relationship." % (", ".join(sorted(non_viewonly)))
-                )
 
         if "mapper" in self.__dict__:
             self._check_cascade_settings(cascade)
@@ -1754,8 +1786,8 @@ class Relationship(
             relationship = Relationship(
                 parent,
                 self.secondary,
-                pj,
-                sj,
+                primaryjoin=pj,
+                secondaryjoin=sj,
                 foreign_keys=foreign_keys,
                 back_populates=self.key,
                 **kwargs,
index 5b1d0bb087e16a34061035fb21c822b4b0fbcea2..74035ec0aaa33e65f968a3176eaccf67ba10760b 100644 (file)
@@ -39,6 +39,7 @@ from . import persistence
 from . import query
 from . import state as statelib
 from ._typing import _O
+from ._typing import insp_is_mapper
 from ._typing import is_composite_class
 from ._typing import is_user_defined_option
 from .base import _class_to_mapper
@@ -69,12 +70,14 @@ from ..engine.util import TransactionalContext
 from ..event import dispatcher
 from ..event import EventTarget
 from ..inspection import inspect
+from ..inspection import Inspectable
 from ..sql import coercions
 from ..sql import dml
 from ..sql import roles
 from ..sql import Select
 from ..sql import visitors
 from ..sql.base import CompileState
+from ..sql.schema import Table
 from ..sql.selectable import ForUpdateArg
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..util import IdentitySet
@@ -90,6 +93,7 @@ if typing.TYPE_CHECKING:
     from .path_registry import PathRegistry
     from ..engine import Result
     from ..engine import Row
+    from ..engine import RowMapping
     from ..engine.base import Transaction
     from ..engine.base import TwoPhaseTransaction
     from ..engine.interfaces import _CoreAnyExecuteParams
@@ -103,6 +107,7 @@ if typing.TYPE_CHECKING:
     from ..sql.base import Executable
     from ..sql.elements import ClauseElement
     from ..sql.schema import Table
+    from ..sql.selectable import TableClause
 
 __all__ = [
     "Session",
@@ -184,7 +189,7 @@ class _SessionClassMethods:
         ident: Union[Any, Tuple[Any, ...]] = None,
         *,
         instance: Optional[Any] = None,
-        row: Optional[Row] = None,
+        row: Optional[Union[Row, RowMapping]] = None,
         identity_token: Optional[Any] = None,
     ) -> _IdentityKeyType[Any]:
         """Return an identity key.
@@ -2050,9 +2055,12 @@ class Session(_SessionClassMethods, EventTarget):
             else:
                 self.__binds[key] = bind
         else:
-            if insp.is_selectable:
+            if TYPE_CHECKING:
+                assert isinstance(insp, Inspectable)
+
+            if isinstance(insp, Table):
                 self.__binds[insp] = bind
-            elif insp.is_mapper:
+            elif insp_is_mapper(insp):
                 self.__binds[insp.class_] = bind
                 for _selectable in insp._all_tables:
                     self.__binds[_selectable] = bind
@@ -2211,7 +2219,7 @@ class Session(_SessionClassMethods, EventTarget):
         # we don't have self.bind and either have self.__binds
         # or we don't have self.__binds (which is legacy).  Look at the
         # mapper and the clause
-        if mapper is clause is None:
+        if mapper is None and clause is None:
             if self.bind:
                 return self.bind
             else:
@@ -2350,7 +2358,10 @@ class Session(_SessionClassMethods, EventTarget):
         key = mapper.identity_key_from_primary_key(
             primary_key_identity, identity_token=identity_token
         )
-        return loading.get_from_identity(self, mapper, key, passive)
+
+        # work around: https://github.com/python/typing/discussions/1143
+        return_value = loading.get_from_identity(self, mapper, key, passive)
+        return return_value
 
     @util.non_memoized_property
     @contextlib.contextmanager
index 85e015193771529f22448405e584a75c8a5b4eeb..2d85ba7f6490d273663961d828a13e2da73c6c9e 100644 (file)
@@ -2162,10 +2162,11 @@ class JoinedLoader(AbstractRelationshipLoader):
         else:
             to_adapt = self._gen_pooled_aliased_class(compile_state)
 
-        clauses = inspect(to_adapt)._memo(
+        to_adapt_insp = inspect(to_adapt)
+        clauses = to_adapt_insp._memo(
             ("joinedloader_ormadapter", self),
             orm_util.ORMAdapter,
-            to_adapt,
+            to_adapt_insp,
             equivalents=self.mapper._equivalent_columns,
             adapt_required=True,
             allow_label_resolve=False,
index 4699781a42bfd90d56e75a25c11fd0829b7e5f9b..3934de5355325803ec906110e5fee9bd9f1d3f5c 100644 (file)
@@ -11,8 +11,16 @@ import re
 import types
 import typing
 from typing import Any
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
 from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Match
 from typing import Optional
+from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TypeVar
@@ -20,32 +28,35 @@ from typing import Union
 import weakref
 
 from . import attributes  # noqa
-from .base import _class_to_mapper  # noqa
-from .base import _never_set  # noqa
-from .base import _none_set  # noqa
-from .base import attribute_str  # noqa
-from .base import class_mapper  # noqa
-from .base import InspectionAttr  # noqa
-from .base import instance_str  # noqa
-from .base import object_mapper  # noqa
-from .base import object_state  # noqa
-from .base import state_attribute_str  # noqa
-from .base import state_class_str  # noqa
-from .base import state_str  # noqa
+from ._typing import _O
+from ._typing import insp_is_aliased_class
+from ._typing import insp_is_mapper
+from ._typing import prop_is_relationship
+from .base import _class_to_mapper as _class_to_mapper
+from .base import _never_set as _never_set
+from .base import _none_set as _none_set
+from .base import attribute_str as attribute_str
+from .base import class_mapper as class_mapper
+from .base import InspectionAttr as InspectionAttr
+from .base import instance_str as instance_str
+from .base import object_mapper as object_mapper
+from .base import object_state as object_state
+from .base import state_attribute_str as state_attribute_str
+from .base import state_class_str as state_class_str
+from .base import state_str as state_str
 from .interfaces import CriteriaOption
-from .interfaces import MapperProperty  # noqa
+from .interfaces import MapperProperty as MapperProperty
 from .interfaces import ORMColumnsClauseRole
 from .interfaces import ORMEntityColumnsClauseRole
 from .interfaces import ORMFromClauseRole
-from .interfaces import PropComparator  # noqa
-from .path_registry import PathRegistry  # noqa
+from .interfaces import PropComparator as PropComparator
+from .path_registry import PathRegistry as PathRegistry
 from .. import event
 from .. import exc as sa_exc
 from .. import inspection
 from .. import sql
 from .. import util
 from ..engine.result import result_tuple
-from ..sql import base as sql_base
 from ..sql import coercions
 from ..sql import expression
 from ..sql import lambdas
@@ -54,19 +65,39 @@ from ..sql import util as sql_util
 from ..sql import visitors
 from ..sql.annotation import SupportsCloneAnnotations
 from ..sql.base import ColumnCollection
+from ..sql.cache_key import HasCacheKey
+from ..sql.cache_key import MemoizedHasCacheKey
+from ..sql.elements import ColumnElement
 from ..sql.selectable import FromClause
 from ..util.langhelpers import MemoizedSlots
 from ..util.typing import de_stringify_annotation
 from ..util.typing import is_origin_of
+from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
     from ._typing import _EntityType
     from ._typing import _IdentityKeyType
     from ._typing import _InternalEntityType
+    from ._typing import _ORMColumnExprArgument
+    from .context import _MapperEntity
+    from .context import ORMCompileState
     from .mapper import Mapper
+    from .relationships import Relationship
     from ..engine import Row
+    from ..engine import RowMapping
+    from ..sql._typing import _ColumnExpressionArgument
+    from ..sql._typing import _EquivalentColumnMap
+    from ..sql._typing import _FromClauseArgument
+    from ..sql._typing import _OnClauseArgument
     from ..sql._typing import _PropagateAttrsType
+    from ..sql.base import ReadOnlyColumnCollection
+    from ..sql.elements import BindParameter
+    from ..sql.selectable import _ColumnsClauseElement
     from ..sql.selectable import Alias
+    from ..sql.selectable import Subquery
+    from ..sql.visitors import _ET
+    from ..sql.visitors import anon_map
+    from ..sql.visitors import ExternallyTraversible
 
 _T = TypeVar("_T", bound=Any)
 
@@ -84,7 +115,7 @@ all_cascades = frozenset(
 )
 
 
-class CascadeOptions(frozenset):
+class CascadeOptions(FrozenSet[str]):
     """Keeps track of the options sent to
     :paramref:`.relationship.cascade`"""
 
@@ -104,6 +135,13 @@ class CascadeOptions(frozenset):
         "delete_orphan",
     )
 
+    save_update: bool
+    delete: bool
+    refresh_expire: bool
+    merge: bool
+    expunge: bool
+    delete_orphan: bool
+
     def __new__(cls, value_list):
         if isinstance(value_list, str) or value_list is None:
             return cls.from_string(value_list)
@@ -127,7 +165,7 @@ class CascadeOptions(frozenset):
             values.clear()
         values.discard("all")
 
-        self = frozenset.__new__(CascadeOptions, values)
+        self = super().__new__(cls, values)  # type: ignore
         self.save_update = "save-update" in values
         self.delete = "delete" in values
         self.refresh_expire = "refresh-expire" in values
@@ -238,7 +276,7 @@ def polymorphic_union(
 
     """
 
-    colnames = util.OrderedSet()
+    colnames: util.OrderedSet[str] = util.OrderedSet()
     colnamemaps = {}
     types = {}
     for key in table_map:
@@ -299,13 +337,13 @@ def polymorphic_union(
 
 
 def identity_key(
-    class_: Optional[Type[Any]] = None,
+    class_: Optional[Type[_T]] = None,
     ident: Union[Any, Tuple[Any, ...]] = None,
     *,
-    instance: Optional[Any] = None,
-    row: Optional[Row] = None,
+    instance: Optional[_T] = None,
+    row: Optional[Union[Row, RowMapping]] = None,
     identity_token: Optional[Any] = None,
-) -> _IdentityKeyType:
+) -> _IdentityKeyType[_T]:
     r"""Generate "identity key" tuples, as are used as keys in the
     :attr:`.Session.identity_map` dictionary.
 
@@ -351,7 +389,7 @@ def identity_key(
     * ``identity_key(class, row=row, identity_token=token)``
 
       This form is similar to the class/tuple form, except is passed a
-      database result row as a :class:`.Row` object.
+      database result row as a :class:`.Row` or :class:`.RowMapping` object.
 
       E.g.::
 
@@ -375,7 +413,7 @@ def identity_key(
             if ident is None:
                 raise sa_exc.ArgumentError("ident or row is required")
             return mapper.identity_key_from_primary_key(
-                util.to_list(ident), identity_token=identity_token
+                tuple(util.to_list(ident)), identity_token=identity_token
             )
         else:
             return mapper.identity_key_from_row(
@@ -394,24 +432,26 @@ class ORMAdapter(sql_util.ColumnAdapter):
 
     """
 
-    is_aliased_class = False
-    aliased_insp = None
+    is_aliased_class: bool
+    aliased_insp: Optional[AliasedInsp[Any]]
 
     def __init__(
         self,
-        entity,
-        equivalents=None,
-        adapt_required=False,
-        allow_label_resolve=True,
-        anonymize_labels=False,
+        entity: _InternalEntityType[Any],
+        equivalents: Optional[_EquivalentColumnMap] = None,
+        adapt_required: bool = False,
+        allow_label_resolve: bool = True,
+        anonymize_labels: bool = False,
     ):
-        info = inspection.inspect(entity)
 
-        self.mapper = info.mapper
-        selectable = info.selectable
-        if info.is_aliased_class:
+        self.mapper = entity.mapper
+        selectable = entity.selectable
+        if insp_is_aliased_class(entity):
             self.is_aliased_class = True
-            self.aliased_insp = info
+            self.aliased_insp = entity
+        else:
+            self.is_aliased_class = False
+            self.aliased_insp = None
 
         sql_util.ColumnAdapter.__init__(
             self,
@@ -428,7 +468,7 @@ class ORMAdapter(sql_util.ColumnAdapter):
         return not entity or entity.isa(self.mapper)
 
 
-class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
+class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]):
     r"""Represents an "aliased" form of a mapped class for usage with Query.
 
     The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias`
@@ -489,19 +529,20 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
 
     def __init__(
         self,
-        mapped_class_or_ac: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"],
-        alias=None,
-        name=None,
-        flat=False,
-        adapt_on_names=False,
-        #  TODO: None for default here?
-        with_polymorphic_mappers=(),
-        with_polymorphic_discriminator=None,
-        base_alias=None,
-        use_mapper_path=False,
-        represents_outer_join=False,
+        mapped_class_or_ac: _EntityType[_O],
+        alias: Optional[FromClause] = None,
+        name: Optional[str] = None,
+        flat: bool = False,
+        adapt_on_names: bool = False,
+        with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]] = None,
+        with_polymorphic_discriminator: Optional[ColumnElement[Any]] = None,
+        base_alias: Optional[AliasedInsp[Any]] = None,
+        use_mapper_path: bool = False,
+        represents_outer_join: bool = False,
     ):
-        insp = inspection.inspect(mapped_class_or_ac)
+        insp = cast(
+            "_InternalEntityType[_O]", inspection.inspect(mapped_class_or_ac)
+        )
         mapper = insp.mapper
 
         nest_adapters = False
@@ -519,6 +560,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
         elif insp.is_aliased_class:
             nest_adapters = True
 
+        assert alias is not None
         self._aliased_insp = AliasedInsp(
             self,
             insp,
@@ -540,7 +582,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
         self.__name__ = f"aliased({mapper.class_.__name__})"
 
     @classmethod
-    def _reconstitute_from_aliased_insp(cls, aliased_insp):
+    def _reconstitute_from_aliased_insp(
+        cls, aliased_insp: AliasedInsp[_O]
+    ) -> AliasedClass[_O]:
         obj = cls.__new__(cls)
         obj.__name__ = f"aliased({aliased_insp.mapper.class_.__name__})"
         obj._aliased_insp = aliased_insp
@@ -555,7 +599,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
 
         return obj
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Any:
         try:
             _aliased_insp = self.__dict__["_aliased_insp"]
         except KeyError:
@@ -584,7 +628,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
 
         return attr
 
-    def _get_from_serialized(self, key, mapped_class, aliased_insp):
+    def _get_from_serialized(
+        self, key: str, mapped_class: _O, aliased_insp: AliasedInsp[_O]
+    ) -> Any:
         # this method is only used in terms of the
         # sqlalchemy.ext.serializer extension
         attr = getattr(mapped_class, key)
@@ -605,23 +651,25 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
 
         return attr
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<AliasedClass at 0x%x; %s>" % (
             id(self),
             self._aliased_insp._target.__name__,
         )
 
-    def __str__(self):
+    def __str__(self) -> str:
         return str(self._aliased_insp)
 
 
+@inspection._self_inspects
 class AliasedInsp(
     ORMEntityColumnsClauseRole,
     ORMFromClauseRole,
-    sql_base.HasCacheKey,
+    HasCacheKey,
     InspectionAttr,
     MemoizedSlots,
-    Generic[_T],
+    inspection.Inspectable["AliasedInsp[_O]"],
+    Generic[_O],
 ):
     """Provide an inspection interface for an
     :class:`.AliasedClass` object.
@@ -685,19 +733,36 @@ class AliasedInsp(
         "_nest_adapters",
     )
 
+    mapper: Mapper[_O]
+    selectable: FromClause
+    _adapter: sql_util.ColumnAdapter
+    with_polymorphic_mappers: Sequence[Mapper[Any]]
+    _with_polymorphic_entities: Sequence[AliasedInsp[Any]]
+
+    _weak_entity: weakref.ref[AliasedClass[_O]]
+    """the AliasedClass that refers to this AliasedInsp"""
+
+    _target: Union[_O, AliasedClass[_O]]
+    """the thing referred towards by the AliasedClass/AliasedInsp.
+
+    In the vast majority of cases, this is the mapped class.  However
+    it may also be another AliasedClass (alias of alias).
+
+    """
+
     def __init__(
         self,
-        entity: _EntityType,
-        inspected: _InternalEntityType,
-        selectable,
-        name,
-        with_polymorphic_mappers,
-        polymorphic_on,
-        _base_alias,
-        _use_mapper_path,
-        adapt_on_names,
-        represents_outer_join,
-        nest_adapters,
+        entity: AliasedClass[_O],
+        inspected: _InternalEntityType[_O],
+        selectable: FromClause,
+        name: Optional[str],
+        with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]],
+        polymorphic_on: Optional[ColumnElement[Any]],
+        _base_alias: Optional[AliasedInsp[Any]],
+        _use_mapper_path: bool,
+        adapt_on_names: bool,
+        represents_outer_join: bool,
+        nest_adapters: bool,
     ):
 
         mapped_class_or_ac = inspected.entity
@@ -752,23 +817,22 @@ class AliasedInsp(
         )
 
         if nest_adapters:
+            # supports "aliased class of aliased class" use case
+            assert isinstance(inspected, AliasedInsp)
             self._adapter = inspected._adapter.wrap(self._adapter)
 
         self._adapt_on_names = adapt_on_names
         self._target = mapped_class_or_ac
-        # self._target = mapper.class_  # mapped_class_or_ac
 
     @classmethod
     def _alias_factory(
         cls,
-        element: Union[
-            Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"
-        ],
-        alias=None,
-        name=None,
-        flat=False,
-        adapt_on_names=False,
-    ) -> Union["AliasedClass[_T]", "Alias"]:
+        element: Union[_EntityType[_O], FromClause],
+        alias: Optional[Union[Alias, Subquery]] = None,
+        name: Optional[str] = None,
+        flat: bool = False,
+        adapt_on_names: bool = False,
+    ) -> Union[AliasedClass[_O], FromClause]:
 
         if isinstance(element, FromClause):
             if adapt_on_names:
@@ -793,16 +857,16 @@ class AliasedInsp(
     @classmethod
     def _with_polymorphic_factory(
         cls,
-        base,
-        classes,
-        selectable=False,
-        flat=False,
-        polymorphic_on=None,
-        aliased=False,
-        innerjoin=False,
-        adapt_on_names=False,
-        _use_mapper_path=False,
-    ):
+        base: Union[_O, Mapper[_O]],
+        classes: Iterable[Type[Any]],
+        selectable: Union[Literal[False, None], FromClause] = False,
+        flat: bool = False,
+        polymorphic_on: Optional[ColumnElement[Any]] = None,
+        aliased: bool = False,
+        innerjoin: bool = False,
+        adapt_on_names: bool = False,
+        _use_mapper_path: bool = False,
+    ) -> AliasedClass[_O]:
 
         primary_mapper = _class_to_mapper(base)
 
@@ -816,7 +880,9 @@ class AliasedInsp(
             classes, selectable, innerjoin=innerjoin
         )
         if aliased or flat:
+            assert selectable is not None
             selectable = selectable._anonymous_fromclause(flat=flat)
+
         return AliasedClass(
             base,
             selectable,
@@ -828,7 +894,7 @@ class AliasedInsp(
         )
 
     @property
-    def entity(self):
+    def entity(self) -> AliasedClass[_O]:
         # to eliminate reference cycles, the AliasedClass is held weakly.
         # this produces some situations where the AliasedClass gets lost,
         # particularly when one is created internally and only the AliasedInsp
@@ -844,7 +910,7 @@ class AliasedInsp(
     is_aliased_class = True
     "always returns True"
 
-    def _memoized_method___clause_element__(self):
+    def _memoized_method___clause_element__(self) -> FromClause:
         return self.selectable._annotate(
             {
                 "parentmapper": self.mapper,
@@ -856,7 +922,7 @@ class AliasedInsp(
         )
 
     @property
-    def entity_namespace(self):
+    def entity_namespace(self) -> AliasedClass[_O]:
         return self.entity
 
     _cache_key_traversal = [
@@ -866,7 +932,7 @@ class AliasedInsp(
     ]
 
     @property
-    def class_(self):
+    def class_(self) -> Type[_O]:
         """Return the mapped class ultimately represented by this
         :class:`.AliasedInsp`."""
         return self.mapper.class_
@@ -878,7 +944,7 @@ class AliasedInsp(
         else:
             return PathRegistry.per_mapper(self)
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         return {
             "entity": self.entity,
             "mapper": self.mapper,
@@ -893,8 +959,8 @@ class AliasedInsp(
             "nest_adapters": self._nest_adapters,
         }
 
-    def __setstate__(self, state):
-        self.__init__(
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__init__(  # type: ignore
             state["entity"],
             state["mapper"],
             state["alias"],
@@ -908,7 +974,7 @@ class AliasedInsp(
             state["nest_adapters"],
         )
 
-    def _merge_with(self, other):
+    def _merge_with(self, other: AliasedInsp[_O]) -> AliasedInsp[_O]:
         # assert self._is_with_polymorphic
         # assert other._is_with_polymorphic
 
@@ -929,7 +995,6 @@ class AliasedInsp(
             classes, None, innerjoin=not other.represents_outer_join
         )
         selectable = selectable._anonymous_fromclause(flat=True)
-
         return AliasedClass(
             primary_mapper,
             selectable,
@@ -937,10 +1002,13 @@ class AliasedInsp(
             with_polymorphic_discriminator=other.polymorphic_on,
             use_mapper_path=other._use_mapper_path,
             represents_outer_join=other.represents_outer_join,
-        )
+        )._aliased_insp
 
-    def _adapt_element(self, elem, key=None):
-        d = {
+    def _adapt_element(
+        self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None
+    ) -> _ORMColumnExprArgument[_T]:
+        assert isinstance(elem, ColumnElement)
+        d: Dict[str, Any] = {
             "parententity": self,
             "parentmapper": self.mapper,
         }
@@ -1084,35 +1152,45 @@ class LoaderCriteriaOption(CriteriaOption):
         ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
     ]
 
+    root_entity: Optional[Type[Any]]
+    entity: Optional[_InternalEntityType[Any]]
+    where_criteria: Union[ColumnElement[bool], lambdas.DeferredLambdaElement]
+    deferred_where_criteria: bool
+    include_aliases: bool
+    propagate_to_loaders: bool
+
     def __init__(
         self,
-        entity_or_base,
-        where_criteria,
-        loader_only=False,
-        include_aliases=False,
-        propagate_to_loaders=True,
-        track_closure_variables=True,
+        entity_or_base: _EntityType[Any],
+        where_criteria: _ColumnExpressionArgument[bool],
+        loader_only: bool = False,
+        include_aliases: bool = False,
+        propagate_to_loaders: bool = True,
+        track_closure_variables: bool = True,
     ):
-        entity = inspection.inspect(entity_or_base, False)
+        entity = cast(
+            "_InternalEntityType[Any]",
+            inspection.inspect(entity_or_base, False),
+        )
         if entity is None:
-            self.root_entity = entity_or_base
+            self.root_entity = cast("Type[Any]", entity_or_base)
             self.entity = None
         else:
             self.root_entity = None
             self.entity = entity
 
         if callable(where_criteria):
+            if self.root_entity is not None:
+                wrap_entity = self.root_entity
+            else:
+                assert entity is not None
+                wrap_entity = entity.entity
+
             self.deferred_where_criteria = True
             self.where_criteria = lambdas.DeferredLambdaElement(
-                where_criteria,
+                where_criteria,  # type: ignore
                 roles.WhereHavingRole,
-                lambda_args=(
-                    _WrapUserEntity(
-                        self.root_entity
-                        if self.root_entity is not None
-                        else self.entity.entity,
-                    ),
-                ),
+                lambda_args=(_WrapUserEntity(wrap_entity),),
                 opts=lambdas.LambdaOptions(
                     track_closure_variables=track_closure_variables
                 ),
@@ -1126,22 +1204,27 @@ class LoaderCriteriaOption(CriteriaOption):
         self.include_aliases = include_aliases
         self.propagate_to_loaders = propagate_to_loaders
 
-    def _all_mappers(self):
+    def _all_mappers(self) -> Iterator[Mapper[Any]]:
+
         if self.entity:
-            for ent in self.entity.mapper.self_and_descendants:
-                yield ent
+            for mp_ent in self.entity.mapper.self_and_descendants:
+                yield mp_ent
         else:
+            assert self.root_entity
             stack = list(self.root_entity.__subclasses__())
             while stack:
                 subclass = stack.pop(0)
-                ent = inspection.inspect(subclass, raiseerr=False)
+                ent = cast(
+                    "_InternalEntityType[Any]",
+                    inspection.inspect(subclass, raiseerr=False),
+                )
                 if ent:
                     for mp in ent.mapper.self_and_descendants:
                         yield mp
                 else:
                     stack.extend(subclass.__subclasses__())
 
-    def _should_include(self, compile_state):
+    def _should_include(self, compile_state: ORMCompileState) -> bool:
         if (
             compile_state.select_statement._annotations.get(
                 "for_loader_criteria", None
@@ -1151,21 +1234,29 @@ class LoaderCriteriaOption(CriteriaOption):
             return False
         return True
 
-    def _resolve_where_criteria(self, ext_info):
+    def _resolve_where_criteria(
+        self, ext_info: _InternalEntityType[Any]
+    ) -> ColumnElement[bool]:
         if self.deferred_where_criteria:
-            crit = self.where_criteria._resolve_with_args(ext_info.entity)
+            crit = cast(
+                "ColumnElement[bool]",
+                self.where_criteria._resolve_with_args(ext_info.entity),
+            )
         else:
-            crit = self.where_criteria
+            crit = self.where_criteria  # type: ignore
+        assert isinstance(crit, ColumnElement)
         return sql_util._deep_annotate(
             crit, {"for_loader_criteria": self}, detect_subquery_cols=True
         )
 
     def process_compile_state_replaced_entities(
-        self, compile_state, mapper_entities
-    ):
-        return self.process_compile_state(compile_state)
+        self,
+        compile_state: ORMCompileState,
+        mapper_entities: Iterable[_MapperEntity],
+    ) -> None:
+        self.process_compile_state(compile_state)
 
-    def process_compile_state(self, compile_state):
+    def process_compile_state(self, compile_state: ORMCompileState) -> None:
         """Apply a modification to a given :class:`.CompileState`."""
 
         # if options to limit the criteria to immediate query only,
@@ -1173,7 +1264,7 @@ class LoaderCriteriaOption(CriteriaOption):
 
         self.get_global_criteria(compile_state.global_attributes)
 
-    def get_global_criteria(self, attributes):
+    def get_global_criteria(self, attributes: Dict[Any, Any]) -> None:
         for mp in self._all_mappers():
             load_criteria = attributes.setdefault(
                 ("additional_entity_criteria", mp), []
@@ -1183,14 +1274,14 @@ class LoaderCriteriaOption(CriteriaOption):
 
 
 inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
-inspection._inspects(AliasedInsp)(lambda target: target)
 
 
 @inspection._self_inspects
 class Bundle(
     ORMColumnsClauseRole,
     SupportsCloneAnnotations,
-    sql_base.MemoizedHasCacheKey,
+    MemoizedHasCacheKey,
+    inspection.Inspectable["Bundle"],
     InspectionAttr,
 ):
     """A grouping of SQL expressions that are returned by a :class:`.Query`
@@ -1227,7 +1318,11 @@ class Bundle(
 
     _propagate_attrs: _PropagateAttrsType = util.immutabledict()
 
-    def __init__(self, name, *exprs, **kw):
+    exprs: List[_ColumnsClauseElement]
+
+    def __init__(
+        self, name: str, *exprs: _ColumnExpressionArgument[Any], **kw: Any
+    ):
         r"""Construct a new :class:`.Bundle`.
 
         e.g.::
@@ -1246,37 +1341,43 @@ class Bundle(
 
         """
         self.name = self._label = name
-        self.exprs = exprs = [
+        coerced_exprs = [
             coercions.expect(
                 roles.ColumnsClauseRole, expr, apply_propagate_attrs=self
             )
             for expr in exprs
         ]
+        self.exprs = coerced_exprs
 
         self.c = self.columns = ColumnCollection(
             (getattr(col, "key", col._label), col)
-            for col in [e._annotations.get("bundle", e) for e in exprs]
-        )
+            for col in [e._annotations.get("bundle", e) for e in coerced_exprs]
+        ).as_readonly()
         self.single_entity = kw.pop("single_entity", self.single_entity)
 
-    def _gen_cache_key(self, anon_map, bindparams):
+    def _gen_cache_key(
+        self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+    ) -> Tuple[Any, ...]:
         return (self.__class__, self.name, self.single_entity) + tuple(
             [expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs]
         )
 
     @property
-    def mapper(self):
+    def mapper(self) -> Mapper[Any]:
         return self.exprs[0]._annotations.get("parentmapper", None)
 
     @property
-    def entity(self):
+    def entity(self) -> _InternalEntityType[Any]:
         return self.exprs[0]._annotations.get("parententity", None)
 
     @property
-    def entity_namespace(self):
+    def entity_namespace(
+        self,
+    ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]:
         return self.c
 
-    columns = None
+    columns: ReadOnlyColumnCollection[str, ColumnElement[Any]]
+
     """A namespace of SQL expressions referred to by this :class:`.Bundle`.
 
         e.g.::
@@ -1301,7 +1402,7 @@ class Bundle(
 
     """
 
-    c = None
+    c: ReadOnlyColumnCollection[str, ColumnElement[Any]]
     """An alias for :attr:`.Bundle.columns`."""
 
     def _clone(self):
@@ -1400,32 +1501,30 @@ class _ORMJoin(expression.Join):
 
     def __init__(
         self,
-        left,
-        right,
-        onclause=None,
-        isouter=False,
-        full=False,
-        _left_memo=None,
-        _right_memo=None,
-        _extra_criteria=(),
+        left: _FromClauseArgument,
+        right: _FromClauseArgument,
+        onclause: Optional[_OnClauseArgument] = None,
+        isouter: bool = False,
+        full: bool = False,
+        _left_memo: Optional[Any] = None,
+        _right_memo: Optional[Any] = None,
+        _extra_criteria: Sequence[ColumnElement[bool]] = (),
     ):
-        left_info = inspection.inspect(left)
+        left_info = cast(
+            "Union[FromClause, _InternalEntityType[Any]]",
+            inspection.inspect(left),
+        )
 
-        right_info = inspection.inspect(right)
+        right_info = cast(
+            "Union[FromClause, _InternalEntityType[Any]]",
+            inspection.inspect(right),
+        )
         adapt_to = right_info.selectable
 
         # used by joined eager loader
         self._left_memo = _left_memo
         self._right_memo = _right_memo
 
-        # legacy, for string attr name ON clause.  if that's removed
-        # then the "_joined_from_info" concept can go
-        left_orm_info = getattr(left, "_joined_from_info", left_info)
-        self._joined_from_info = right_info
-        if isinstance(onclause, str):
-            onclause = getattr(left_orm_info.entity, onclause)
-        # ####
-
         if isinstance(onclause, attributes.QueryableAttribute):
             on_selectable = onclause.comparator._source_selectable()
             prop = onclause.property
@@ -1477,20 +1576,23 @@ class _ORMJoin(expression.Join):
         augment_onclause = onclause is None and _extra_criteria
         expression.Join.__init__(self, left, right, onclause, isouter, full)
 
+        assert self.onclause is not None
+
         if augment_onclause:
             self.onclause &= sql.and_(*_extra_criteria)
 
         if (
             not prop
             and getattr(right_info, "mapper", None)
-            and right_info.mapper.single
+            and right_info.mapper.single  # type: ignore
         ):
+            right_info = cast("_InternalEntityType[Any]", right_info)
             # if single inheritance target and we are using a manual
             # or implicit ON clause, augment it the same way we'd augment the
             # WHERE.
             single_crit = right_info.mapper._single_table_criterion
             if single_crit is not None:
-                if right_info.is_aliased_class:
+                if insp_is_aliased_class(right_info):
                     single_crit = right_info._adapter.traverse(single_crit)
                 self.onclause = self.onclause & single_crit
 
@@ -1525,19 +1627,27 @@ class _ORMJoin(expression.Join):
 
     def join(
         self,
-        right,
-        onclause=None,
-        isouter=False,
-        full=False,
-        join_to_left=None,
-    ):
+        right: _FromClauseArgument,
+        onclause: Optional[_OnClauseArgument] = None,
+        isouter: bool = False,
+        full: bool = False,
+    ) -> _ORMJoin:
         return _ORMJoin(self, right, onclause, full=full, isouter=isouter)
 
-    def outerjoin(self, right, onclause=None, full=False, join_to_left=None):
+    def outerjoin(
+        self,
+        right: _FromClauseArgument,
+        onclause: Optional[_OnClauseArgument] = None,
+        full: bool = False,
+    ) -> _ORMJoin:
         return _ORMJoin(self, right, onclause, isouter=True, full=full)
 
 
-def with_parent(instance, prop, from_entity=None):
+def with_parent(
+    instance: object,
+    prop: attributes.QueryableAttribute[Any],
+    from_entity: Optional[_EntityType[Any]] = None,
+) -> ColumnElement[bool]:
     """Create filtering criterion that relates this query's primary entity
     to the given related instance, using established
     :func:`_orm.relationship()`
@@ -1588,6 +1698,8 @@ def with_parent(instance, prop, from_entity=None):
       .. versionadded:: 1.2
 
     """
+    prop_t: Relationship[Any]
+
     if isinstance(prop, str):
         raise sa_exc.ArgumentError(
             "with_parent() accepts class-bound mapped attributes, not strings"
@@ -1595,12 +1707,19 @@ def with_parent(instance, prop, from_entity=None):
     elif isinstance(prop, attributes.QueryableAttribute):
         if prop._of_type:
             from_entity = prop._of_type
-        prop = prop.property
+        if not prop_is_relationship(prop.property):
+            raise sa_exc.ArgumentError(
+                f"Expected relationship property for with_parent(), "
+                f"got {prop.property}"
+            )
+        prop_t = prop.property
+    else:
+        prop_t = prop
 
-    return prop._with_parent(instance, from_entity=from_entity)
+    return prop_t._with_parent(instance, from_entity=from_entity)
 
 
-def has_identity(object_):
+def has_identity(object_: object) -> bool:
     """Return True if the given object has a database
     identity.
 
@@ -1616,7 +1735,7 @@ def has_identity(object_):
     return state.has_identity
 
 
-def was_deleted(object_):
+def was_deleted(object_: object) -> bool:
     """Return True if the given object was deleted
     within a session flush.
 
@@ -1633,27 +1752,32 @@ def was_deleted(object_):
     return state.was_deleted
 
 
-def _entity_corresponds_to(given, entity):
+def _entity_corresponds_to(
+    given: _InternalEntityType[Any], entity: _InternalEntityType[Any]
+) -> bool:
     """determine if 'given' corresponds to 'entity', in terms
     of an entity passed to Query that would match the same entity
     being referred to elsewhere in the query.
 
     """
-    if entity.is_aliased_class:
-        if given.is_aliased_class:
+    if insp_is_aliased_class(entity):
+        if insp_is_aliased_class(given):
             if entity._base_alias() is given._base_alias():
                 return True
         return False
-    elif given.is_aliased_class:
+    elif insp_is_aliased_class(given):
         if given._use_mapper_path:
             return entity in given.with_polymorphic_mappers
         else:
             return entity is given
 
+    assert insp_is_mapper(given)
     return entity.common_parent(given)
 
 
-def _entity_corresponds_to_use_path_impl(given, entity):
+def _entity_corresponds_to_use_path_impl(
+    given: _InternalEntityType[Any], entity: _InternalEntityType[Any]
+) -> bool:
     """determine if 'given' corresponds to 'entity', in terms
     of a path of loader options where a mapped attribute is taken to
     be a member of a parent entity.
@@ -1673,13 +1797,13 @@ def _entity_corresponds_to_use_path_impl(given, entity):
 
 
     """
-    if given.is_aliased_class:
+    if insp_is_aliased_class(given):
         return (
-            entity.is_aliased_class
+            insp_is_aliased_class(entity)
             and not entity._use_mapper_path
             and (given is entity or entity in given._with_polymorphic_entities)
         )
-    elif not entity.is_aliased_class:
+    elif not insp_is_aliased_class(entity):
         return given.isa(entity.mapper)
     else:
         return (
@@ -1688,7 +1812,7 @@ def _entity_corresponds_to_use_path_impl(given, entity):
         )
 
 
-def _entity_isa(given, mapper):
+def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool:
     """determine if 'given' "is a" mapper, in terms of the given
     would load rows of type 'mapper'.
 
@@ -1703,42 +1827,6 @@ def _entity_isa(given, mapper):
         return given.isa(mapper)
 
 
-def randomize_unitofwork():
-    """Use random-ordering sets within the unit of work in order
-    to detect unit of work sorting issues.
-
-    This is a utility function that can be used to help reproduce
-    inconsistent unit of work sorting issues.   For example,
-    if two kinds of objects A and B are being inserted, and
-    B has a foreign key reference to A - the A must be inserted first.
-    However, if there is no relationship between A and B, the unit of work
-    won't know to perform this sorting, and an operation may or may not
-    fail, depending on how the ordering works out.   Since Python sets
-    and dictionaries have non-deterministic ordering, such an issue may
-    occur on some runs and not on others, and in practice it tends to
-    have a great dependence on the state of the interpreter.  This leads
-    to so-called "heisenbugs" where changing entirely irrelevant aspects
-    of the test program still cause the failure behavior to change.
-
-    By calling ``randomize_unitofwork()`` when a script first runs, the
-    ordering of a key series of sets within the unit of work implementation
-    are randomized, so that the script can be minimized down to the
-    fundamental mapping and operation that's failing, while still reproducing
-    the issue on at least some runs.
-
-    This utility is also available when running the test suite via the
-    ``--reversetop`` flag.
-
-    """
-    from sqlalchemy.orm import unitofwork, session, mapper, dependency
-    from sqlalchemy.util import topological
-    from sqlalchemy.testing.util import RandomSet
-
-    topological.set = (
-        unitofwork.set
-    ) = session.set = mapper.set = dependency.set = RandomSet
-
-
 def _getitem(iterable_query, item):
     """calculate __getitem__ in terms of an iterable query object
     that also has a slice() method.
@@ -1780,16 +1868,21 @@ def _getitem(iterable_query, item):
             return list(iterable_query[item : item + 1])[0]
 
 
-def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type):
+def _is_mapped_annotation(
+    raw_annotation: Union[type, str], cls: Type[Any]
+) -> bool:
     annotated = de_stringify_annotation(cls, raw_annotation)
     return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
 
 
-def _cleanup_mapped_str_annotation(annotation):
+def _cleanup_mapped_str_annotation(annotation: str) -> str:
     # fix up an annotation that comes in as the form:
     # 'Mapped[List[Address]]'  so that it instead looks like:
     # 'Mapped[List["Address"]]' , which will allow us to get
     # "Address" as a string
+
+    inner: Optional[Match[str]]
+
     mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
     if mm and mm.group(1) == "Mapped":
         stack = []
@@ -1839,8 +1932,8 @@ def _extract_mapped_subtype(
     else:
         if (
             not hasattr(annotated, "__origin__")
-            or not issubclass(annotated.__origin__, attr_cls)
-            and not issubclass(attr_cls, annotated.__origin__)
+            or not issubclass(annotated.__origin__, attr_cls)  # type: ignore
+            and not issubclass(attr_cls, annotated.__origin__)  # type: ignore
         ):
             our_annotated_str = (
                 annotated.__name__
@@ -1853,9 +1946,9 @@ def _extract_mapped_subtype(
                 f'"{attr_cls.__name__}[{our_annotated_str}]".'
             )
 
-        if len(annotated.__args__) != 1:
+        if len(annotated.__args__) != 1:  # type: ignore
             raise sa_exc.ArgumentError(
                 "Expected sub-type for Mapped[] annotation"
             )
 
-        return annotated.__args__[0]
+        return annotated.__args__[0]  # type: ignore
index ea21e01c66633efbca2a2aab828f627c12599f05..605f75ec4fb080111979701bc5f4961fd85ebe8d 100644 (file)
@@ -389,7 +389,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:
 
 
 def bindparam(
-    key: str,
+    key: Optional[str],
     value: Any = _NoArg.NO_ARG,
     type_: Optional[TypeEngine[_T]] = None,
     unique: bool = False,
@@ -521,6 +521,11 @@ def bindparam(
       key, or if its length is too long and truncation is
       required.
 
+      If omitted, an "anonymous" name is generated for the bound parameter;
+      when given a value to bind, the end result is equivalent to calling upon
+      the :func:`.literal` function with a value to bind, particularly
+      if the :paramref:`.bindparam.unique` parameter is also provided.
+
     :param value:
       Initial value for this bind param.  Will be used at statement
       execution time as the value for this parameter passed to the
index b0a717a1a35f2a1533523045bce66a44ba4695d9..53d29b628f5ce16ef43efaab91496bfa1515a7d1 100644 (file)
@@ -2,13 +2,14 @@ from __future__ import annotations
 
 import operator
 from typing import Any
+from typing import Callable
 from typing import Dict
+from typing import Set
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
-from sqlalchemy.sql.base import Executable
 from . import roles
 from .. import util
 from ..inspection import Inspectable
@@ -16,6 +17,7 @@ from ..util.typing import Literal
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
+    from .base import Executable
     from .compiler import Compiled
     from .compiler import DDLCompiler
     from .compiler import SQLCompiler
@@ -27,17 +29,20 @@ if TYPE_CHECKING:
     from .elements import quoted_name
     from .elements import SQLCoreOperations
     from .elements import TextClause
+    from .lambdas import LambdaElement
     from .roles import ColumnsClauseRole
     from .roles import FromClauseRole
     from .schema import Column
     from .schema import DefaultGenerator
     from .schema import Sequence
+    from .schema import Table
     from .selectable import Alias
     from .selectable import FromClause
     from .selectable import Join
     from .selectable import NamedFromClause
     from .selectable import ReturnsRows
     from .selectable import Select
+    from .selectable import Selectable
     from .selectable import SelectBase
     from .selectable import Subquery
     from .selectable import TableClause
@@ -46,7 +51,6 @@ if TYPE_CHECKING:
     from .type_api import TypeEngine
     from ..util.typing import TypeGuard
 
-
 _T = TypeVar("_T", bound=Any)
 
 
@@ -89,7 +93,11 @@ sets; select(...), insert().returning(...), etc.
 """
 
 _ColumnExpressionArgument = Union[
-    "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T]
+    "ColumnElement[_T]",
+    _HasClauseElement,
+    roles.ExpressionElementRole[_T],
+    Callable[[], "ColumnElement[_T]"],
+    "LambdaElement",
 ]
 """narrower "column expression" argument.
 
@@ -103,6 +111,7 @@ overall which brings in the TextClause object also.
 
 """
 
+
 _InfoType = Dict[Any, Any]
 """the .info dictionary accepted and used throughout Core /ORM"""
 
@@ -169,6 +178,8 @@ _PropagateAttrsType = util.immutabledict[str, Any]
 
 _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
 
+_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]]
+
 if TYPE_CHECKING:
 
     def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]:
@@ -195,6 +206,9 @@ if TYPE_CHECKING:
     def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:
         ...
 
+    def is_selectable(t: Any) -> TypeGuard[Selectable]:
+        ...
+
     def is_select_base(
         t: Union[Executable, ReturnsRows]
     ) -> TypeGuard[SelectBase]:
@@ -224,6 +238,7 @@ else:
     is_from_clause = operator.attrgetter("_is_from_clause")
     is_tuple_type = operator.attrgetter("_is_tuple_type")
     is_table_value_type = operator.attrgetter("_is_table_value")
+    is_selectable = operator.attrgetter("is_selectable")
     is_select_base = operator.attrgetter("_is_select_base")
     is_select_statement = operator.attrgetter("_is_select_statement")
     is_table = operator.attrgetter("_is_table")
index f7692dbc2ac51a8450e438ad08fbb6813403d246..f81878d55d8f2ac4ab1f69009b2c8559ef3efa82 100644 (file)
@@ -218,7 +218,7 @@ def _generative(fn: _Fn) -> _Fn:
 
     """
 
-    @util.decorator
+    @util.decorator  # type: ignore
     def _generative(
         fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
     ) -> _SelfGenerativeType:
@@ -244,7 +244,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
         for name in names
     ]
 
-    @util.decorator
+    @util.decorator  # type: ignore
     def check(fn, *args, **kw):
         # make pylance happy by not including "self" in the argument
         # list
@@ -260,7 +260,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
                 raise exc.InvalidRequestError(msg)
         return fn(self, *args, **kw)
 
-    return check
+    return check  # type: ignore
 
 
 def _clone(element, **kw):
@@ -1750,15 +1750,14 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
                 self._collection.append((k, col))
         self._colset.update(c for (k, c) in self._collection)
 
-        # https://github.com/python/mypy/issues/12610
         self._index.update(
-            (idx, c) for idx, (k, c) in enumerate(self._collection)  # type: ignore  # noqa: E501
+            (idx, c) for idx, (k, c) in enumerate(self._collection)
         )
         for col in replace_col:
             self.replace(col)
 
     def extend(self, iter_: Iterable[_NAMEDCOL]) -> None:
-        self._populate_separate_keys((col.key, col) for col in iter_)  # type: ignore  # noqa: E501
+        self._populate_separate_keys((col.key, col) for col in iter_)
 
     def remove(self, column: _NAMEDCOL) -> None:
         if column not in self._colset:
@@ -1772,9 +1771,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
             (k, c) for (k, c) in self._collection if c is not column
         ]
 
-        # https://github.com/python/mypy/issues/12610
         self._index.update(
-            {idx: col for idx, (k, col) in enumerate(self._collection)}  # type: ignore  # noqa: E501
+            {idx: col for idx, (k, col) in enumerate(self._collection)}
         )
         # delete higher index
         del self._index[len(self._collection)]
@@ -1827,9 +1825,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
 
         self._index.clear()
 
-        # https://github.com/python/mypy/issues/12610
         self._index.update(
-            {idx: col for idx, (k, col) in enumerate(self._collection)}  # type: ignore  # noqa: E501
+            {idx: col for idx, (k, col) in enumerate(self._collection)}
         )
         self._index.update(self._collection)
 
index 4bf45da9cb47fbbd7469773f7ff620f92672ee59..0659709ab4385e8aaaa9c915a36abce84a822dc7 100644 (file)
@@ -214,6 +214,7 @@ def expect(
         Type[roles.ExpressionElementRole[Any]],
         Type[roles.LimitOffsetRole],
         Type[roles.WhereHavingRole],
+        Type[roles.OnClauseRole],
     ],
     element: Any,
     **kw: Any,
index 938be0f8177091c29abe4e752acdea82cdad82e6..c524a2602c9c066bc0ce634e87ffdd0c77188e74 100644 (file)
@@ -1078,7 +1078,7 @@ class SQLCompiler(Compiled):
         return list(self.insert_prefetch) + list(self.update_prefetch)
 
     @util.memoized_property
-    def _global_attributes(self):
+    def _global_attributes(self) -> Dict[Any, Any]:
         return {}
 
     @util.memoized_instancemethod
index 6ac7c24483c2ee76c7c99946f25bcd13a94bdee9..052af6ac9da102fba6905307116b9a8391bb15db 100644 (file)
@@ -14,6 +14,7 @@ from __future__ import annotations
 import typing
 from typing import Any
 from typing import Callable
+from typing import Iterable
 from typing import List
 from typing import Optional
 from typing import Sequence as typing_Sequence
@@ -1143,7 +1144,7 @@ class SchemaDropper(InvokeDDLBase):
 
 
 def sort_tables(
-    tables: typing_Sequence["Table"],
+    tables: Iterable["Table"],
     skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None,
     extra_dependencies: Optional[
         typing_Sequence[Tuple["Table", "Table"]]
index ea0fa799620ed17f29fda04bc782b1e6f9a74853..34d5127ab73301ef0b2fdfb50aa460c5acbbc8f1 100644 (file)
@@ -293,11 +293,18 @@ class ClauseElement(
 
     __visit_name__ = "clause"
 
-    _propagate_attrs: _PropagateAttrsType = util.immutabledict()
-    """like annotations, however these propagate outwards liberally
-    as SQL constructs are built, and are set up at construction time.
+    if TYPE_CHECKING:
 
-    """
+        @util.memoized_property
+        def _propagate_attrs(self) -> _PropagateAttrsType:
+            """like annotations, however these propagate outwards liberally
+            as SQL constructs are built, and are set up at construction time.
+
+            """
+            ...
+
+    else:
+        _propagate_attrs = util.EMPTY_DICT
 
     @util.ro_memoized_property
     def description(self) -> Optional[str]:
@@ -343,7 +350,9 @@ class ClauseElement(
     def _from_objects(self) -> List[FromClause]:
         return []
 
-    def _set_propagate_attrs(self, values):
+    def _set_propagate_attrs(
+        self: SelfClauseElement, values: Mapping[str, Any]
+    ) -> SelfClauseElement:
         # usually, self._propagate_attrs is empty here.  one case where it's
         # not is a subquery against ORM select, that is then pulled as a
         # property of an aliased class.   should all be good
@@ -526,13 +535,10 @@ class ClauseElement(
             if unique:
                 bind._convert_to_unique()
 
-        return cast(
-            SelfClauseElement,
-            cloned_traverse(
-                self,
-                {"maintain_key": True, "detect_subquery_cols": True},
-                {"bindparam": visit_bindparam},
-            ),
+        return cloned_traverse(
+            self,
+            {"maintain_key": True, "detect_subquery_cols": True},
+            {"bindparam": visit_bindparam},
         )
 
     def compare(self, other, **kw):
@@ -730,7 +736,9 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
     # redefined with the specific types returned by ColumnElement hierarchies
     if typing.TYPE_CHECKING:
 
-        _propagate_attrs: _PropagateAttrsType
+        @util.non_memoized_property
+        def _propagate_attrs(self) -> _PropagateAttrsType:
+            ...
 
         def operate(
             self, op: OperatorType, *other: Any, **kwargs: Any
@@ -2064,10 +2072,11 @@ class TextClause(
     roles.OrderByRole,
     roles.FromClauseRole,
     roles.SelectStatementRole,
-    roles.BinaryElementRole[Any],
     roles.InElementRole,
     Executable,
     DQLDMLClauseElement,
+    roles.BinaryElementRole[Any],
+    inspection.Inspectable["TextClause"],
 ):
     """Represent a literal SQL text fragment.
 
index da15c305fc0efe701bd166bcbb2e26fba09c0352..4b220188f72125deed0882717f534fa5f790097e 100644 (file)
@@ -444,7 +444,7 @@ class DeferredLambdaElement(LambdaElement):
     def _invoke_user_fn(self, fn, *arg):
         return fn(*self.lambda_args)
 
-    def _resolve_with_args(self, *lambda_args):
+    def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:
         assert isinstance(self._rec, AnalyzedFunction)
         tracker_fn = self._rec.tracker_instrumented_fn
         expr = tracker_fn(*lambda_args)
@@ -478,7 +478,7 @@ class DeferredLambdaElement(LambdaElement):
         for deferred_copy_internals in self._transforms:
             expr = deferred_copy_internals(expr)
 
-        return expr
+        return expr  # type: ignore
 
     def _copy_internals(
         self, clone=_clone, deferred_copy_internals=None, **kw
index 577d868fdb5b8c5e4bcb29cde4086fb990410922..231c70a5bad54188ee498372e6a03eaafd200910 100644 (file)
@@ -22,9 +22,7 @@ if TYPE_CHECKING:
     from .base import _EntityNamespace
     from .base import ColumnCollection
     from .base import ReadOnlyColumnCollection
-    from .elements import ClauseElement
     from .elements import ColumnClause
-    from .elements import ColumnElement
     from .elements import Label
     from .elements import NamedColumn
     from .selectable import _SelectIterable
@@ -271,7 +269,14 @@ class StatementRole(SQLRole):
     __slots__ = ()
     _role_name = "Executable SQL or text() construct"
 
-    _propagate_attrs: _PropagateAttrsType = util.immutabledict()
+    if TYPE_CHECKING:
+
+        @util.memoized_property
+        def _propagate_attrs(self) -> _PropagateAttrsType:
+            ...
+
+    else:
+        _propagate_attrs = util.EMPTY_DICT
 
 
 class SelectStatementRole(StatementRole, ReturnsRowsRole):
index 92b9cc62c2ddd21d37253b258d96b1af6783c04e..52ba60a62c8f991b93c53d03953948a93554d251 100644 (file)
@@ -144,9 +144,9 @@ class SchemaConst(Enum):
     NULL_UNSPECIFIED = 3
     """Symbol indicating the "nullable" keyword was not passed to a Column.
 
-    Normally we would expect None to be acceptable for this but some backends
-    such as that of SQL Server place special signficance on a "nullability"
-    value of None.
+    This is used to distinguish between the use case of passing
+    ``nullable=None`` to a :class:`.Column`, which has special meaning
+    on some backends such as SQL Server.
 
     """
 
@@ -308,7 +308,9 @@ class HasSchemaAttr(SchemaItem):
     schema: Optional[str]
 
 
-class Table(DialectKWArgs, HasSchemaAttr, TableClause):
+class Table(
+    DialectKWArgs, HasSchemaAttr, TableClause, inspection.Inspectable["Table"]
+):
     r"""Represent a table in a database.
 
     e.g.::
@@ -1318,117 +1320,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
     inherit_cache = True
     key: str
 
-    @overload
-    def __init__(
-        self,
-        *args: SchemaEventTarget,
-        autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
-        default: Optional[Any] = None,
-        doc: Optional[str] = None,
-        key: Optional[str] = None,
-        index: Optional[bool] = None,
-        unique: Optional[bool] = None,
-        info: Optional[_InfoType] = None,
-        nullable: Optional[
-            Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
-        ] = NULL_UNSPECIFIED,
-        onupdate: Optional[Any] = None,
-        primary_key: bool = False,
-        server_default: Optional[_ServerDefaultType] = None,
-        server_onupdate: Optional[FetchedValue] = None,
-        quote: Optional[bool] = None,
-        system: bool = False,
-        comment: Optional[str] = None,
-        _proxies: Optional[Any] = None,
-        **dialect_kwargs: Any,
-    ):
-        ...
-
-    @overload
-    def __init__(
-        self,
-        __name: str,
-        *args: SchemaEventTarget,
-        autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
-        default: Optional[Any] = None,
-        doc: Optional[str] = None,
-        key: Optional[str] = None,
-        index: Optional[bool] = None,
-        unique: Optional[bool] = None,
-        info: Optional[_InfoType] = None,
-        nullable: Optional[
-            Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
-        ] = NULL_UNSPECIFIED,
-        onupdate: Optional[Any] = None,
-        primary_key: bool = False,
-        server_default: Optional[_ServerDefaultType] = None,
-        server_onupdate: Optional[FetchedValue] = None,
-        quote: Optional[bool] = None,
-        system: bool = False,
-        comment: Optional[str] = None,
-        _proxies: Optional[Any] = None,
-        **dialect_kwargs: Any,
-    ):
-        ...
-
-    @overload
     def __init__(
         self,
-        __type: _TypeEngineArgument[_T],
-        *args: SchemaEventTarget,
-        autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
-        default: Optional[Any] = None,
-        doc: Optional[str] = None,
-        key: Optional[str] = None,
-        index: Optional[bool] = None,
-        unique: Optional[bool] = None,
-        info: Optional[_InfoType] = None,
-        nullable: Optional[
-            Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
-        ] = NULL_UNSPECIFIED,
-        onupdate: Optional[Any] = None,
-        primary_key: bool = False,
-        server_default: Optional[_ServerDefaultType] = None,
-        server_onupdate: Optional[FetchedValue] = None,
-        quote: Optional[bool] = None,
-        system: bool = False,
-        comment: Optional[str] = None,
-        _proxies: Optional[Any] = None,
-        **dialect_kwargs: Any,
-    ):
-        ...
-
-    @overload
-    def __init__(
-        self,
-        __name: str,
-        __type: _TypeEngineArgument[_T],
+        __name_pos: Optional[
+            Union[str, _TypeEngineArgument[_T], SchemaEventTarget]
+        ] = None,
+        __type_pos: Optional[
+            Union[_TypeEngineArgument[_T], SchemaEventTarget]
+        ] = None,
         *args: SchemaEventTarget,
-        autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
-        default: Optional[Any] = None,
-        doc: Optional[str] = None,
-        key: Optional[str] = None,
-        index: Optional[bool] = None,
-        unique: Optional[bool] = None,
-        info: Optional[_InfoType] = None,
-        nullable: Optional[
-            Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
-        ] = NULL_UNSPECIFIED,
-        onupdate: Optional[Any] = None,
-        primary_key: bool = False,
-        server_default: Optional[_ServerDefaultType] = None,
-        server_onupdate: Optional[FetchedValue] = None,
-        quote: Optional[bool] = None,
-        system: bool = False,
-        comment: Optional[str] = None,
-        _proxies: Optional[Any] = None,
-        **dialect_kwargs: Any,
-    ):
-        ...
-
-    def __init__(
-        self,
-        *args: Union[str, _TypeEngineArgument[_T], SchemaEventTarget],
         name: Optional[str] = None,
         type_: Optional[_TypeEngineArgument[_T]] = None,
         autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
@@ -1440,7 +1340,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
         info: Optional[_InfoType] = None,
         nullable: Optional[
             Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
-        ] = NULL_UNSPECIFIED,
+        ] = SchemaConst.NULL_UNSPECIFIED,
         onupdate: Optional[Any] = None,
         primary_key: bool = False,
         server_default: Optional[_ServerDefaultType] = None,
@@ -1953,7 +1853,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
 
         """  # noqa: E501, RST201, RST202
 
-        l_args = list(args)
+        l_args = [__name_pos, __type_pos] + list(args)
         del args
 
         if l_args:
@@ -1963,6 +1863,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
                         "May not pass name positionally and as a keyword."
                     )
                 name = l_args.pop(0)  # type: ignore
+            elif l_args[0] is None:
+                l_args.pop(0)
         if l_args:
             coltype = l_args[0]
 
@@ -1972,6 +1874,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
                         "May not pass type_ positionally and as a keyword."
                     )
                 type_ = l_args.pop(0)  # type: ignore
+            elif l_args[0] is None:
+                l_args.pop(0)
 
         if name is not None:
             name = quoted_name(name, quote)
@@ -1989,7 +1893,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
         self.primary_key = primary_key
 
         self._user_defined_nullable = udn = nullable
-
         if udn is not NULL_UNSPECIFIED:
             self.nullable = udn
         else:
@@ -5128,7 +5031,7 @@ class MetaData(HasSchemaAttr):
     def clear(self) -> None:
         """Clear all Table objects from this MetaData."""
 
-        dict.clear(self.tables)
+        dict.clear(self.tables)  # type: ignore
         self._schemas.clear()
         self._fk_memos.clear()
 
index aab3c678c5e30ab3fcde9984f72ead3c3ecdde93..9d4d1d6c79c90ea7387c4f806c58c886dfa0e364 100644 (file)
@@ -1223,7 +1223,9 @@ class Join(roles.DMLTableRole, FromClause):
     @util.preload_module("sqlalchemy.sql.util")
     def _populate_column_collection(self):
         sqlutil = util.preloaded.sql_util
-        columns = [c for c in self.left.c] + [c for c in self.right.c]
+        columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [
+            c for c in self.right.c
+        ]
 
         self.primary_key.extend(  # type: ignore
             sqlutil.reduce_columns(
index 2843431549cec6ba6b969864ed757c34aad3e994..d08fef60a9b93fa42c677dcf0e8f2954c7951faf 100644 (file)
@@ -17,7 +17,9 @@ from typing import AbstractSet
 from typing import Any
 from typing import Callable
 from typing import cast
+from typing import Collection
 from typing import Dict
+from typing import Iterable
 from typing import Iterator
 from typing import List
 from typing import Optional
@@ -32,15 +34,15 @@ from . import coercions
 from . import operators
 from . import roles
 from . import visitors
+from ._typing import is_text_clause
 from .annotation import _deep_annotate as _deep_annotate
 from .annotation import _deep_deannotate as _deep_deannotate
 from .annotation import _shallow_annotate as _shallow_annotate
 from .base import _expand_cloned
 from .base import _from_objects
-from .base import ColumnSet
-from .cache_key import HasCacheKey  # noqa
-from .ddl import sort_tables  # noqa
-from .elements import _find_columns
+from .cache_key import HasCacheKey as HasCacheKey
+from .ddl import sort_tables as sort_tables
+from .elements import _find_columns as _find_columns
 from .elements import _label_reference
 from .elements import _textual_label_reference
 from .elements import BindParameter
@@ -67,10 +69,13 @@ from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
+    from ._typing import _EquivalentColumnMap
     from ._typing import _TypeEngineArgument
+    from .elements import TextClause
     from .roles import FromClauseRole
     from .selectable import _JoinTargetElement
     from .selectable import _OnClauseElement
+    from .selectable import _SelectIterable
     from .selectable import Selectable
     from .visitors import _TraverseCallableType
     from .visitors import ExternallyTraversible
@@ -752,7 +757,29 @@ def splice_joins(
     return ret
 
 
-def reduce_columns(columns, *clauses, **kw):
+@overload
+def reduce_columns(
+    columns: Iterable[ColumnElement[Any]],
+    *clauses: Optional[ClauseElement],
+    **kw: bool,
+) -> Sequence[ColumnElement[Any]]:
+    ...
+
+
+@overload
+def reduce_columns(
+    columns: _SelectIterable,
+    *clauses: Optional[ClauseElement],
+    **kw: bool,
+) -> Sequence[Union[ColumnElement[Any], TextClause]]:
+    ...
+
+
+def reduce_columns(
+    columns: _SelectIterable,
+    *clauses: Optional[ClauseElement],
+    **kw: bool,
+) -> Collection[Union[ColumnElement[Any], TextClause]]:
     r"""given a list of columns, return a 'reduced' set based on natural
     equivalents.
 
@@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw):
     ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
     only_synonyms = kw.pop("only_synonyms", False)
 
-    columns = util.ordered_column_set(columns)
+    column_set = util.OrderedSet(columns)
+    cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
+        c for c in column_set if is_text_clause(c)  # type: ignore
+    )
 
     omit = util.column_set()
-    for col in columns:
+    for col in cset_no_text:
         for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
-            for c in columns:
+            for c in cset_no_text:
                 if c is col:
                     continue
                 try:
@@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw):
         def visit_binary(binary):
             if binary.operator == operators.eq:
                 cols = util.column_set(
-                    chain(*[c.proxy_set for c in columns.difference(omit)])
+                    chain(
+                        *[c.proxy_set for c in cset_no_text.difference(omit)]
+                    )
                 )
                 if binary.left in cols and binary.right in cols:
-                    for c in reversed(columns):
+                    for c in reversed(cset_no_text):
                         if c.shares_lineage(binary.right) and (
                             not only_synonyms or c.name == binary.left.name
                         ):
@@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw):
             if clause is not None:
                 visitors.traverse(clause, {}, {"binary": visit_binary})
 
-    return ColumnSet(columns.difference(omit))
+    return column_set.difference(omit)
 
 
 def criterion_as_pairs(
@@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
     def __init__(
         self,
         selectable: Selectable,
-        equivalents: Optional[
-            Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
-        ] = None,
+        equivalents: Optional[_EquivalentColumnMap] = None,
         include_fn: Optional[Callable[[ClauseElement], bool]] = None,
         exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
         adapt_on_names: bool = False,
@@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
 
 
 class _ColumnLookup(Protocol):
-    def __getitem__(
-        self, key: ColumnElement[Any]
-    ) -> Optional[ColumnElement[Any]]:
+    @overload
+    def __getitem__(self, key: None) -> None:
+        ...
+
+    @overload
+    def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]:
+        ...
+
+    @overload
+    def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]:
+        ...
+
+    @overload
+    def __getitem__(self, key: _ET) -> _ET:
+        ...
+
+    def __getitem__(self, key: Any) -> Any:
         ...
 
 
@@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter):
     def __init__(
         self,
         selectable: Selectable,
-        equivalents: Optional[
-            Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
-        ] = None,
+        equivalents: Optional[_EquivalentColumnMap] = None,
         adapt_required: bool = False,
         include_fn: Optional[Callable[[ClauseElement], bool]] = None,
         exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
@@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter):
 
         return ac
 
-    def traverse(self, obj):
+    @overload
+    def traverse(self, obj: Literal[None]) -> None:
+        ...
+
+    @overload
+    def traverse(self, obj: _ET) -> _ET:
+        ...
+
+    def traverse(
+        self, obj: Optional[ExternallyTraversible]
+    ) -> Optional[ExternallyTraversible]:
         return self.columns[obj]
 
     def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
@@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter):
     adapt_clause = traverse
     adapt_list = ClauseAdapter.copy_and_process
 
-    def adapt_check_present(self, col):
+    def adapt_check_present(
+        self, col: ColumnElement[Any]
+    ) -> Optional[ColumnElement[Any]]:
         newcol = self.columns[col]
 
         if newcol is col and self._corresponding_column(col, True) is None:
index 7363f9ddc28adecd09e90329b8b84bed91886dd6..e0a66fbcf48edf7345f688245064337c168e5cbb 100644 (file)
@@ -961,12 +961,16 @@ def cloned_traverse(
     ...
 
 
+# a bit of controversy here, as the clone of the lead element
+# *could* in theory replace with an entirely different kind of element.
+# however this is really not how cloned_traverse is ever used internally
+# at least.
 @overload
 def cloned_traverse(
-    obj: ExternallyTraversible,
+    obj: _ET,
     opts: Mapping[str, Any],
     visitors: Mapping[str, _TraverseCallableType[Any]],
-) -> ExternallyTraversible:
+) -> _ET:
     ...
 
 
index b908585128dcac01cef7f7b13dd2b7cf884e5d91..16924a0a1b23a3820794e045df44fc828514edff 100644 (file)
@@ -166,14 +166,6 @@ def setup_options(make_option):
         help="write out generated follower idents to <file>, "
         "when -n<num> is used",
     )
-    make_option(
-        "--reversetop",
-        action="store_true",
-        dest="reversetop",
-        default=False,
-        help="Use a random-ordering set implementation in the ORM "
-        "(helps reveal dependency issues)",
-    )
     make_option(
         "--requirements",
         action="callback",
@@ -475,14 +467,6 @@ def _prep_testing_database(options, file_config):
             provision.drop_all_schema_objects(cfg, cfg.db)
 
 
-@post
-def _reverse_topological(options, file_config):
-    if options.reversetop:
-        from sqlalchemy.orm.util import randomize_unitofwork
-
-        randomize_unitofwork()
-
-
 @post
 def _post_setup_options(opt, file_config):
     from sqlalchemy.testing import config
index 086b008de6b0fe26370513cb1a6e5d80fd543be9..ed69450903b321677e94209b671d8671209e6cee 100644 (file)
@@ -26,6 +26,7 @@ from typing import Mapping
 from typing import NoReturn
 from typing import Optional
 from typing import overload
+from typing import Sequence
 from typing import Set
 from typing import Tuple
 from typing import TypeVar
@@ -287,8 +288,8 @@ OrderedDict = dict
 sort_dictionary = _ordered_dictionary_sort
 
 
-class WeakSequence:
-    def __init__(self, __elements=()):
+class WeakSequence(Sequence[_T]):
+    def __init__(self, __elements: Sequence[_T] = ()):
         # adapted from weakref.WeakKeyDictionary, prevent reference
         # cycles in the collection itself
         def _remove(item, selfref=weakref.ref(self)):
index 88deac28f8c8abcb8ebf5b25d373bff47704aba5..b02bca28f80239b2249387db6fe867321d2c06d5 100644 (file)
@@ -92,7 +92,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
 
         new = dict.__new__(self.__class__)
         dict.__init__(new, self)
-        dict.update(new, __d)
+        dict.update(new, __d)  # type: ignore
         return new
 
     def _union_w_kw(
@@ -105,7 +105,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
         new = dict.__new__(self.__class__)
         dict.__init__(new, self)
         if __d:
-            dict.update(new, __d)
+            dict.update(new, __d)  # type: ignore
         dict.update(new, kw)  # type: ignore
         return new
 
@@ -118,7 +118,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
                 if new is None:
                     new = dict.__new__(self.__class__)
                     dict.__init__(new, self)
-                dict.update(new, d)
+                dict.update(new, d)  # type: ignore
         if new is None:
             return self
 
index 5c536b675fa40bda4ac96ce755e8d3e5e541229e..7c80ef4e0216da7c22a2721d5ae43dd00095ea7a 100644 (file)
@@ -13,7 +13,6 @@ from __future__ import annotations
 import re
 from typing import Any
 from typing import Callable
-from typing import cast
 from typing import Dict
 from typing import Match
 from typing import Optional
@@ -79,7 +78,7 @@ def warn_deprecated_limited(
 
 
 def deprecated_cls(
-    version: str, message: str, constructor: str = "__init__"
+    version: str, message: str, constructor: Optional[str] = "__init__"
 ) -> Callable[[Type[_T]], Type[_T]]:
     header = ".. deprecated:: %s %s" % (version, (message or ""))
 
@@ -288,7 +287,9 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]:
 
         check_any_kw = spec.varkw
 
-        @decorator
+        # latest mypy has opinions here, not sure if they implemented
+        # Concatenate or something
+        @decorator  # type: ignore
         def warned(fn: _F, *args: Any, **kwargs: Any) -> _F:
             for m in check_defaults:
                 if (defaults[m] is None and kwargs[m] is not None) or (
@@ -332,7 +333,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]:
                     for param, (version, message) in specs.items()
                 },
             )
-        decorated = cast(_F, warned)(fn)
+        decorated = warned(fn)  # type: ignore
         decorated.__doc__ = doc
         return decorated  # type: ignore[no-any-return]
 
@@ -352,7 +353,7 @@ def _sanitize_restructured_text(text: str) -> str:
 
 def _decorate_cls_with_warning(
     cls: Type[_T],
-    constructor: str,
+    constructor: Optional[str],
     wtype: Type[exc.SADeprecationWarning],
     message: str,
     version: str,
@@ -418,7 +419,7 @@ def _decorate_with_warning(
     else:
         doc_only = ""
 
-    @decorator
+    @decorator  # type: ignore
     def warned(fn: _F, *args: Any, **kwargs: Any) -> _F:
         skip_warning = not enable_warnings or kwargs.pop(
             "_sa_skip_warning", False
@@ -435,9 +436,9 @@ def _decorate_with_warning(
 
         doc = inject_docstring_text(doc, docstring_header, 1)
 
-    decorated = cast(_F, warned)(func)
+    decorated = warned(func)  # type: ignore
     decorated.__doc__ = doc
-    decorated._sa_warn = lambda: _warn_with_version(
+    decorated._sa_warn = lambda: _warn_with_version(  # type: ignore
         message, version, wtype, stacklevel=3
     )
     return decorated  # type: ignore[no-any-return]
index 9b3692d595a295efe7b7b46466e8c9e620406eed..49c5d693afa7c527b9311a50c1830444750a54fd 100644 (file)
@@ -238,6 +238,8 @@ def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
 
 _Fn = TypeVar("_Fn", bound="Callable[..., Any]")
 
+# this seems to be in flux in recent mypy versions
+
 
 def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
     """A signature-matching decorator factory."""
index 260250b2cf45bd06232c4dbb5d9c4fc25ed3e53d..ee3227d775f6567853121997bb8621068153c874 100644 (file)
@@ -23,6 +23,8 @@ _FN = TypeVar("_FN", bound=Callable[..., Any])
 
 if TYPE_CHECKING:
     from sqlalchemy.engine import default as engine_default
+    from sqlalchemy.orm import descriptor_props as orm_descriptor_props
+    from sqlalchemy.orm import relationships as orm_relationships
     from sqlalchemy.orm import session as orm_session
     from sqlalchemy.orm import util as orm_util
     from sqlalchemy.sql import dml as sql_dml
index b3f3b93870872244409e931ee673c9974c48deb4..d192dc06bfc0d9e98206b7b11dd3f9087bd499af 100644 (file)
@@ -187,7 +187,9 @@ def is_union(type_):
     return is_origin_of(type_, "Union")
 
 
-def is_origin_of(type_, *names, module=None):
+def is_origin_of(
+    type_: Any, *names: str, module: Optional[str] = None
+) -> bool:
     """return True if the given type has an __origin__ with the given name
     and optional module."""
 
@@ -200,7 +202,7 @@ def is_origin_of(type_, *names, module=None):
     )
 
 
-def _get_type_name(type_):
+def _get_type_name(type_: Type[Any]) -> str:
     if compat.py310:
         return type_.__name__
     else:
@@ -208,4 +210,4 @@ def _get_type_name(type_):
         if typ_name is None:
             typ_name = getattr(type_, "_name", None)
 
-        return typ_name
+        return typ_name  # type: ignore
index e727ee1e49581bf80b7bc1963d1814ef02715cea..d16f03c032848922629c546b6c1b37712b0f9e9e 100644 (file)
@@ -56,7 +56,8 @@ incremental = true
 
 [[tool.mypy.overrides]]
 
-# ad-hoc ignores
+#####################################################################
+# modules / packages explicitly not checked by Mypy at all right now.
 module = [
     "sqlalchemy.engine.reflection",  # interim, should be strict
 
@@ -86,7 +87,8 @@ module = [
 warn_unused_ignores = false
 ignore_errors = true
 
-# strict checking
+################################################
+# modules explicitly for Mypy strict checking
 [[tool.mypy.overrides]]
 
 module = [
@@ -98,6 +100,11 @@ module = [
     "sqlalchemy.engine.*",
     "sqlalchemy.pool.*",
 
+    # uncomment, trying to make sure mypy
+    # is at a baseline
+    # "sqlalchemy.orm._orm_constructors",
+
+    "sqlalchemy.orm.path_registry",
     "sqlalchemy.orm.scoping",
     "sqlalchemy.orm.session",
     "sqlalchemy.orm.state",
@@ -114,7 +121,8 @@ warn_unused_ignores = false
 ignore_errors = false
 strict = true
 
-# partial checking
+################################################
+# modules explicitly for Mypy non-strict checking
 [[tool.mypy.overrides]]
 
 module = [
@@ -135,6 +143,12 @@ module = [
     "sqlalchemy.sql.traversals",
     "sqlalchemy.sql.util",
 
+    "sqlalchemy.orm._orm_constructors",
+
+    "sqlalchemy.orm.interfaces",
+    "sqlalchemy.orm.mapper",
+    "sqlalchemy.orm.util",
+
     "sqlalchemy.util.*",
 ]
 
index c5c897956a167b619ae4c508e11905f401a79798..e8b57a0c0226d2c1d296b9ef71574d11b5a1ed74 100644 (file)
@@ -40,8 +40,8 @@ class Address(Base):
 u1 = User()
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*\[builtins.str\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\]
     reveal_type(User.email_addresses)
 
-    # EXPECTED_TYPE: builtins.set\*\[builtins.str\]
+    # EXPECTED_TYPE: builtins.set\*?\[builtins.str\]
     reveal_type(u1.email_addresses)
index e97a9598b0d6e7f84f146886e6673c6a54713d0b..fe2742072c306c461b3986d7d6bf948c3eb39cb6 100644 (file)
@@ -8,7 +8,6 @@ from typing import Set
 
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
-from sqlalchemy import String
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -42,8 +41,8 @@ class Address(Base):
 
     id = mapped_column(Integer, primary_key=True)
     user_id = mapped_column(ForeignKey("user.id"))
-    email = mapped_column(String, nullable=False)
-    email_name = mapped_column("email_name", String, nullable=False)
+    email: Mapped[str]
+    email_name: Mapped[str] = mapped_column("email_name")
 
     user_style_one: Mapped[User] = relationship()
     user_style_two: Mapped["User"] = relationship()
@@ -56,14 +55,14 @@ if typing.TYPE_CHECKING:
     # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\]
     reveal_type(User.extra_name)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
     reveal_type(Address.email)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
     reveal_type(Address.email_name)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[experimental_relationship.Address\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\]
     reveal_type(User.addresses_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*\[experimental_relationship.Address\]\]
+    # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\]
     reveal_type(User.addresses_style_two)
index 7d97024afe03c6649f3627bc526f48dae8e93b03..d9f97ebcff5156abe69d05d8d298b393c4ef0deb 100644 (file)
@@ -47,7 +47,7 @@ expr2 = Interval.contains(7)
 expr3 = Interval.intersects(i2)
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: builtins.int\*
+    # EXPECTED_TYPE: builtins.int\*?
     reveal_type(i1.length)
 
     # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
index 6bfabbd30abb7f1d130c97cc5fc42e1009fbaff5..ab2970656ed0655aed507ea631e02241d2f1edeb 100644 (file)
@@ -69,10 +69,10 @@ expr3 = Interval.radius.in_([0.5, 5.2])
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: builtins.int\*
+    # EXPECTED_TYPE: builtins.int\*?
     reveal_type(i1.length)
 
-    # EXPECTED_TYPE: builtins.float\*
+    # EXPECTED_TYPE: builtins.float\*?
     reveal_type(i2.radius)
 
     # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
index b20beeb3a3e1236f7fcab4b3822661e66d432ab8..14f4ad845a82a7ef0fd5a342042bf5dcf20255e4 100644 (file)
@@ -14,68 +14,67 @@ class Base(DeclarativeBase):
 class X(Base):
     __tablename__ = "x"
 
+    # these are fine - pk, column is not null, have the attribute be
+    # non-optional, fine
     id: Mapped[int] = mapped_column(primary_key=True)
     int_id: Mapped[int] = mapped_column(Integer, primary_key=True)
 
-    # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types
+    # but this is also "fine" because the developer may wish to have the object
+    # in a pending state with None for the id for some period of time.
+    # "primary_key=True" will still be interpreted correctly in DDL
     err_int_id: Mapped[Optional[int]] = mapped_column(
         Integer, primary_key=True
     )
 
-    id_name: Mapped[int] = mapped_column("id_name", primary_key=True)
-    int_id_name: Mapped[int] = mapped_column(
-        "int_id_name", Integer, primary_key=True
-    )
-
-    # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types
+    # also fine, X(err_int_id_name) is None when you first make the
+    # object
     err_int_id_name: Mapped[Optional[int]] = mapped_column(
         "err_int_id_name", Integer, primary_key=True
     )
 
-    # note we arent getting into primary_key=True / nullable=True here.
-    # leaving that as undefined for now
+    id_name: Mapped[int] = mapped_column("id_name", primary_key=True)
+    int_id_name: Mapped[int] = mapped_column(
+        "int_id_name", Integer, primary_key=True
+    )
 
     a: Mapped[str] = mapped_column()
     b: Mapped[Optional[str]] = mapped_column()
 
-    # can't detect error because no SQL type is present
+    # this can't be detected because we don't know the type
     c: Mapped[str] = mapped_column(nullable=True)
     d: Mapped[str] = mapped_column(nullable=False)
 
     e: Mapped[Optional[str]] = mapped_column(nullable=True)
 
-    # can't detect error because no SQL type is present
     f: Mapped[Optional[str]] = mapped_column(nullable=False)
 
     g: Mapped[str] = mapped_column(String)
     h: Mapped[Optional[str]] = mapped_column(String)
 
-    # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types
+    # this probably is wrong.  however at the moment it seems better to
+    # decouple the right hand arguments from declaring things about the
+    # left side since it mostly doesn't work in any case.
     i: Mapped[str] = mapped_column(String, nullable=True)
 
     j: Mapped[str] = mapped_column(String, nullable=False)
 
     k: Mapped[Optional[str]] = mapped_column(String, nullable=True)
 
-    # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type
     l: Mapped[Optional[str]] = mapped_column(String, nullable=False)
 
     a_name: Mapped[str] = mapped_column("a_name")
     b_name: Mapped[Optional[str]] = mapped_column("b_name")
 
-    # can't detect error because no SQL type is present
     c_name: Mapped[str] = mapped_column("c_name", nullable=True)
     d_name: Mapped[str] = mapped_column("d_name", nullable=False)
 
     e_name: Mapped[Optional[str]] = mapped_column("e_name", nullable=True)
 
-    # can't detect error because no SQL type is present
     f_name: Mapped[Optional[str]] = mapped_column("f_name", nullable=False)
 
     g_name: Mapped[str] = mapped_column("g_name", String)
     h_name: Mapped[Optional[str]] = mapped_column("h_name", String)
 
-    # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types
     i_name: Mapped[str] = mapped_column("i_name", String, nullable=True)
 
     j_name: Mapped[str] = mapped_column("j_name", String, nullable=False)
@@ -86,7 +85,6 @@ class X(Base):
 
     l_name: Mapped[Optional[str]] = mapped_column(
         "l_name",
-        # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type
         String,
         nullable=False,
     )
index 78b0a467cefd175a474d07fda1bc3fdfe073857f..f9b9b2ffe50f26ffed193fb2ccc5c4f9b72fdab6 100644 (file)
@@ -3,6 +3,7 @@ import typing
 from sqlalchemy import Boolean
 from sqlalchemy import column
 from sqlalchemy import Integer
+from sqlalchemy import select
 from sqlalchemy import String
 
 
@@ -32,6 +33,7 @@ expr7 = c1 + "x"
 
 expr8 = c2 + 10
 
+stmt = select(column("q")).where(lambda: column("g") > 5).where(c2 == 5)
 
 if typing.TYPE_CHECKING:
 
index b43dcd594be4fc41ee8bd1c92eca823f19ea7413..af7d292be7ae8c42f0793fad7d6be489213d70a1 100644 (file)
@@ -101,45 +101,45 @@ class Address(Base):
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.Address\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\]
     reveal_type(User.addresses_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\]
     reveal_type(User.addresses_style_two)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[Any\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(User.addresses_style_three)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(User.addresses_style_three_cast)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[Any\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(User.addresses_style_four)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\]
     reveal_type(Address.user_style_one_typed)
 
     # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_two)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\]
     reveal_type(Address.user_style_two_typed)
 
     # reveal_type(Address.user_style_six)
 
     # reveal_type(Address.user_style_seven)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_eight)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_nine)
 
     # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_ten)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*\[builtins.str, trad_relationship_uselist.User\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\]
     reveal_type(Address.user_style_ten_typed)
index 473ccb28247a519c51c3d982ef2f694dcd91bbde..ce131dd00413ecec69d5054e344e8f73d560a67d 100644 (file)
@@ -60,29 +60,29 @@ class Address(Base):
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.Address\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\]
     reveal_type(User.addresses_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.Address\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\]
     reveal_type(User.addresses_style_two)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\]
     reveal_type(Address.user_style_one_typed)
 
     # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_two)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\]
     reveal_type(Address.user_style_two_typed)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\]
     reveal_type(Address.user_style_three)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\]
     reveal_type(Address.user_style_four)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.User\]\]
+    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_five)
index 0c8e3c4f647dda48dccb16cbe61f8a28e336fabb..15961c703a4f6ec2799cef8c8010fdd3f450ba9a 100644 (file)
@@ -17,7 +17,7 @@ class User(Base):
     __tablename__ = "user"
 
     id = mapped_column(Integer, primary_key=True)
-    name = mapped_column(String, nullable=True)
+    name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
 
     addresses: Mapped[List["Address"]] = relationship(
         "Address", back_populates="user"
index 466e636a7850ec06a435a710ebe324e926688adb..d29909c3c99265c8c6d4bd11207546bfa558c358 100644 (file)
@@ -43,11 +43,11 @@ class Address(Base):
 
     @declared_attr
     def email_address(cls) -> Column[String]:
-        # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa
+        # EXPECTED_MYPY: Argument 1 to "Column" has incompatible type "bool";
         return Column(True)
 
     @declared_attr
     # EXPECTED_MYPY: Invalid type comment or annotation
     def thisisweird(cls) -> Column(String):
-        # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa
+        # EXPECTED_MYPY: Argument 1 to "Column" has incompatible type "bool";
         return Column(False)
index 43443b7f64f2e72bc8cade6da210e409ee3ebc5c..7830fcee68fc05a12815344fd214865b5f786560 100644 (file)
@@ -18,6 +18,7 @@ from sqlalchemy.orm.instrumentation import register_class
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
@@ -110,7 +111,13 @@ class DisposeTest(_ExtBase, fixtures.TestBase):
         class MyClass:
             __sa_instrumentation_manager__ = MyClassState
 
-        assert attributes.manager_of_class(MyClass) is None
+        assert attributes.opt_manager_of_class(MyClass) is None
+
+        with expect_raises_message(
+            sa.orm.exc.UnmappedClassError,
+            r"Can't locate an instrumentation manager for class .*MyClass",
+        ):
+            attributes.manager_of_class(MyClass)
 
         t = Table(
             "my_table",
@@ -120,7 +127,7 @@ class DisposeTest(_ExtBase, fixtures.TestBase):
 
         registry.map_imperatively(MyClass, t)
 
-        manager = attributes.manager_of_class(MyClass)
+        manager = attributes.opt_manager_of_class(MyClass)
         is_not(manager, None)
         is_(manager, MyClass.xyz)
 
@@ -128,7 +135,7 @@ class DisposeTest(_ExtBase, fixtures.TestBase):
 
         registry.dispose()
 
-        manager = attributes.manager_of_class(MyClass)
+        manager = attributes.opt_manager_of_class(MyClass)
         is_(manager, None)
 
         assert not hasattr(MyClass, "xyz")
@@ -532,9 +539,9 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
         register_class(Known)
         k, u = Known(), Unknown()
 
-        assert instrumentation.manager_of_class(Unknown) is None
-        assert instrumentation.manager_of_class(Known) is not None
-        assert instrumentation.manager_of_class(None) is None
+        assert instrumentation.opt_manager_of_class(Unknown) is None
+        assert instrumentation.opt_manager_of_class(Known) is not None
+        assert instrumentation.opt_manager_of_class(None) is None
 
         assert attributes.instance_state(k) is not None
         assert_raises((AttributeError, KeyError), attributes.instance_state, u)
@@ -583,7 +590,10 @@ class FinderTest(_ExtBase, fixtures.ORMTest):
             )
 
         register_class(A)
-        ne_(type(manager_of_class(A)), instrumentation.ClassManager)
+        ne_(
+            type(attributes.opt_manager_of_class(A)),
+            instrumentation.ClassManager,
+        )
 
     def test_nativeext_submanager(self):
         class Mine(instrumentation.ClassManager):
index 67abc8971bfcda08ceacb7ad5ce4dde3ae330e8e..b50cbc2bae316627bb22e54cd90b8d065d3942cb 100644 (file)
@@ -391,6 +391,33 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest):
             polymorphic_identity=0,
         )
 
+    def test_polymorphic_on_not_present_col_partial_wpoly(self):
+        """fix for partial with_polymorphic().
+
+        found_during_type_annotation
+
+        """
+        t2, t1 = self.tables.t2, self.tables.t1
+        Parent = self.classes.Parent
+        t1t2_join = select(t1.c.x).select_from(t1.join(t2)).alias()
+
+        def go():
+            t1t2_join_2 = select(t1.c.q).select_from(t1.join(t2)).alias()
+            self.mapper_registry.map_imperatively(
+                Parent,
+                t2,
+                polymorphic_on=t1t2_join.c.x,
+                with_polymorphic=("*", None),
+                polymorphic_identity=0,
+            )
+
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            "Could not map polymorphic_on column 'x' to the mapped table - "
+            "polymorphic loads will not function properly",
+            go,
+        )
+
     def test_polymorphic_on_not_present_col(self):
         t2, t1 = self.tables.t2, self.tables.t1
         Parent = self.classes.Parent
index 51f37f02861d9e4ed7f92ad344bc9810c823aa05..5a171e3722ecf52fd1b4831e0a252064f00ae4d1 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import attributes
 from sqlalchemy.orm import backref
+from sqlalchemy.orm import CascadeOptions
 from sqlalchemy.orm import class_mapper
 from sqlalchemy.orm import configure_mappers
 from sqlalchemy.orm import exc as orm_exc
@@ -4217,11 +4218,12 @@ class SubclassCascadeTest(fixtures.DeclarativeMappedTest):
         eq_(s.query(Language).count(), 0)
 
 
-class ViewonlyFlagWarningTest(fixtures.MappedTest):
-    """test for #4993.
+class ViewonlyCascadeUpdate(fixtures.MappedTest):
+    """Test that cascades are trimmed accordingly when viewonly is set.
 
-    In 1.4, this moves to test/orm/test_cascade, deprecation warnings
-    become errors, will then be for #4994.
+    Originally #4993 and #4994 this was raising an error for invalid
+    cascades.  in 2.0 this is simplified to just remove the write
+    cascades, allows the default cascade to be reasonable.
 
     """
 
@@ -4250,21 +4252,17 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest):
             pass
 
     @testing.combinations(
-        ({"delete"}, {"delete"}),
+        ({"delete"}, {"none"}),
         (
             {"all, delete-orphan"},
-            {"delete", "delete-orphan", "merge", "save-update"},
+            {"refresh-expire", "expunge"},
         ),
-        ({"save-update, expunge"}, {"save-update"}),
+        ({"save-update, expunge"}, {"expunge"}),
     )
-    def test_write_cascades(self, setting, settings_that_warn):
+    def test_write_cascades(self, setting, expected):
         Order = self.classes.Order
 
-        assert_raises_message(
-            sa_exc.ArgumentError,
-            'Cascade settings "%s" apply to persistence '
-            "operations" % (", ".join(sorted(settings_that_warn))),
-            relationship,
+        r = relationship(
             Order,
             primaryjoin=(
                 self.tables.users.c.id == foreign(self.tables.orders.c.user_id)
@@ -4272,6 +4270,7 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest):
             cascade=", ".join(sorted(setting)),
             viewonly=True,
         )
+        eq_(r.cascade, CascadeOptions(expected))
 
     def test_expunge_cascade(self):
         User, Order, orders, users = (
@@ -4425,23 +4424,6 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest):
 
         eq_(umapper.attrs["orders"].cascade, set())
 
-    def test_write_cascade_disallowed_w_viewonly(self):
-
-        Order = self.classes.Order
-
-        assert_raises_message(
-            sa_exc.ArgumentError,
-            'Cascade settings "delete, delete-orphan, merge, save-update" '
-            "apply to persistence operations",
-            relationship,
-            Order,
-            primaryjoin=(
-                self.tables.users.c.id == foreign(self.tables.orders.c.user_id)
-            ),
-            cascade="all, delete, delete-orphan",
-            viewonly=True,
-        )
-
 
 class CollectionCascadesNoBackrefTest(fixtures.TestBase):
     """test the removal of cascade_backrefs behavior
index 3b10103001e719e5f8a5074707a5981d9f01404d..437129af16cda65615c5db531559c3257bcaef11 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_warns_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing.fixtures import fixture_session
@@ -739,9 +740,15 @@ class MiscTest(fixtures.MappedTest):
 
         assert instrumentation.manager_of_class(A) is manager
         instrumentation.unregister_class(A)
-        assert instrumentation.manager_of_class(A) is None
+        assert instrumentation.opt_manager_of_class(A) is None
         assert not hasattr(A, "x")
 
+        with expect_raises_message(
+            sa.orm.exc.UnmappedClassError,
+            r"Can't locate an instrumentation manager for class .*A",
+        ):
+            instrumentation.manager_of_class(A)
+
         assert A.__init__ == object.__init__
 
     def test_compileonattr_rel_backref_a(self):
index f71ab3032781cd41199cec50632738defd8f52ca..43a34eae4c00eab2f0ed9b9cbc5c7a4bdeab8122 100644 (file)
@@ -1374,6 +1374,16 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             [User(name="fred")],
         )
 
+    def test_str_not_accepted_orm_join(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        with expect_raises_message(
+            sa.exc.ArgumentError,
+            "ON clause, typically a SQL expression or ORM "
+            "relationship attribute expected, got 'addresses'.",
+        ):
+            outerjoin(User, Address, "addresses")
+
     def test_aliased_classes(self):
         User, Address = self.classes.User, self.classes.Address
 
@@ -1409,7 +1419,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
         eq_(result, [(user8, address3)])
 
         result = (
-            q.select_from(outerjoin(User, AdAlias, "addresses"))
+            q.select_from(outerjoin(User, AdAlias, User.addresses))
             .filter(AdAlias.email_address == "ed@bettyboop.com")
             .all()
         )
@@ -1504,7 +1514,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
         q = sess.query(Order)
         q = (
             q.add_entity(Item)
-            .select_from(join(Order, Item, "items"))
+            .select_from(join(Order, Item, Order.items))
             .order_by(Order.id, Item.id)
         )
         result = q.all()
@@ -1513,7 +1523,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
         IAlias = aliased(Item)
         q = (
             sess.query(Order, IAlias)
-            .select_from(join(Order, IAlias, "items"))
+            .select_from(join(Order, IAlias, Order.items))
             .filter(IAlias.description == "item 3")
         )
         result = q.all()
@@ -2569,18 +2579,6 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
             s.query(Node).join(Node.children)._compile_context,
         )
 
-    def test_explicit_join_1(self):
-        Node = self.classes.Node
-        n1 = aliased(Node)
-        n2 = aliased(Node)
-
-        self.assert_compile(
-            join(Node, n1, "children").join(n2, "children"),
-            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
-            "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
-            use_default_dialect=True,
-        )
-
     def test_explicit_join_2(self):
         Node = self.classes.Node
         n1 = aliased(Node)
@@ -2598,12 +2596,8 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
         n1 = aliased(Node)
         n2 = aliased(Node)
 
-        # the join_to_left=False here is unfortunate.   the default on this
-        # flag should be False.
         self.assert_compile(
-            join(Node, n1, Node.children).join(
-                n2, Node.children, join_to_left=False
-            ),
+            join(Node, n1, Node.children).join(n2, Node.children),
             "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
             "JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id",
             use_default_dialect=True,
@@ -2646,7 +2640,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
 
         node = (
             sess.query(Node)
-            .select_from(join(Node, n1, "children"))
+            .select_from(join(Node, n1, n1.children))
             .filter(n1.data == "n122")
             .first()
         )
@@ -2660,7 +2654,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
 
         node = (
             sess.query(Node)
-            .select_from(join(Node, n1, "children").join(n2, "children"))
+            .select_from(join(Node, n1, Node.children).join(n2, n1.children))
             .filter(n2.data == "n122")
             .first()
         )
@@ -2676,7 +2670,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
         node = (
             sess.query(Node)
             .select_from(
-                join(Node, n1, Node.id == n1.parent_id).join(n2, "children")
+                join(Node, n1, Node.id == n1.parent_id).join(n2, n1.children)
             )
             .filter(n2.data == "n122")
             .first()
@@ -2691,7 +2685,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
 
         node = (
             sess.query(Node)
-            .select_from(join(Node, n1, "parent").join(n2, "parent"))
+            .select_from(join(Node, n1, Node.parent).join(n2, n1.parent))
             .filter(
                 and_(Node.data == "n122", n1.data == "n12", n2.data == "n1")
             )
@@ -2708,7 +2702,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
         eq_(
             list(
                 sess.query(Node)
-                .select_from(join(Node, n1, "parent").join(n2, "parent"))
+                .select_from(join(Node, n1, Node.parent).join(n2, n1.parent))
                 .filter(
                     and_(
                         Node.data == "n122", n1.data == "n12", n2.data == "n1"
@@ -3085,7 +3079,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest):
         n1 = aliased(Node)
         eq_(
             sess.query(Node)
-            .select_from(join(Node, n1, "children"))
+            .select_from(join(Node, n1, Node.children))
             .filter(n1.data.in_(["n3", "n7"]))
             .order_by(Node.id)
             .all(),
index 980c82fbe2aff6824264f7316da9184326a820f8..d8cc48939b0fe616cf0b17c7b405c578ee6f6984 100644 (file)
@@ -2,6 +2,7 @@ import logging
 import logging.handlers
 
 import sqlalchemy as sa
+from sqlalchemy import column
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
@@ -9,6 +10,7 @@ from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import String
+from sqlalchemy import table
 from sqlalchemy import testing
 from sqlalchemy import util
 from sqlalchemy.engine import default
@@ -132,6 +134,22 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         ):
             self.mapper(User, users)
 
+    def test_no_table(self):
+        """test new error condition raised for table=None
+
+        found_during_type_annotation
+
+        """
+
+        User = self.classes.User
+
+        with expect_raises_message(
+            sa.exc.ArgumentError,
+            r"Mapper\[User\(None\)\] has None for a primary table "
+            r"argument and does not specify 'inherits'",
+        ):
+            self.mapper(User, None)
+
     def test_cant_call_legacy_constructor_directly(self):
         users, User = (
             self.tables.users,
@@ -341,6 +359,34 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             s,
         )
 
+    def test_no_tableclause(self):
+        """It's not tested for a Mapper to have lower-case table() objects
+        as part of its collection of tables, and in particular these objects
+        won't report on constraints or primary keys, which while this doesn't
+        necessarily disqualify them from being part of a mapper, we don't
+        have assumptions figured out right now to accommodate them.
+
+        found_during_type_annotation
+
+        """
+        User = self.classes.User
+        users = self.tables.users
+
+        address = table(
+            "address",
+            column("address_id", Integer),
+            column("user_id", Integer),
+        )
+
+        with expect_raises_message(
+            sa.exc.ArgumentError,
+            "ORM mappings can only be made against schema-level Table "
+            "objects, not TableClause; got tableclause 'address'",
+        ):
+            self.mapper_registry.map_imperatively(
+                User, users.join(address, users.c.id == address.c.user_id)
+            )
+
     def test_reconfigure_on_other_mapper(self):
         """A configure trigger on an already-configured mapper
         still triggers a check against all mappers."""
@@ -666,7 +712,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             (column_property, (users.c.name,)),
             (relationship, (Address,)),
             (composite, (MyComposite, "id", "name")),
-            (synonym, "foo"),
+            (synonym, ("foo",)),
         ]:
             obj = constructor(info={"x": "y"}, *args)
             eq_(obj.info, {"x": "y"})
index 96759e3889e859aa394b4e368dbc56f36af15fcc..d6fadc449cd73262dfecf44a74c9891ce8011908 100644 (file)
@@ -11,7 +11,6 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import attributes
-from sqlalchemy.orm import class_mapper
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm import defaultload
@@ -82,8 +81,7 @@ class PathTest:
         r = []
         for i, item in enumerate(path):
             if i % 2 == 0:
-                if isinstance(item, type):
-                    item = class_mapper(item)
+                item = inspect(item)
             else:
                 if isinstance(item, str):
                     item = inspect(r[-1]).mapper.attrs[item]
index d0c8f41084f9cb5cfefd6ccb1d88b6f71e375b3e..55414364c5439b357071ed5690ca1f4722cc007f 100644 (file)
@@ -6389,6 +6389,23 @@ class ParentTest(QueryTest, AssertsCompiledSQL):
             Order(description="order 5"),
         ] == o
 
+    def test_invalid_property(self):
+        """Test if with_parent is passed a non-relationship
+
+        found_during_type_annotation
+
+        """
+        User, Address = self.classes.User, self.classes.Address
+
+        sess = fixture_session()
+        u1 = sess.get(User, 7)
+        with expect_raises_message(
+            sa_exc.ArgumentError,
+            r"Expected relationship property for with_parent\(\), "
+            "got User.name",
+        ):
+            with_parent(u1, User.name)
+
     def test_select_from(self):
         User, Address = self.classes.User, self.classes.Address
 
index 03c31dc0ffb2d42ac47747e2782ddffe6e1b62ea..c829582fdfb579fc4f34e501e28d67ad3bf3de75 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import Table
 from sqlalchemy import testing
+from sqlalchemy.engine import result
 from sqlalchemy.ext.hybrid import hybrid_method
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import aliased
@@ -465,8 +466,7 @@ class IdentityKeyTest(_fixtures.FixtureTest):
 
     def _cases():
         return testing.combinations(
-            (orm_util,),
-            (Session,),
+            (orm_util,), (Session,), argnames="ormutil"
         )
 
     @_cases()
@@ -504,12 +504,29 @@ class IdentityKeyTest(_fixtures.FixtureTest):
         eq_(key, (User, (u.id,), None))
 
     @_cases()
-    def test_identity_key_3(self, ormutil):
+    @testing.combinations("dict", "row", "mapping", argnames="rowtype")
+    def test_identity_key_3(self, ormutil, rowtype):
+        """test a real Row works with identity_key.
+
+        this was broken w/ 1.4 future mode as we are assuming a mapping
+        here.  to prevent regressions, identity_key now accepts any of
+        dict, RowMapping, Row for the "row".
+
+        found_during_type_annotation
+
+
+        """
         User, users = self.classes.User, self.tables.users
 
         self.mapper_registry.map_imperatively(User, users)
 
-        row = {users.c.id: 1, users.c.name: "Frank"}
+        if rowtype == "dict":
+            row = {users.c.id: 1, users.c.name: "Frank"}
+        elif rowtype in ("mapping", "row"):
+            row = result.result_tuple([users.c.id, users.c.name])((1, "Frank"))
+            if rowtype == "mapping":
+                row = row._mapping
+
         key = ormutil.identity_key(User, row=row)
         eq_(key, (User, (1,), None))
 
index 9fdc519389fac09b7ce35cbfd5581fbd7e45757e..7fa39825c4911df1991d28965fddfff403bfd55b 100644 (file)
@@ -776,6 +776,28 @@ class SelectableTest(
             "table1.col3, table1.colx FROM table1) AS anon_1",
         )
 
+    def test_reduce_cols_odd_expressions(self):
+        """test util.reduce_columns() works with text, non-col expressions
+        in a SELECT.
+
+        found_during_type_annotation
+
+        """
+
+        stmt = select(
+            table1.c.col1,
+            table1.c.col3 * 5,
+            text("some_expr"),
+            table2.c.col2,
+            func.foo(),
+        ).join(table2)
+        self.assert_compile(
+            stmt.reduce_columns(only_synonyms=False),
+            "SELECT table1.col1, table1.col3 * :col3_1 AS anon_1, "
+            "some_expr, foo() AS foo_1 FROM table1 JOIN table2 "
+            "ON table1.col1 = table2.col2",
+        )
+
     def test_with_only_generative_no_list(self):
         s1 = table1.select().scalar_subquery()