]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep-484: session, instancestate, etc
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 7 Apr 2022 16:37:23 +0000 (12:37 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Tue, 12 Apr 2022 02:09:50 +0000 (02:09 +0000)
Also adds some fixes to annotation-based mapping
that have come up, as well as starts to add more
pep-484 test cases

Change-Id: Ia722bbbc7967a11b23b66c8084eb61df9d233fee

47 files changed:
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/util.py
lib/sqlalchemy/event/__init__.py
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/_typing.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/exc.py
lib/sqlalchemy/orm/identity.py
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/state_changes.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/_py_collections.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/preloaded.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/base/test_utils.py
test/ext/mypy/plain_files/session.py [new file with mode: 0644]
test/orm/declarative/test_tm_future_annotations.py
test/orm/test_core_compilation.py
test/orm/test_events.py
test/orm/test_pickled.py
test/orm/test_scoping.py
tools/generate_proxy_methods.py

index 07783ced78a95e1a8922637279b70ee4e13fc8fa..c0dc54fabe75c480ac3fbf42c2e246828d569dbc 100644 (file)
@@ -456,6 +456,8 @@ from .json import JSONB
 from ... import types as sqltypes
 from ... import util
 from ...engine import cursor as _cursor
+from ...util import FastIntFlag
+from ...util import parse_user_argument_for_enum
 
 
 logger = logging.getLogger("sqlalchemy.dialects.postgresql")
@@ -519,13 +521,19 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
     pass
 
 
-EXECUTEMANY_PLAIN = util.symbol("executemany_plain", canonical=0)
-EXECUTEMANY_BATCH = util.symbol("executemany_batch", canonical=1)
-EXECUTEMANY_VALUES = util.symbol("executemany_values", canonical=2)
-EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol(
-    "executemany_values_plus_batch",
-    canonical=EXECUTEMANY_BATCH | EXECUTEMANY_VALUES,
-)
+class ExecutemanyMode(FastIntFlag):
+    EXECUTEMANY_PLAIN = 0
+    EXECUTEMANY_BATCH = 1
+    EXECUTEMANY_VALUES = 2
+    EXECUTEMANY_VALUES_PLUS_BATCH = EXECUTEMANY_BATCH | EXECUTEMANY_VALUES
+
+
+(
+    EXECUTEMANY_PLAIN,
+    EXECUTEMANY_BATCH,
+    EXECUTEMANY_VALUES,
+    EXECUTEMANY_VALUES_PLUS_BATCH,
+) = tuple(ExecutemanyMode)
 
 
 class PGDialect_psycopg2(_PGDialect_common_psycopg):
@@ -564,7 +572,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
 
         # Parse executemany_mode argument, allowing it to be only one of the
         # symbol names
-        self.executemany_mode = util.symbol.parse_user_argument(
+        self.executemany_mode = parse_user_argument_for_enum(
             executemany_mode,
             {
                 EXECUTEMANY_PLAIN: [None],
index 5c446a91dc071666bdc0e04da06a6b74b7be6e8e..8bcc7e2587482da37f1f1724ea698b25b48f7b86 100644 (file)
@@ -79,8 +79,8 @@ if typing.TYPE_CHECKING:
 
 """
 
-_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.immutabledict()
-NO_OPTIONS: Mapping[str, Any] = util.immutabledict()
+_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT
+NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT
 
 
 class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
@@ -936,6 +936,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             )
         )
 
+    def _get_required_transaction(self) -> RootTransaction:
+        trans = self._transaction
+        if trans is None:
+            raise exc.InvalidRequestError("connection is not in a transaction")
+        return trans
+
+    def _get_required_nested_transaction(self) -> NestedTransaction:
+        trans = self._nested_transaction
+        if trans is None:
+            raise exc.InvalidRequestError(
+                "connection is not in a nested transaction"
+            )
+        return trans
+
     def get_transaction(self) -> Optional[RootTransaction]:
         """Return the current root transaction in progress, if any.
 
@@ -1220,7 +1234,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         self,
         func: FunctionElement[Any],
         distilled_parameters: _CoreMultiExecuteParams,
-        execution_options: _ExecuteOptions,
+        execution_options: _ExecuteOptionsParameter,
     ) -> Result:
         """Execute a sql.FunctionElement object."""
 
@@ -1232,7 +1246,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         self,
         default: ColumnDefault,
         distilled_parameters: _CoreMultiExecuteParams,
-        execution_options: _ExecuteOptions,
+        execution_options: _ExecuteOptionsParameter,
     ) -> Any:
         """Execute a schema.ColumnDefault object."""
 
@@ -1291,7 +1305,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         self,
         ddl: DDLElement,
         distilled_parameters: _CoreMultiExecuteParams,
-        execution_options: _ExecuteOptions,
+        execution_options: _ExecuteOptionsParameter,
     ) -> Result:
         """Execute a schema.DDL object."""
 
@@ -1388,7 +1402,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         self,
         elem: Executable,
         distilled_parameters: _CoreMultiExecuteParams,
-        execution_options: _ExecuteOptions,
+        execution_options: _ExecuteOptionsParameter,
     ) -> Result:
         """Execute a sql.ClauseElement object."""
 
@@ -1511,7 +1525,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         self,
         statement: str,
         parameters: Optional[_DBAPIAnyExecuteParams] = None,
-        execution_options: Optional[_ExecuteOptions] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
     ) -> Result:
         r"""Executes a SQL statement construct and returns a
         :class:`_engine.CursorResult`.
index 213485cc92c486c58d93365e6b6757b1e2bc7ce3..529b2ca73b8ea9431ceb9b003adbc19fbd338e0c 100644 (file)
@@ -10,11 +10,13 @@ from __future__ import annotations
 import typing
 from typing import Any
 from typing import Callable
+from typing import Optional
 from typing import TypeVar
 
 from .. import exc
 from .. import util
 from ..util._has_cy import HAS_CYEXTENSION
+from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
     from ._py_util import _distill_params_20 as _distill_params_20
@@ -49,6 +51,10 @@ def connection_memoize(key: str) -> Callable[[_C], _C]:
     return decorated  # type: ignore[return-value]
 
 
+class _TConsSubject(Protocol):
+    _trans_context_manager: Optional[TransactionalContext]
+
+
 class TransactionalContext:
     """Apply Python context manager behavior to transaction objects.
 
@@ -59,6 +65,8 @@ class TransactionalContext:
 
     __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__")
 
+    _trans_subject: Optional[_TConsSubject]
+
     def _transaction_is_active(self) -> bool:
         raise NotImplementedError()
 
@@ -82,7 +90,7 @@ class TransactionalContext:
         """
         raise NotImplementedError()
 
-    def _get_subject(self) -> Any:
+    def _get_subject(self) -> _TConsSubject:
         raise NotImplementedError()
 
     def commit(self) -> None:
@@ -95,7 +103,7 @@ class TransactionalContext:
         raise NotImplementedError()
 
     @classmethod
-    def _trans_ctx_check(cls, subject: Any) -> None:
+    def _trans_ctx_check(cls, subject: _TConsSubject) -> None:
         trans_context = subject._trans_context_manager
         if trans_context:
             if not trans_context._transaction_is_active():
index e1c94968139633b7b0cac89a9e9132e19dc70dc2..7e6d2a397949e31f3b64f433c2766ebc84403815 100644 (file)
@@ -13,6 +13,7 @@ from .api import listen as listen
 from .api import listens_for as listens_for
 from .api import NO_RETVAL as NO_RETVAL
 from .api import remove as remove
+from .attr import _InstanceLevelDispatch as _InstanceLevelDispatch
 from .attr import RefCollection as RefCollection
 from .base import _Dispatch as _Dispatch
 from .base import _DispatchCommon as _DispatchCommon
index cc572512553b65930836590c4e10889894d4836f..0503076aaf7d903fd00c0c461cb409ddc64abb9a 100644 (file)
@@ -116,7 +116,7 @@ class async_scoped_session(ScopedSessionMixin):
     # code within this block is **programmatically,
     # statically generated** by tools/generate_proxy_methods.py
 
-    def __contains__(self, instance):
+    def __contains__(self, instance: object) -> bool:
         r"""Return True if the instance is associated with this session.
 
         .. container:: class_bases
@@ -138,7 +138,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.__contains__(instance)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[object]:
         r"""Iterate over all pending or persistent instances within this
         Session.
 
@@ -156,7 +156,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.__iter__()
 
-    def add(self, instance: Any, _warn: bool = True) -> None:
+    def add(self, instance: object, _warn: bool = True) -> None:
         r"""Place an object in the ``Session``.
 
         .. container:: class_bases
@@ -181,7 +181,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.add(instance, _warn=_warn)
 
-    def add_all(self, instances):
+    def add_all(self, instances: Iterable[object]) -> None:
         r"""Add the given collection of instances to this ``Session``.
 
         .. container:: class_bases
@@ -374,7 +374,9 @@ class async_scoped_session(ScopedSessionMixin):
             **kw,
         )
 
-    def expire(self, instance, attribute_names=None):
+    def expire(
+        self, instance: object, attribute_names: Optional[Iterable[str]] = None
+    ) -> None:
         r"""Expire the attributes on an instance.
 
         .. container:: class_bases
@@ -426,7 +428,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.expire(instance, attribute_names=attribute_names)
 
-    def expire_all(self):
+    def expire_all(self) -> None:
         r"""Expires all persistent instances within this Session.
 
         .. container:: class_bases
@@ -473,7 +475,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.expire_all()
 
-    def expunge(self, instance):
+    def expunge(self, instance: object) -> None:
         r"""Remove the `instance` from this ``Session``.
 
         .. container:: class_bases
@@ -495,7 +497,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.expunge(instance)
 
-    def expunge_all(self):
+    def expunge_all(self) -> None:
         r"""Remove all object instances from this ``Session``.
 
         .. container:: class_bases
@@ -652,7 +654,9 @@ class async_scoped_session(ScopedSessionMixin):
             mapper=mapper, clause=clause, bind=bind, **kw
         )
 
-    def is_modified(self, instance, include_collections=True):
+    def is_modified(
+        self, instance: object, include_collections: bool = True
+    ) -> bool:
         r"""Return ``True`` if the given instance has locally
         modified attributes.
 
@@ -1168,7 +1172,7 @@ class async_scoped_session(ScopedSessionMixin):
         return await AsyncSession.close_all()
 
     @classmethod
-    def object_session(cls, instance: Any) -> "Session":
+    def object_session(cls, instance: object) -> Optional[Session]:
         r"""Return the :class:`.Session` to which an object belongs.
 
         .. container:: class_bases
@@ -1192,13 +1196,13 @@ class async_scoped_session(ScopedSessionMixin):
     @classmethod
     def identity_key(
         cls,
-        class_=None,
-        ident=None,
+        class_: Optional[Type[Any]] = None,
+        ident: Union[Any, Tuple[Any, ...]] = None,
         *,
-        instance=None,
-        row=None,
-        identity_token=None,
-    ) -> _IdentityKeyType:
+        instance: Optional[Any] = None,
+        row: Optional[Row] = None,
+        identity_token: Optional[Any] = None,
+    ) -> _IdentityKeyType[Any]:
         r"""Return an identity key.
 
         .. container:: class_bases
index 0bd2530b20e2e486d7421f6c0e0890c92c4487b0..769fe05bdb8d76b69d96c24d73a2f15b45ee337a 100644 (file)
@@ -640,7 +640,7 @@ class AsyncSession(ReversibleProxy):
     # code within this block is **programmatically,
     # statically generated** by tools/generate_proxy_methods.py
 
-    def __contains__(self, instance):
+    def __contains__(self, instance: object) -> bool:
         r"""Return True if the instance is associated with this session.
 
         .. container:: class_bases
@@ -656,7 +656,7 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.__contains__(instance)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[object]:
         r"""Iterate over all pending or persistent instances within this
         Session.
 
@@ -670,7 +670,7 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.__iter__()
 
-    def add(self, instance: Any, _warn: bool = True) -> None:
+    def add(self, instance: object, _warn: bool = True) -> None:
         r"""Place an object in the ``Session``.
 
         .. container:: class_bases
@@ -689,7 +689,7 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.add(instance, _warn=_warn)
 
-    def add_all(self, instances):
+    def add_all(self, instances: Iterable[object]) -> None:
         r"""Add the given collection of instances to this ``Session``.
 
         .. container:: class_bases
@@ -701,7 +701,9 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.add_all(instances)
 
-    def expire(self, instance, attribute_names=None):
+    def expire(
+        self, instance: object, attribute_names: Optional[Iterable[str]] = None
+    ) -> None:
         r"""Expire the attributes on an instance.
 
         .. container:: class_bases
@@ -747,7 +749,7 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.expire(instance, attribute_names=attribute_names)
 
-    def expire_all(self):
+    def expire_all(self) -> None:
         r"""Expires all persistent instances within this Session.
 
         .. container:: class_bases
@@ -788,7 +790,7 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.expire_all()
 
-    def expunge(self, instance):
+    def expunge(self, instance: object) -> None:
         r"""Remove the `instance` from this ``Session``.
 
         .. container:: class_bases
@@ -804,7 +806,7 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.expunge(instance)
 
-    def expunge_all(self):
+    def expunge_all(self) -> None:
         r"""Remove all object instances from this ``Session``.
 
         .. container:: class_bases
@@ -820,7 +822,9 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.expunge_all()
 
-    def is_modified(self, instance, include_collections=True):
+    def is_modified(
+        self, instance: object, include_collections: bool = True
+    ) -> bool:
         r"""Return ``True`` if the given instance has locally
         modified attributes.
 
@@ -882,7 +886,7 @@ class AsyncSession(ReversibleProxy):
             instance, include_collections=include_collections
         )
 
-    def in_transaction(self):
+    def in_transaction(self) -> bool:
         r"""Return True if this :class:`_orm.Session` has begun a transaction.
 
         .. container:: class_bases
@@ -902,7 +906,7 @@ class AsyncSession(ReversibleProxy):
 
         return self._proxied.in_transaction()
 
-    def in_nested_transaction(self):
+    def in_nested_transaction(self) -> bool:
         r"""Return True if this :class:`_orm.Session` has begun a nested
         transaction, e.g. SAVEPOINT.
 
@@ -978,7 +982,7 @@ class AsyncSession(ReversibleProxy):
         return self._proxied.new
 
     @property
-    def identity_map(self) -> identity.IdentityMap:
+    def identity_map(self) -> IdentityMap:
         r"""Proxy for the :attr:`_orm.Session.identity_map` attribute
         on behalf of the :class:`_asyncio.AsyncSession` class.
 
@@ -987,7 +991,7 @@ class AsyncSession(ReversibleProxy):
         return self._proxied.identity_map
 
     @identity_map.setter
-    def identity_map(self, attr: identity.IdentityMap) -> None:
+    def identity_map(self, attr: IdentityMap) -> None:
         self._proxied.identity_map = attr
 
     @property
@@ -1090,7 +1094,7 @@ class AsyncSession(ReversibleProxy):
         return self._proxied.info
 
     @classmethod
-    def object_session(cls, instance: Any) -> "Session":
+    def object_session(cls, instance: object) -> Optional[Session]:
         r"""Return the :class:`.Session` to which an object belongs.
 
         .. container:: class_bases
@@ -1108,13 +1112,13 @@ class AsyncSession(ReversibleProxy):
     @classmethod
     def identity_key(
         cls,
-        class_=None,
-        ident=None,
+        class_: Optional[Type[Any]] = None,
+        ident: Union[Any, Tuple[Any, ...]] = None,
         *,
-        instance=None,
-        row=None,
-        identity_token=None,
-    ) -> _IdentityKeyType:
+        instance: Optional[Any] = None,
+        row: Optional[Row] = None,
+        identity_token: Optional[Any] = None,
+    ) -> _IdentityKeyType[Any]:
         r"""Return an identity key.
 
         .. container:: class_bases
@@ -1243,7 +1247,7 @@ def async_object_session(instance):
         return None
 
 
-def async_session(session):
+def async_session(session: Session) -> AsyncSession:
     """Return the :class:`_asyncio.AsyncSession` which is proxying the given
     :class:`_orm.Session` object, if any.
 
index 5ca8b03dd69188b13de5098c54cb83f4e011e3a7..a0c7905d8453d225e9352187d3bc217752f6704b 100644 (file)
@@ -1303,7 +1303,9 @@ class Comparator(interfaces.PropComparator[_T]):
     def property(self) -> Any:
         return None
 
-    def adapt_to_entity(self, adapt_to_entity: AliasedInsp) -> Comparator[_T]:
+    def adapt_to_entity(
+        self, adapt_to_entity: AliasedInsp[Any]
+    ) -> Comparator[_T]:
         # interesting....
         return self
 
index c5b0affd2e7c49d06896636263ad46a3b6048734..6e8a0e77133de5fcd1c9c5e0e0f8f4b3ffba36c0 100644 (file)
@@ -460,7 +460,7 @@ def composite(
     class_: Type[_T],
     *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
     **kwargs: Any,
-) -> "Composite[_T]":
+) -> Composite[_T]:
     ...
 
 
@@ -468,7 +468,7 @@ def composite(
 def composite(
     *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
     **kwargs: Any,
-) -> "Composite[Any]":
+) -> Composite[Any]:
     ...
 
 
@@ -476,7 +476,7 @@ def composite(
     class_: Any = None,
     *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
     **kwargs: Any,
-) -> "Composite[Any]":
+) -> Composite[Any]:
     r"""Return a composite column-based property for use with a Mapper.
 
     See the mapping documentation section :ref:`mapper_composite` for a
index e9ddf6d1586650818bfc865f6b248497b72d3aef..4250cdbe1f6c532ba74648313af73f874237ec38 100644 (file)
@@ -1,11 +1,69 @@
 from __future__ import annotations
 
+import operator
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
+from sqlalchemy.orm.interfaces import UserDefinedOption
+from ..util.typing import Protocol
+from ..util.typing import TypeGuard
 
 if TYPE_CHECKING:
+    from .attributes import AttributeImpl
+    from .attributes import CollectionAttributeImpl
+    from .base import PassiveFlag
+    from .descriptor_props import _CompositeClassProto
     from .mapper import Mapper
+    from .state import InstanceState
+    from .util import AliasedClass
     from .util import AliasedInsp
+    from ..sql.base import ExecutableOption
 
-_EntityType = Union[Mapper, AliasedInsp]
+_T = TypeVar("_T", bound=Any)
+
+_O = TypeVar("_O", bound=Any)
+"""The 'ORM mapped object' type.
+I would have preferred this were bound=object however it seems
+to not travel in all situations when defined in that way.
+"""
+
+_InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"]
+
+_EntityType = Union[_T, "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"]
+
+
+_InstanceDict = Dict[str, Any]
+
+_IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]]
+
+
+class _LoaderCallable(Protocol):
+    def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any:
+        ...
+
+
+def is_user_defined_option(
+    opt: ExecutableOption,
+) -> TypeGuard[UserDefinedOption]:
+    return not opt._is_core and opt._is_user_defined  # type: ignore
+
+
+def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]:
+    return hasattr(obj, "__composite_values__")
+
+
+if TYPE_CHECKING:
+
+    def is_collection_impl(
+        impl: AttributeImpl,
+    ) -> TypeGuard[CollectionAttributeImpl]:
+        ...
+
+else:
+    is_collection_impl = operator.attrgetter("collection")
index 3d34927105278b0e566ccedad811fa11a5940f8b..33ce96a192b1bf803b2e739ea8457ce66e7581df 100644 (file)
@@ -18,12 +18,17 @@ from __future__ import annotations
 
 from collections import namedtuple
 import operator
-import typing
 from typing import Any
 from typing import Callable
+from typing import Collection
+from typing import Dict
 from typing import List
 from typing import NamedTuple
+from typing import Optional
+from typing import overload
 from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -35,8 +40,8 @@ from .base import ATTR_WAS_SET
 from .base import CALLABLES_OK
 from .base import DEFERRED_HISTORY_LOAD
 from .base import INIT_OK
-from .base import instance_dict
-from .base import instance_state
+from .base import instance_dict as instance_dict
+from .base import instance_state as instance_state
 from .base import instance_str
 from .base import LOAD_AGAINST_COMMITTED
 from .base import manager_of_class
@@ -55,6 +60,7 @@ from .base import PASSIVE_NO_RESULT
 from .base import PASSIVE_OFF
 from .base import PASSIVE_ONLY_PERSISTENT
 from .base import PASSIVE_RETURN_NO_VALUE
+from .base import PassiveFlag
 from .base import RELATED_OBJECT_OK  # noqa
 from .base import SQL_OK  # noqa
 from .base import state_str
@@ -67,7 +73,8 @@ from ..sql import roles
 from ..sql import traversals
 from ..sql import visitors
 
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
+    from .state import InstanceState
     from ..sql.dml import _DMLColumnElement
     from ..sql.elements import ColumnElement
     from ..sql.elements import SQLCoreOperations
@@ -115,6 +122,8 @@ class QueryableAttribute(
 
     is_attribute = True
 
+    impl: AttributeImpl
+
     # PropComparator has a __visit_name__ to participate within
     # traversals.   Disambiguate the attribute vs. a comparator.
     __visit_name__ = "orm_instrumented_attribute"
@@ -402,7 +411,19 @@ class InstrumentedAttribute(QueryableAttribute[_T]):
     def __delete__(self, instance):
         self.impl.delete(instance_state(instance), instance_dict(instance))
 
-    def __get__(self, instance, owner):
+    @overload
+    def __get__(
+        self, instance: None, owner: Type[Any]
+    ) -> InstrumentedAttribute:
+        ...
+
+    @overload
+    def __get__(self, instance: object, owner: Type[Any]) -> Optional[_T]:
+        ...
+
+    def __get__(
+        self, instance: Optional[object], owner: Type[Any]
+    ) -> Union[InstrumentedAttribute, Optional[_T]]:
         if instance is None:
             return self
 
@@ -636,6 +657,8 @@ Event = AttributeEvent
 class AttributeImpl:
     """internal implementation for instrumented attributes."""
 
+    collection: bool
+
     def __init__(
         self,
         class_,
@@ -811,7 +834,12 @@ class AttributeImpl:
 
             state.parents[id_] = False
 
-    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+    def get_history(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive=PASSIVE_OFF,
+    ) -> History:
         raise NotImplementedError()
 
     def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
@@ -989,7 +1017,12 @@ class ScalarAttributeImpl(AttributeImpl):
         ):
             raise AttributeError("%s object does not have a value" % self)
 
-    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+    def get_history(
+        self,
+        state: InstanceState[Any],
+        dict_: Dict[str, Any],
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> History:
         if self.key in dict_:
             return History.from_scalar_attribute(self, state, dict_[self.key])
         elif self.key in state.committed_state:
@@ -1005,13 +1038,13 @@ class ScalarAttributeImpl(AttributeImpl):
 
     def set(
         self,
-        state,
-        dict_,
-        value,
-        initiator,
-        passive=PASSIVE_OFF,
-        check_old=None,
-        pop=False,
+        state: InstanceState[Any],
+        dict_: Dict[str, Any],
+        value: Any,
+        initiator: Optional[Event],
+        passive: PassiveFlag = PASSIVE_OFF,
+        check_old: Optional[object] = None,
+        pop: bool = False,
     ):
         if self.dispatch._active_history:
             old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
@@ -1536,7 +1569,7 @@ class CollectionAttributeImpl(AttributeImpl):
         if fire_event:
             self.dispatch.dispose_collection(state, collection, adapter)
 
-    def _invalidate_collection(self, collection):
+    def _invalidate_collection(self, collection: Collection) -> None:
         adapter = getattr(collection, "_sa_adapter")
         adapter.invalidated = True
 
index a1a9442dcaf2e814cd6001fbc96b45e3541d1b08..d8f57e14981faa9d55a5c9adefe13384f0bb9830 100644 (file)
@@ -20,8 +20,8 @@ from typing import Dict
 from typing import Generic
 from typing import Optional
 from typing import overload
-from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -30,6 +30,7 @@ from .. import exc as sa_exc
 from .. import inspection
 from .. import util
 from ..sql.elements import SQLCoreOperations
+from ..util import FastIntFlag
 from ..util.langhelpers import TypingOnly
 from ..util.typing import Concatenate
 from ..util.typing import Literal
@@ -37,159 +38,147 @@ from ..util.typing import ParamSpec
 from ..util.typing import Self
 
 if typing.TYPE_CHECKING:
+    from ._typing import _InternalEntityType
     from .attributes import InstrumentedAttribute
     from .mapper import Mapper
+    from .state import InstanceState
 
 _T = TypeVar("_T", bound=Any)
 
+_O = TypeVar("_O", bound=object)
 
-_IdentityKeyType = Tuple[type, Tuple[Any, ...], Optional[str]]
 
-
-PASSIVE_NO_RESULT = util.symbol(
-    "PASSIVE_NO_RESULT",
+class LoaderCallableStatus(Enum):
+    PASSIVE_NO_RESULT = 0
     """Symbol returned by a loader callable or other attribute/history
     retrieval operation when a value could not be determined, based
     on loader callable flags.
-    """,
-)
+    """
 
-PASSIVE_CLASS_MISMATCH = util.symbol(
-    "PASSIVE_CLASS_MISMATCH",
+    PASSIVE_CLASS_MISMATCH = 1
     """Symbol indicating that an object is locally present for a given
     primary key identity but it is not of the requested class.  The
-    return value is therefore None and no SQL should be emitted.""",
-)
+    return value is therefore None and no SQL should be emitted."""
 
-ATTR_WAS_SET = util.symbol(
-    "ATTR_WAS_SET",
+    ATTR_WAS_SET = 2
     """Symbol returned by a loader callable to indicate the
     retrieved value, or values, were assigned to their attributes
     on the target object.
-    """,
-)
+    """
 
-ATTR_EMPTY = util.symbol(
-    "ATTR_EMPTY",
+    ATTR_EMPTY = 3
     """Symbol used internally to indicate an attribute had no callable.""",
-)
 
-NO_VALUE = util.symbol(
-    "NO_VALUE",
+    NO_VALUE = 4
     """Symbol which may be placed as the 'previous' value of an attribute,
     indicating no value was loaded for an attribute when it was modified,
     and flags indicated we were not to load it.
-    """,
-)
+    """
+
+    NEVER_SET = NO_VALUE
+    """
+    Synonymous with NO_VALUE
+
+    .. versionchanged:: 1.4   NEVER_SET was merged with NO_VALUE
+
+    """
+
+
+(
+    PASSIVE_NO_RESULT,
+    PASSIVE_CLASS_MISMATCH,
+    ATTR_WAS_SET,
+    ATTR_EMPTY,
+    NO_VALUE,
+) = tuple(LoaderCallableStatus)
+
 NEVER_SET = NO_VALUE
-"""
-Synonymous with NO_VALUE
 
-.. versionchanged:: 1.4   NEVER_SET was merged with NO_VALUE
-"""
 
-NO_CHANGE = util.symbol(
-    "NO_CHANGE",
+class PassiveFlag(FastIntFlag):
+    """Bitflag interface that passes options onto loader callables"""
+
+    NO_CHANGE = 0
     """No callables or SQL should be emitted on attribute access
     and no state should change
-    """,
-    canonical=0,
-)
+    """
 
-CALLABLES_OK = util.symbol(
-    "CALLABLES_OK",
+    CALLABLES_OK = 1
     """Loader callables can be fired off if a value
     is not present.
-    """,
-    canonical=1,
-)
+    """
 
-SQL_OK = util.symbol(
-    "SQL_OK",
-    """Loader callables can emit SQL at least on scalar value attributes.""",
-    canonical=2,
-)
+    SQL_OK = 2
+    """Loader callables can emit SQL at least on scalar value attributes."""
 
-RELATED_OBJECT_OK = util.symbol(
-    "RELATED_OBJECT_OK",
+    RELATED_OBJECT_OK = 4
     """Callables can use SQL to load related objects as well
     as scalar value attributes.
-    """,
-    canonical=4,
-)
+    """
 
-INIT_OK = util.symbol(
-    "INIT_OK",
+    INIT_OK = 8
     """Attributes should be initialized with a blank
     value (None or an empty collection) upon get, if no other
     value can be obtained.
-    """,
-    canonical=8,
-)
+    """
 
-NON_PERSISTENT_OK = util.symbol(
-    "NON_PERSISTENT_OK",
-    """Callables can be emitted if the parent is not persistent.""",
-    canonical=16,
-)
+    NON_PERSISTENT_OK = 16
+    """Callables can be emitted if the parent is not persistent."""
 
-LOAD_AGAINST_COMMITTED = util.symbol(
-    "LOAD_AGAINST_COMMITTED",
+    LOAD_AGAINST_COMMITTED = 32
     """Callables should use committed values as primary/foreign keys during a
     load.
-    """,
-    canonical=32,
-)
+    """
 
-NO_AUTOFLUSH = util.symbol(
-    "NO_AUTOFLUSH",
+    NO_AUTOFLUSH = 64
     """Loader callables should disable autoflush.""",
-    canonical=64,
-)
 
-NO_RAISE = util.symbol(
-    "NO_RAISE",
-    """Loader callables should not raise any assertions""",
-    canonical=128,
-)
+    NO_RAISE = 128
+    """Loader callables should not raise any assertions"""
 
-DEFERRED_HISTORY_LOAD = util.symbol(
-    "DEFERRED_HISTORY_LOAD",
-    """indicates special load of the previous value of an attribute""",
-    canonical=256,
-)
+    DEFERRED_HISTORY_LOAD = 256
+    """indicates special load of the previous value of an attribute"""
 
-# pre-packaged sets of flags used as inputs
-PASSIVE_OFF = util.symbol(
-    "PASSIVE_OFF",
-    "Callables can be emitted in all cases.",
-    canonical=(
+    # pre-packaged sets of flags used as inputs
+    PASSIVE_OFF = (
         RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK
-    ),
-)
-PASSIVE_RETURN_NO_VALUE = util.symbol(
-    "PASSIVE_RETURN_NO_VALUE",
-    """PASSIVE_OFF ^ INIT_OK""",
-    canonical=PASSIVE_OFF ^ INIT_OK,
-)
-PASSIVE_NO_INITIALIZE = util.symbol(
-    "PASSIVE_NO_INITIALIZE",
-    "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK",
-    canonical=PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK,
-)
-PASSIVE_NO_FETCH = util.symbol(
-    "PASSIVE_NO_FETCH", "PASSIVE_OFF ^ SQL_OK", canonical=PASSIVE_OFF ^ SQL_OK
-)
-PASSIVE_NO_FETCH_RELATED = util.symbol(
-    "PASSIVE_NO_FETCH_RELATED",
-    "PASSIVE_OFF ^ RELATED_OBJECT_OK",
-    canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK,
-)
-PASSIVE_ONLY_PERSISTENT = util.symbol(
-    "PASSIVE_ONLY_PERSISTENT",
-    "PASSIVE_OFF ^ NON_PERSISTENT_OK",
-    canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK,
-)
+    )
+    "Callables can be emitted in all cases."
+
+    PASSIVE_RETURN_NO_VALUE = PASSIVE_OFF ^ INIT_OK
+    """PASSIVE_OFF ^ INIT_OK"""
+
+    PASSIVE_NO_INITIALIZE = PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK
+    "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK"
+
+    PASSIVE_NO_FETCH = PASSIVE_OFF ^ SQL_OK
+    "PASSIVE_OFF ^ SQL_OK"
+
+    PASSIVE_NO_FETCH_RELATED = PASSIVE_OFF ^ RELATED_OBJECT_OK
+    "PASSIVE_OFF ^ RELATED_OBJECT_OK"
+
+    PASSIVE_ONLY_PERSISTENT = PASSIVE_OFF ^ NON_PERSISTENT_OK
+    "PASSIVE_OFF ^ NON_PERSISTENT_OK"
+
+
+(
+    NO_CHANGE,
+    CALLABLES_OK,
+    SQL_OK,
+    RELATED_OBJECT_OK,
+    INIT_OK,
+    NON_PERSISTENT_OK,
+    LOAD_AGAINST_COMMITTED,
+    NO_AUTOFLUSH,
+    NO_RAISE,
+    DEFERRED_HISTORY_LOAD,
+    PASSIVE_OFF,
+    PASSIVE_RETURN_NO_VALUE,
+    PASSIVE_NO_INITIALIZE,
+    PASSIVE_NO_FETCH,
+    PASSIVE_NO_FETCH_RELATED,
+    PASSIVE_ONLY_PERSISTENT,
+) = tuple(PassiveFlag)
 
 DEFAULT_MANAGER_ATTR = "_sa_class_manager"
 DEFAULT_STATE_ATTR = "_sa_instance_state"
@@ -285,18 +274,27 @@ def manager_of_class(cls):
     return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
 
 
-instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
+if TYPE_CHECKING:
+
+    def instance_state(instance: _O) -> InstanceState[_O]:
+        ...
+
+    def instance_dict(instance: object) -> Dict[str, Any]:
+        ...
 
-instance_dict = operator.attrgetter("__dict__")
+else:
+    instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
 
+    instance_dict = operator.attrgetter("__dict__")
 
-def instance_str(instance):
+
+def instance_str(instance: object) -> str:
     """Return a string describing an instance."""
 
     return state_str(instance_state(instance))
 
 
-def state_str(state):
+def state_str(state: InstanceState[Any]) -> str:
     """Return a string describing an instance via its InstanceState."""
 
     if state is None:
@@ -305,7 +303,7 @@ def state_str(state):
         return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj()))
 
 
-def state_class_str(state):
+def state_class_str(state: InstanceState[Any]) -> str:
     """Return a string describing an instance's class via its
     InstanceState.
     """
@@ -316,15 +314,15 @@ def state_class_str(state):
         return "<%s>" % (state.class_.__name__,)
 
 
-def attribute_str(instance, attribute):
+def attribute_str(instance: object, attribute: str) -> str:
     return instance_str(instance) + "." + attribute
 
 
-def state_attribute_str(state, attribute):
+def state_attribute_str(state: InstanceState[Any], attribute: str) -> str:
     return state_str(state) + "." + attribute
 
 
-def object_mapper(instance):
+def object_mapper(instance: _T) -> Mapper[_T]:
     """Given an object, return the primary Mapper associated with the object
     instance.
 
@@ -343,7 +341,7 @@ def object_mapper(instance):
     return object_state(instance).mapper
 
 
-def object_state(instance):
+def object_state(instance: _T) -> InstanceState[_T]:
     """Given an object, return the :class:`.InstanceState`
     associated with the object.
 
@@ -368,14 +366,14 @@ def object_state(instance):
 
 
 @inspection._inspects(object)
-def _inspect_mapped_object(instance):
+def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]:
     try:
         return instance_state(instance)
     except (exc.UnmappedClassError,) + exc.NO_STATE:
         return None
 
 
-def _class_to_mapper(class_or_mapper):
+def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]:
     insp = inspection.inspect(class_or_mapper, False)
     if insp is not None:
         return insp.mapper
@@ -383,7 +381,9 @@ def _class_to_mapper(class_or_mapper):
         raise exc.UnmappedClassError(class_or_mapper)
 
 
-def _mapper_or_none(entity):
+def _mapper_or_none(
+    entity: Union[_T, _InternalEntityType[_T]]
+) -> Optional[Mapper[_T]]:
     """Return the :class:`_orm.Mapper` for the given class or None if the
     class is not mapped.
     """
@@ -579,6 +579,8 @@ class InspectionAttrInfo(InspectionAttr):
 
     """
 
+    __slots__ = ()
+
     @util.memoized_property
     def info(self) -> Dict[Any, Any]:
         """Info dictionary associated with the object, allowing user-defined
index edd3fb56bfab1b8c37af920d7b2a93463e5937c1..419da65f7fc2829369a46ee47bb0f801c38d1604 100644 (file)
@@ -9,6 +9,7 @@ from __future__ import annotations
 
 import itertools
 from typing import Any
+from typing import cast
 from typing import Dict
 from typing import List
 from typing import Optional
@@ -61,7 +62,7 @@ from ..sql.selectable import SelectState
 from ..sql.visitors import InternalTraversal
 
 if TYPE_CHECKING:
-    from ._typing import _EntityType
+    from ._typing import _InternalEntityType
     from ..sql.compiler import _CompilerStackEntry
     from ..sql.dml import _DMLTableElement
     from ..sql.elements import ColumnElement
@@ -213,7 +214,7 @@ class ORMCompileState(CompileState):
     statement: Union[Select, FromStatement]
     select_statement: Union[Select, FromStatement]
     _entities: List[_QueryEntity]
-    _polymorphic_adapters: Dict[_EntityType, ORMAdapter]
+    _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter]
     compile_options: Union[
         Type[default_compile_options], default_compile_options
     ]
@@ -630,6 +631,24 @@ class FromStatement(GroupedElement, ReturnsRows, Executable):
 
         return compiler.process(compile_state.statement, **kw)
 
+    @property
+    def column_descriptions(self):
+        """Return a :term:`plugin-enabled` 'column descriptions' structure
+        referring to the columns which are SELECTed by this statement.
+
+        See the section :ref:`queryguide_inspection` for an overview
+        of this feature.
+
+        .. seealso::
+
+            :ref:`queryguide_inspection` - ORM background
+
+        """
+        meth = cast(
+            ORMSelectCompileState, SelectState.get_plugin_class(self)
+        ).get_column_descriptions
+        return meth(self)
+
     def _ensure_disambiguated_names(self):
         return self
 
index dd3931faf61af9c678e18504c8cb8d008697582d..32c69a7446d009ee611a3d6d0abb5bd487239a30 100644 (file)
@@ -19,9 +19,11 @@ import operator
 import typing
 from typing import Any
 from typing import Callable
+from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Tuple
+from typing import Type
 from typing import TypeVar
 from typing import Union
 
@@ -41,14 +43,23 @@ from .. import sql
 from .. import util
 from ..sql import expression
 from ..sql import operators
+from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from .attributes import InstrumentedAttribute
     from .properties import MappedColumn
+    from ..sql._typing import _ColumnExpressionArgument
+    from ..sql.schema import Column
 
 _T = TypeVar("_T", bound=Any)
 _PT = TypeVar("_PT", bound=Any)
 
 
+class _CompositeClassProto(Protocol):
+    def __composite_values__(self) -> Tuple[Any, ...]:
+        ...
+
+
 class DescriptorProperty(MapperProperty[_T]):
     """:class:`.MapperProperty` which proxies access to a
     user-defined descriptor."""
@@ -110,6 +121,11 @@ class DescriptorProperty(MapperProperty[_T]):
         mapper.class_manager.instrument_attribute(self.key, proxy_attr)
 
 
+_CompositeAttrType = Union[
+    str, "Column[Any]", "MappedColumn[Any]", "InstrumentedAttribute[Any]"
+]
+
+
 class Composite(
     _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T]
 ):
@@ -129,12 +145,21 @@ class Composite(
 
     """
 
-    composite_class: Union[type, Callable[..., type]]
-    attrs: Tuple[
-        Union[sql.ColumnElement[Any], "MappedColumn", str, Mapped[Any]], ...
+    composite_class: Union[
+        Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]]
     ]
+    attrs: Tuple[_CompositeAttrType, ...]
 
-    def __init__(self, class_=None, *attrs, **kwargs):
+    def __init__(
+        self,
+        class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None,
+        *attrs: _CompositeAttrType,
+        active_history: bool = False,
+        deferred: bool = False,
+        group: Optional[str] = None,
+        comparator_factory: Optional[Type[Comparator]] = None,
+        info: Optional[Dict[Any, Any]] = None,
+    ):
         super().__init__()
 
         if isinstance(class_, (Mapped, str, sql.ColumnElement)):
@@ -145,15 +170,17 @@ class Composite(
             self.composite_class = class_
             self.attrs = attrs
 
-        self.active_history = kwargs.get("active_history", False)
-        self.deferred = kwargs.get("deferred", False)
-        self.group = kwargs.get("group", None)
-        self.comparator_factory = kwargs.pop(
-            "comparator_factory", self.__class__.Comparator
+        self.active_history = active_history
+        self.deferred = deferred
+        self.group = group
+        self.comparator_factory = (
+            comparator_factory
+            if comparator_factory is not None
+            else self.__class__.Comparator
         )
         self._generated_composite_accessor = None
-        if "info" in kwargs:
-            self.info = kwargs.pop("info")
+        if info is not None:
+            self.info = info
 
         util.set_creation_order(self)
         self._create_descriptor()
@@ -162,7 +189,9 @@ class Composite(
         super().instrument_class(mapper)
         self._setup_event_handlers()
 
-    def _composite_values_from_instance(self, value):
+    def _composite_values_from_instance(
+        self, value: _CompositeClassProto
+    ) -> Tuple[Any, ...]:
         if self._generated_composite_accessor:
             return self._generated_composite_accessor(value)
         else:
index f70ea783739509cd8b2366a0314a6db0627abca1..00829ecbb7c9870ba0ea1abc5a5a59d56471e183 100644 (file)
@@ -9,6 +9,10 @@
 
 from __future__ import annotations
 
+from typing import Any
+from typing import Optional
+from typing import Type
+
 from .. import exc as sa_exc
 from .. import util
 from ..exc import MultipleResultsFound  # noqa
@@ -73,7 +77,7 @@ class UnmappedInstanceError(UnmappedError):
     """An mapping operation was requested for an unknown instance."""
 
     @util.preload_module("sqlalchemy.orm.base")
-    def __init__(self, obj, msg=None):
+    def __init__(self, obj: object, msg: Optional[str] = None):
         base = util.preloaded.orm_base
 
         if not msg:
@@ -87,7 +91,7 @@ class UnmappedInstanceError(UnmappedError):
                     "was called." % (name, name)
                 )
             except UnmappedClassError:
-                msg = _default_unmapped(type(obj))
+                msg = f"Class '{_safe_cls_name(type(obj))}' is not mapped"
                 if isinstance(obj, type):
                     msg += (
                         "; was a class (%s) supplied where an instance was "
@@ -102,12 +106,12 @@ class UnmappedInstanceError(UnmappedError):
 class UnmappedClassError(UnmappedError):
     """An mapping operation was requested for an unknown class."""
 
-    def __init__(self, cls, msg=None):
+    def __init__(self, cls: Type[object], msg: Optional[str] = None):
         if not msg:
             msg = _default_unmapped(cls)
         UnmappedError.__init__(self, msg)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Any:
         return self.__class__, (None, self.args[0])
 
 
@@ -194,7 +198,7 @@ def _safe_cls_name(cls):
 
 
 @util.preload_module("sqlalchemy.orm.base")
-def _default_unmapped(cls):
+def _default_unmapped(cls) -> Optional[str]:
     base = util.preloaded.orm_base
 
     try:
@@ -204,4 +208,6 @@ def _default_unmapped(cls):
     name = _safe_cls_name(cls)
 
     if not mappers:
-        return "Class '%s' is not mapped" % name
+        return f"Class '{name}' is not mapped"
+    else:
+        return None
index 3caf0b22fbc4ea0ec376ef9dc7b88945212a149c..d13265c56072c1df138cc005054b6ca69c1258a1 100644 (file)
 
 from __future__ import annotations
 
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import TYPE_CHECKING
+from typing import TypeVar
 import weakref
 
 from . import util as orm_util
 from .. import exc as sa_exc
 
+if TYPE_CHECKING:
+    from ._typing import _IdentityKeyType
+    from .state import InstanceState
+
+
+_T = TypeVar("_T", bound=Any)
+
+_O = TypeVar("_O", bound=object)
+
 
 class IdentityMap:
-    def __init__(self):
+    _wr: weakref.ref[IdentityMap]
+
+    _dict: Dict[_IdentityKeyType[Any], Any]
+    _modified: Set[InstanceState[Any]]
+
+    def __init__(self) -> None:
         self._dict = {}
         self._modified = set()
         self._wr = weakref.ref(self)
 
-    def _kill(self):
-        self._add_unpresent = _killed
+    def _kill(self) -> None:
+        self._add_unpresent = _killed  # type: ignore
+
+    def all_states(self) -> List[InstanceState[Any]]:
+        raise NotImplementedError()
+
+    def contains_state(self, state: InstanceState[Any]) -> bool:
+        raise NotImplementedError()
+
+    def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
+        raise NotImplementedError()
+
+    def safe_discard(self, state: InstanceState[Any]) -> None:
+        raise NotImplementedError()
+
+    def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
+        raise NotImplementedError()
+
+    def get(
+        self, key: _IdentityKeyType[_O], default: Optional[_O] = None
+    ) -> Optional[_O]:
+        raise NotImplementedError()
 
     def keys(self):
         return self._dict.keys()
 
-    def replace(self, state):
+    def values(self) -> Iterable[object]:
+        raise NotImplementedError()
+
+    def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]:
+        raise NotImplementedError()
+
+    def add(self, state: InstanceState[Any]) -> bool:
         raise NotImplementedError()
 
-    def add(self, state):
+    def _fast_discard(self, state: InstanceState[Any]) -> None:
         raise NotImplementedError()
 
-    def _add_unpresent(self, state, key):
+    def _add_unpresent(
+        self, state: InstanceState[Any], key: _IdentityKeyType[Any]
+    ) -> None:
         """optional inlined form of add() which can assume item isn't present
         in the map"""
         self.add(state)
 
-    def update(self, dict_):
-        raise NotImplementedError("IdentityMap uses add() to insert data")
-
-    def clear(self):
-        raise NotImplementedError("IdentityMap uses remove() to remove data")
-
-    def _manage_incoming_state(self, state):
+    def _manage_incoming_state(self, state: InstanceState[Any]) -> None:
         state._instance_dict = self._wr
 
         if state.modified:
             self._modified.add(state)
 
-    def _manage_removed_state(self, state):
+    def _manage_removed_state(self, state: InstanceState[Any]) -> None:
         del state._instance_dict
         if state.modified:
             self._modified.discard(state)
 
-    def _dirty_states(self):
+    def _dirty_states(self) -> Set[InstanceState[Any]]:
         return self._modified
 
-    def check_modified(self):
+    def check_modified(self) -> bool:
         """return True if any InstanceStates present have been marked
         as 'modified'.
 
         """
         return bool(self._modified)
 
-    def has_key(self, key):
+    def has_key(self, key: _IdentityKeyType[Any]) -> bool:
         return key in self
 
-    def popitem(self):
-        raise NotImplementedError("IdentityMap uses remove() to remove data")
-
-    def pop(self, key, *args):
-        raise NotImplementedError("IdentityMap uses remove() to remove data")
-
-    def setdefault(self, key, default=None):
-        raise NotImplementedError("IdentityMap uses add() to insert data")
-
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._dict)
 
-    def copy(self):
-        raise NotImplementedError()
-
-    def __setitem__(self, key, value):
-        raise NotImplementedError("IdentityMap uses add() to insert data")
-
-    def __delitem__(self, key):
-        raise NotImplementedError("IdentityMap uses remove() to remove data")
-
 
 class WeakInstanceDict(IdentityMap):
-    def __getitem__(self, key):
+    _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]]
+
+    def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
         state = self._dict[key]
         o = state.obj()
         if o is None:
             raise KeyError(key)
         return o
 
-    def __contains__(self, key):
+    def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
         try:
             if key in self._dict:
                 state = self._dict[key]
@@ -108,7 +138,7 @@ class WeakInstanceDict(IdentityMap):
         else:
             return o is not None
 
-    def contains_state(self, state):
+    def contains_state(self, state: InstanceState[Any]) -> bool:
         if state.key in self._dict:
             try:
                 return self._dict[state.key] is state
@@ -117,13 +147,15 @@ class WeakInstanceDict(IdentityMap):
         else:
             return False
 
-    def replace(self, state):
+    def replace(
+        self, state: InstanceState[Any]
+    ) -> Optional[InstanceState[Any]]:
         if state.key in self._dict:
             try:
                 existing = self._dict[state.key]
             except KeyError:
                 # catch gc removed the key after we just checked for it
-                pass
+                existing = None
             else:
                 if existing is not state:
                     self._manage_removed_state(existing)
@@ -136,7 +168,7 @@ class WeakInstanceDict(IdentityMap):
         self._manage_incoming_state(state)
         return existing
 
-    def add(self, state):
+    def add(self, state: InstanceState[Any]) -> bool:
         key = state.key
         # inline of self.__contains__
         if key in self._dict:
@@ -161,12 +193,16 @@ class WeakInstanceDict(IdentityMap):
         self._manage_incoming_state(state)
         return True
 
-    def _add_unpresent(self, state, key):
+    def _add_unpresent(
+        self, state: InstanceState[Any], key: _IdentityKeyType[Any]
+    ) -> None:
         # inlined form of add() called by loading.py
         self._dict[key] = state
         state._instance_dict = self._wr
 
-    def get(self, key, default=None):
+    def get(
+        self, key: _IdentityKeyType[_O], default: Optional[_O] = None
+    ) -> Optional[_O]:
         if key not in self._dict:
             return default
         try:
@@ -180,7 +216,7 @@ class WeakInstanceDict(IdentityMap):
                 return default
             return o
 
-    def items(self):
+    def items(self) -> List[InstanceState[Any]]:
         values = self.all_states()
         result = []
         for state in values:
@@ -189,7 +225,7 @@ class WeakInstanceDict(IdentityMap):
                 result.append((state.key, value))
         return result
 
-    def values(self):
+    def values(self) -> List[object]:
         values = self.all_states()
         result = []
         for state in values:
@@ -199,13 +235,13 @@ class WeakInstanceDict(IdentityMap):
 
         return result
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[_IdentityKeyType[Any]]:
         return iter(self.keys())
 
-    def all_states(self):
+    def all_states(self) -> List[InstanceState[Any]]:
         return list(self._dict.values())
 
-    def _fast_discard(self, state):
+    def _fast_discard(self, state: InstanceState[Any]) -> None:
         # used by InstanceState for state being
         # GC'ed, inlines _managed_removed_state
         try:
@@ -217,10 +253,10 @@ class WeakInstanceDict(IdentityMap):
             if st is state:
                 self._dict.pop(state.key, None)
 
-    def discard(self, state):
+    def discard(self, state: InstanceState[Any]) -> None:
         self.safe_discard(state)
 
-    def safe_discard(self, state):
+    def safe_discard(self, state: InstanceState[Any]) -> None:
         if state.key in self._dict:
             try:
                 st = self._dict[state.key]
@@ -233,7 +269,7 @@ class WeakInstanceDict(IdentityMap):
                     self._manage_removed_state(state)
 
 
-def _killed(state, key):
+def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn:
     # external function to avoid creating cycles when assigned to
     # the IdentityMap
     raise sa_exc.InvalidRequestError(
index a050c533a51b1c6705bf49ed5c9dc25ec3874648..030d1595b2f5208b3c4d9ea6bdabb9f81d688e43 100644 (file)
@@ -32,33 +32,64 @@ alternate instrumentation forms.
 
 from __future__ import annotations
 
+from typing import Any
+from typing import Dict
+from typing import Generic
+from typing import Set
+from typing import TYPE_CHECKING
+from typing import TypeVar
+
 from . import base
 from . import collections
 from . import exc
 from . import interfaces
 from . import state
 from .. import util
+from ..event import EventTarget
 from ..util import HasMemoized
+from ..util.typing import Protocol
 
+if TYPE_CHECKING:
+    from .attributes import InstrumentedAttribute
+    from .mapper import Mapper
+    from ..event import dispatcher
 
+_T = TypeVar("_T", bound=Any)
 DEL_ATTR = util.symbol("DEL_ATTR")
 
 
-class ClassManager(HasMemoized, dict):
+class _ExpiredAttributeLoaderProto(Protocol):
+    def __call__(
+        self,
+        state: state.InstanceState[Any],
+        toload: Set[str],
+        passive: base.PassiveFlag,
+    ):
+        ...
+
+
+class ClassManager(
+    HasMemoized,
+    Dict[str, "InstrumentedAttribute[Any]"],
+    Generic[_T],
+    EventTarget,
+):
     """Tracks state information at the class level."""
 
+    dispatch: dispatcher[ClassManager]
+
     MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
     STATE_ATTR = base.DEFAULT_STATE_ATTR
 
     _state_setter = staticmethod(util.attrsetter(STATE_ATTR))
 
-    expired_attribute_loader = None
+    expired_attribute_loader: _ExpiredAttributeLoaderProto
     "previously known as deferred_scalar_loader"
 
     init_method = None
 
     factory = None
-    mapper = None
+
     declarative_scan = None
     registry = None
 
@@ -199,7 +230,7 @@ class ClassManager(HasMemoized, dict):
         return frozenset([attr.impl for attr in self.values()])
 
     @util.memoized_property
-    def mapper(self):
+    def mapper(self) -> Mapper[_T]:
         # raises unless self.mapper has been assigned
         raise exc.UnmappedClassError(self.class_)
 
@@ -426,7 +457,9 @@ class ClassManager(HasMemoized, dict):
     def teardown_instance(self, instance):
         delattr(instance, self.STATE_ATTR)
 
-    def _serialize(self, state, state_dict):
+    def _serialize(
+        self, state: state.InstanceState, state_dict: Dict[str, Any]
+    ) -> _SerializeManager:
         return _SerializeManager(state, state_dict)
 
     def _new_state_if_none(self, instance):
@@ -480,7 +513,7 @@ class _SerializeManager:
 
     """
 
-    def __init__(self, state, d):
+    def __init__(self, state: state.InstanceState[Any], d: Dict[str, Any]):
         self.class_ = state.class_
         manager = state.manager
         manager.dispatch.pickle(state, d)
index b4228323b47bbd2e2f0b7528c1571fd2348440d2..7be7ce32b4d49430277024b609e985711385c9c7 100644 (file)
@@ -838,6 +838,10 @@ class ORMOption(ExecutableOption):
 
     """
 
+    _is_core = False
+
+    _is_user_defined = False
+
     _is_compile_state = False
 
     _is_criteria_option = False
@@ -942,6 +946,8 @@ class UserDefinedOption(ORMOption):
 
     _is_legacy_option = False
 
+    _is_user_defined = True
+
     propagate_to_loaders = False
     """if True, indicate this option should be carried along
     to "secondary" Query objects produced during lazy loads
index 6f4c654ce4ee9543d8e65505018462ef6a86289e..ae083054cd17f04cf94560b1b272190cda8e3ac9 100644 (file)
@@ -15,12 +15,24 @@ as well as some of the attribute loading strategies.
 
 from __future__ import annotations
 
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
+from sqlalchemy.orm.context import FromStatement
 from . import attributes
 from . import exc as orm_exc
 from . import path_registry
 from .base import _DEFER_FOR_STATE
 from .base import _RAISE_FOR_STATE
 from .base import _SET_DEFERRED_EXPIRED
+from .base import PassiveFlag
 from .util import _none_set
 from .util import state_str
 from .. import exc as sa_exc
@@ -31,9 +43,25 @@ from ..engine.result import ChunkedIteratorResult
 from ..engine.result import FrozenResult
 from ..engine.result import SimpleResultMetaData
 from ..sql import util as sql_util
+from ..sql.selectable import ForUpdateArg
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..sql.selectable import SelectState
 
+if TYPE_CHECKING:
+    from ._typing import _IdentityKeyType
+    from .base import LoaderCallableStatus
+    from .context import FromStatement
+    from .interfaces import ORMOption
+    from .mapper import Mapper
+    from .session import Session
+    from .state import InstanceState
+    from ..engine.interfaces import _ExecuteOptions
+    from ..sql import Select
+    from ..sql.base import Executable
+    from ..sql.selectable import ForUpdateArg
+
+_T = TypeVar("_T", bound=Any)
+_O = TypeVar("_O", bound=object)
 _new_runid = util.counter()
 
 
@@ -350,7 +378,12 @@ def merge_result(query, iterator, load=True):
         session.autoflush = autoflush
 
 
-def get_from_identity(session, mapper, key, passive):
+def get_from_identity(
+    session: Session,
+    mapper: Mapper[_O],
+    key: _IdentityKeyType[_O],
+    passive: PassiveFlag,
+) -> Union[Optional[_O], LoaderCallableStatus]:
     """Look up the given key in the given session's identity map,
     check the object for expired state if found.
 
@@ -385,16 +418,17 @@ def get_from_identity(session, mapper, key, passive):
 
 
 def load_on_ident(
-    session,
-    statement,
-    key,
-    load_options=None,
-    refresh_state=None,
-    with_for_update=None,
-    only_load_props=None,
-    no_autoflush=False,
-    bind_arguments=util.EMPTY_DICT,
-    execution_options=util.EMPTY_DICT,
+    session: Session,
+    statement: Union[Select, FromStatement],
+    key: Optional[_IdentityKeyType],
+    *,
+    load_options: Optional[Sequence[ORMOption]] = None,
+    refresh_state: Optional[InstanceState[Any]] = None,
+    with_for_update: Optional[ForUpdateArg] = None,
+    only_load_props: Optional[Iterable[str]] = None,
+    no_autoflush: bool = False,
+    bind_arguments: Mapping[str, Any] = util.EMPTY_DICT,
+    execution_options: _ExecuteOptions = util.EMPTY_DICT,
 ):
     """Load the given identity key from the database."""
     if key is not None:
@@ -419,17 +453,18 @@ def load_on_ident(
 
 
 def load_on_pk_identity(
-    session,
-    statement,
-    primary_key_identity,
-    load_options=None,
-    refresh_state=None,
-    with_for_update=None,
-    only_load_props=None,
-    identity_token=None,
-    no_autoflush=False,
-    bind_arguments=util.EMPTY_DICT,
-    execution_options=util.EMPTY_DICT,
+    session: Session,
+    statement: Union[Select, FromStatement],
+    primary_key_identity: Optional[Tuple[Any, ...]],
+    *,
+    load_options: Optional[Sequence[ORMOption]] = None,
+    refresh_state: Optional[InstanceState[Any]] = None,
+    with_for_update: Optional[ForUpdateArg] = None,
+    only_load_props: Optional[Iterable[str]] = None,
+    identity_token: Optional[Any] = None,
+    no_autoflush: bool = False,
+    bind_arguments: Mapping[str, Any] = util.EMPTY_DICT,
+    execution_options: _ExecuteOptions = util.EMPTY_DICT,
 ):
 
     """Load the given primary key identity from the database."""
index 982b4b6d9ca27df8d0efe6df3a102e82cc20ee56..c85861a594346665fe53bca25bf288aa23fc3e70 100644 (file)
@@ -22,9 +22,13 @@ from itertools import chain
 import sys
 import threading
 from typing import Any
+from typing import Callable
 from typing import Generic
+from typing import Iterator
+from typing import Optional
+from typing import Tuple
 from typing import Type
-from typing import TypeVar
+from typing import TYPE_CHECKING
 import weakref
 
 from . import attributes
@@ -33,9 +37,11 @@ from . import instrumentation
 from . import loading
 from . import properties
 from . import util as orm_util
+from ._typing import _O
 from .base import _class_to_mapper
 from .base import _state_mapper
 from .base import class_mapper
+from .base import PassiveFlag
 from .base import state_str
 from .interfaces import _MappedAttribute
 from .interfaces import EXT_SKIP
@@ -62,10 +68,15 @@ from ..sql import visitors
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..util import HasMemoized
 
-_mapper_registries = weakref.WeakKeyDictionary()
-
+if TYPE_CHECKING:
+    from ._typing import _IdentityKeyType
+    from ._typing import _InstanceDict
+    from .instrumentation import ClassManager
+    from .state import InstanceState
+    from ..sql.elements import ColumnElement
+    from ..sql.schema import Column
 
-_MC = TypeVar("_MC")
+_mapper_registries = weakref.WeakKeyDictionary()
 
 
 def _all_registries():
@@ -99,7 +110,7 @@ class Mapper(
     sql_base.MemoizedHasCacheKey,
     InspectionAttr,
     log.Identified,
-    Generic[_MC],
+    Generic[_O],
 ):
     """Defines an association between a Python class and a database table or
     other relational structure, so that ORM operations against the class may
@@ -115,9 +126,13 @@ class Mapper(
     _dispose_called = False
     _ready_for_configure = False
 
-    class_: Type[_MC]
+    class_: Type[_O]
     """The class to which this :class:`_orm.Mapper` is mapped."""
 
+    _identity_class: Type[_O]
+
+    always_refresh: bool
+
     @util.deprecated_params(
         non_primary=(
             "1.3",
@@ -130,7 +145,7 @@ class Mapper(
     )
     def __init__(
         self,
-        class_: Type[_MC],
+        class_: Type[_O],
         local_table=None,
         properties=None,
         primary_key=None,
@@ -813,7 +828,7 @@ class Mapper(
 
     """
 
-    primary_key = None
+    primary_key: Tuple[Column[Any], ...]
     """An iterable containing the collection of :class:`_schema.Column`
     objects
     which comprise the 'primary key' of the mapped table, from the
@@ -837,7 +852,7 @@ class Mapper(
 
     """
 
-    class_ = None
+    class_: Type[_O]
     """The Python class which this :class:`_orm.Mapper` maps.
 
     This is a *read only* attribute determined during mapper construction.
@@ -845,7 +860,7 @@ class Mapper(
 
     """
 
-    class_manager = None
+    class_manager: ClassManager[_O]
     """The :class:`.ClassManager` which maintains event listeners
     and class-bound descriptors for this :class:`_orm.Mapper`.
 
@@ -1965,7 +1980,7 @@ class Mapper(
             else self.persist_selectable.description,
         )
 
-    def _is_orphan(self, state):
+    def _is_orphan(self, state: InstanceState[_O]) -> bool:
         orphan_possible = False
         for mapper in self.iterate_to_root():
             for (key, cls) in mapper._delete_orphans:
@@ -2804,16 +2819,24 @@ class Mapper(
             identity_token,
         )
 
-    def identity_key_from_primary_key(self, primary_key, identity_token=None):
+    def identity_key_from_primary_key(
+        self,
+        primary_key: Tuple[Any, ...],
+        identity_token: Optional[Any] = None,
+    ) -> _IdentityKeyType[_O]:
         """Return an identity-map key for use in storing/retrieving an
         item from an identity map.
 
         :param primary_key: A list of values indicating the identifier.
 
         """
-        return self._identity_class, tuple(primary_key), identity_token
+        return (
+            self._identity_class,
+            tuple(primary_key),
+            identity_token,
+        )
 
-    def identity_key_from_instance(self, instance):
+    def identity_key_from_instance(self, instance: _O) -> _IdentityKeyType[_O]:
         """Return the identity key for the given instance, based on
         its primary key attributes.
 
@@ -2830,8 +2853,10 @@ class Mapper(
         return self._identity_key_from_state(state, attributes.PASSIVE_OFF)
 
     def _identity_key_from_state(
-        self, state, passive=attributes.PASSIVE_RETURN_NO_VALUE
-    ):
+        self,
+        state: InstanceState[_O],
+        passive: PassiveFlag = attributes.PASSIVE_RETURN_NO_VALUE,
+    ) -> _IdentityKeyType[_O]:
         dict_ = state.dict
         manager = state.manager
         return (
@@ -2845,7 +2870,7 @@ class Mapper(
             state.identity_token,
         )
 
-    def primary_key_from_instance(self, instance):
+    def primary_key_from_instance(self, instance: _O) -> Tuple[Any, ...]:
         """Return the list of primary key values for the given
         instance.
 
@@ -2903,8 +2928,12 @@ class Mapper(
         return {self._columntoproperty[col].key for col in self._all_pk_cols}
 
     def _get_state_attr_by_column(
-        self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE
-    ):
+        self,
+        state: InstanceState[_O],
+        dict_: _InstanceDict,
+        column: Column[Any],
+        passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE,
+    ) -> Any:
         prop = self._columntoproperty[column]
         return state.manager[prop.key].impl.get(state, dict_, passive=passive)
 
@@ -3146,7 +3175,14 @@ class Mapper(
     def _subclass_load_via_in_mapper(self):
         return self._subclass_load_via_in(self)
 
-    def cascade_iterator(self, type_, state, halt_on=None):
+    def cascade_iterator(
+        self,
+        type_: str,
+        state: InstanceState[_O],
+        halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None,
+    ) -> Iterator[
+        Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict]
+    ]:
         r"""Iterate each element and its mapper in an object graph,
         for all relationships that meet the given cascade rule.
 
index 9a7aa91a03bd58adb322d9378d2d86beb30581f1..e2cf1d5b04298c7f8ab6d76fa76f0eef3611d70f 100644 (file)
@@ -14,6 +14,7 @@ from functools import reduce
 from itertools import chain
 import logging
 from typing import Any
+from typing import Sequence
 from typing import Tuple
 from typing import Union
 
@@ -198,12 +199,12 @@ class PathRegistry(HasCacheKey):
             p = p[0:-1]
         return p
 
-    def serialize(self):
+    def serialize(self) -> Sequence[Any]:
         path = self.path
         return self._serialize_path(path)
 
     @classmethod
-    def deserialize(cls, path: Tuple) -> "PathRegistry":
+    def deserialize(cls, path: Sequence[Any]) -> PathRegistry:
         assert path is not None
         p = cls._deserialize_path(path)
         return cls.coerce(p)
index 355ddc922d6367ba21715c941a7e239827eb8615..93b49ab254895280ead2d258e32e2f5924928f44 100644 (file)
@@ -19,6 +19,12 @@ from itertools import chain
 from itertools import groupby
 from itertools import zip_longest
 import operator
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
 
 from . import attributes
 from . import evaluator
@@ -47,15 +53,22 @@ from ..sql.dml import UpdateDMLState
 from ..sql.elements import BooleanClauseList
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 
+if TYPE_CHECKING:
+    from .mapper import Mapper
+    from .session import SessionTransaction
+    from .state import InstanceState
+
+_O = TypeVar("_O", bound=object)
+
 
 def _bulk_insert(
-    mapper,
-    mappings,
-    session_transaction,
-    isstates,
-    return_defaults,
-    render_nulls,
-):
+    mapper: Mapper[_O],
+    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+    session_transaction: SessionTransaction,
+    isstates: bool,
+    return_defaults: bool,
+    render_nulls: bool,
+) -> None:
     base_mapper = mapper.base_mapper
 
     if session_transaction.session.connection_callable:
@@ -126,8 +139,12 @@ def _bulk_insert(
 
 
 def _bulk_update(
-    mapper, mappings, session_transaction, isstates, update_changed_only
-):
+    mapper: Mapper[Any],
+    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+    session_transaction: SessionTransaction,
+    isstates: bool,
+    update_changed_only: bool,
+) -> None:
     base_mapper = mapper.base_mapper
 
     search_keys = mapper._primary_key_propkeys
index cc9d5a23b39a742ec36aedba571065aa0048215c..e498b17b4d31c60471c528b639eeeccf736bad2f 100644 (file)
@@ -7,6 +7,19 @@
 
 from __future__ import annotations
 
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
 from . import exc as orm_exc
 from .base import class_mapper
 from .session import Session
@@ -17,16 +30,54 @@ from ..util import ScopedRegistry
 from ..util import ThreadLocalRegistry
 from ..util import warn
 from ..util import warn_deprecated
+from ..util.typing import Protocol
+
+if TYPE_CHECKING:
+    from ._typing import _IdentityKeyType
+    from .identity import IdentityMap
+    from .interfaces import ORMOption
+    from .mapper import Mapper
+    from .query import Query
+    from .session import _EntityBindKey
+    from .session import _PKIdentityArgument
+    from .session import _SessionBind
+    from .session import sessionmaker
+    from .session import SessionTransaction
+    from ..engine import Connection
+    from ..engine import Engine
+    from ..engine import Result
+    from ..engine import Row
+    from ..engine.interfaces import _CoreAnyExecuteParams
+    from ..engine.interfaces import _CoreSingleExecuteParams
+    from ..engine.interfaces import _ExecuteOptions
+    from ..engine.interfaces import _ExecuteOptionsParameter
+    from ..engine.result import ScalarResult
+    from ..sql._typing import _ColumnsClauseArgument
+    from ..sql.base import Executable
+    from ..sql.elements import ClauseElement
+    from ..sql.selectable import ForUpdateArg
+
+
+class _QueryDescriptorType(Protocol):
+    def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]:
+        ...
+
+
+_O = TypeVar("_O", bound=object)
 
 __all__ = ["scoped_session", "ScopedSessionMixin"]
 
 
 class ScopedSessionMixin:
+    session_factory: sessionmaker
+    _support_async: bool
+    registry: ScopedRegistry[Session]
+
     @property
-    def _proxied(self):
-        return self.registry()
+    def _proxied(self) -> Session:
+        return self.registry()  # type: ignore
 
-    def __call__(self, **kw):
+    def __call__(self, **kw: Any) -> Session:
         r"""Return the current :class:`.Session`, creating it
         using the :attr:`.scoped_session.session_factory` if not present.
 
@@ -57,7 +108,7 @@ class ScopedSessionMixin:
             )
         return sess
 
-    def configure(self, **kwargs):
+    def configure(self, **kwargs: Any) -> None:
         """reconfigure the :class:`.sessionmaker` used by this
         :class:`.scoped_session`.
 
@@ -120,7 +171,6 @@ class ScopedSessionMixin:
         "autoflush",
         "no_autoflush",
         "info",
-        "autocommit",
     ],
 )
 class scoped_session(ScopedSessionMixin):
@@ -136,15 +186,20 @@ class scoped_session(ScopedSessionMixin):
 
     """
 
-    _support_async = False
+    _support_async: bool = False
 
-    session_factory = None
+    session_factory: sessionmaker
     """The `session_factory` provided to `__init__` is stored in this
     attribute and may be accessed at a later time.  This can be useful when
     a new non-scoped :class:`.Session` or :class:`_engine.Connection` to the
     database is needed."""
 
-    def __init__(self, session_factory, scopefunc=None):
+    def __init__(
+        self,
+        session_factory: sessionmaker,
+        scopefunc: Optional[Callable[[], Any]] = None,
+    ):
+
         """Construct a new :class:`.scoped_session`.
 
         :param session_factory: a factory to create new :class:`.Session`
@@ -167,7 +222,7 @@ class scoped_session(ScopedSessionMixin):
         else:
             self.registry = ThreadLocalRegistry(session_factory)
 
-    def remove(self):
+    def remove(self) -> None:
         """Dispose of the current :class:`.Session`, if present.
 
         This will first call :meth:`.Session.close` method
@@ -184,7 +239,9 @@ class scoped_session(ScopedSessionMixin):
             self.registry().close()
         self.registry.clear()
 
-    def query_property(self, query_cls=None):
+    def query_property(
+        self, query_cls: Optional[Type[Query[Any]]] = None
+    ) -> _QueryDescriptorType:
         """return a class property which produces a :class:`_query.Query`
         object
         against the class and the current :class:`.Session` when called.
@@ -211,16 +268,18 @@ class scoped_session(ScopedSessionMixin):
         """
 
         class query:
-            def __get__(s, instance, owner):
+            def __get__(
+                s, instance: Any, owner: Type[Any]
+            ) -> Optional[Query[Any]]:
                 try:
                     mapper = class_mapper(owner)
-                    if mapper:
-                        if query_cls:
-                            # custom query class
-                            return query_cls(mapper, session=self.registry())
-                        else:
-                            # session's configured query class
-                            return self.registry().query(mapper)
+                    assert mapper is not None
+                    if query_cls:
+                        # custom query class
+                        return query_cls(mapper, session=self.registry())
+                    else:
+                        # session's configured query class
+                        return self.registry().query(mapper)
                 except orm_exc.UnmappedClassError:
                     return None
 
@@ -231,7 +290,7 @@ class scoped_session(ScopedSessionMixin):
     # code within this block is **programmatically,
     # statically generated** by tools/generate_proxy_methods.py
 
-    def __contains__(self, instance):
+    def __contains__(self, instance: object) -> bool:
         r"""Return True if the instance is associated with this session.
 
         .. container:: class_bases
@@ -247,7 +306,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.__contains__(instance)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[object]:
         r"""Iterate over all pending or persistent instances within this
         Session.
 
@@ -261,7 +320,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.__iter__()
 
-    def add(self, instance: Any, _warn: bool = True) -> None:
+    def add(self, instance: object, _warn: bool = True) -> None:
         r"""Place an object in the ``Session``.
 
         .. container:: class_bases
@@ -280,7 +339,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.add(instance, _warn=_warn)
 
-    def add_all(self, instances):
+    def add_all(self, instances: Iterable[object]) -> None:
         r"""Add the given collection of instances to this ``Session``.
 
         .. container:: class_bases
@@ -292,7 +351,9 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.add_all(instances)
 
-    def begin(self, nested=False, _subtrans=False):
+    def begin(
+        self, nested: bool = False, _subtrans: bool = False
+    ) -> SessionTransaction:
         r"""Begin a transaction, or nested transaction,
         on this :class:`.Session`, if one is not already begun.
 
@@ -335,7 +396,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.begin(nested=nested, _subtrans=_subtrans)
 
-    def begin_nested(self):
+    def begin_nested(self) -> SessionTransaction:
         r"""Begin a "nested" transaction on this Session, e.g. SAVEPOINT.
 
         .. container:: class_bases
@@ -367,7 +428,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.begin_nested()
 
-    def close(self):
+    def close(self) -> None:
         r"""Close out the transactional resources and ORM objects used by this
         :class:`_orm.Session`.
 
@@ -434,7 +495,7 @@ class scoped_session(ScopedSessionMixin):
     def connection(
         self,
         bind_arguments: Optional[Dict[str, Any]] = None,
-        execution_options: Optional["_ExecuteOptions"] = None,
+        execution_options: Optional[_ExecuteOptions] = None,
     ) -> "Connection":
         r"""Return a :class:`_engine.Connection` object corresponding to this
         :class:`.Session` object's transactional state.
@@ -476,7 +537,7 @@ class scoped_session(ScopedSessionMixin):
             bind_arguments=bind_arguments, execution_options=execution_options
         )
 
-    def delete(self, instance):
+    def delete(self, instance: object) -> None:
         r"""Mark an instance as deleted.
 
         .. container:: class_bases
@@ -493,13 +554,13 @@ class scoped_session(ScopedSessionMixin):
 
     def execute(
         self,
-        statement: "Executable",
-        params: Optional["_ExecuteParams"] = None,
-        execution_options: "_ExecuteOptions" = util.EMPTY_DICT,
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[Dict[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-    ):
+    ) -> Result:
         r"""Execute a SQL expression construct.
 
         .. container:: class_bases
@@ -567,7 +628,9 @@ class scoped_session(ScopedSessionMixin):
             _add_event=_add_event,
         )
 
-    def expire(self, instance, attribute_names=None):
+    def expire(
+        self, instance: object, attribute_names: Optional[Iterable[str]] = None
+    ) -> None:
         r"""Expire the attributes on an instance.
 
         .. container:: class_bases
@@ -613,7 +676,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.expire(instance, attribute_names=attribute_names)
 
-    def expire_all(self):
+    def expire_all(self) -> None:
         r"""Expires all persistent instances within this Session.
 
         .. container:: class_bases
@@ -654,7 +717,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.expire_all()
 
-    def expunge(self, instance):
+    def expunge(self, instance: object) -> None:
         r"""Remove the `instance` from this ``Session``.
 
         .. container:: class_bases
@@ -670,7 +733,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.expunge(instance)
 
-    def expunge_all(self):
+    def expunge_all(self) -> None:
         r"""Remove all object instances from this ``Session``.
 
         .. container:: class_bases
@@ -686,7 +749,7 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.expunge_all()
 
-    def flush(self, objects=None):
+    def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
         r"""Flush all the object changes to the database.
 
         .. container:: class_bases
@@ -719,14 +782,15 @@ class scoped_session(ScopedSessionMixin):
 
     def get(
         self,
-        entity,
-        ident,
-        options=None,
-        populate_existing=False,
-        with_for_update=None,
-        identity_token=None,
-        execution_options=None,
-    ):
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: Optional[ForUpdateArg] = None,
+        identity_token: Optional[Any] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+    ) -> Optional[_O]:
         r"""Return an instance based on the given primary key identifier,
         or ``None`` if not found.
 
@@ -841,12 +905,12 @@ class scoped_session(ScopedSessionMixin):
 
     def get_bind(
         self,
-        mapper=None,
-        clause=None,
-        bind=None,
-        _sa_skip_events=None,
-        _sa_skip_for_implicit_returning=False,
-    ):
+        mapper: Optional[_EntityBindKey[_O]] = None,
+        clause: Optional[ClauseElement] = None,
+        bind: Optional[_SessionBind] = None,
+        _sa_skip_events: Optional[bool] = None,
+        _sa_skip_for_implicit_returning: bool = False,
+    ) -> Union[Engine, Connection]:
         r"""Return a "bind" to which this :class:`.Session` is bound.
 
         .. container:: class_bases
@@ -933,7 +997,9 @@ class scoped_session(ScopedSessionMixin):
             _sa_skip_for_implicit_returning=_sa_skip_for_implicit_returning,
         )
 
-    def is_modified(self, instance, include_collections=True):
+    def is_modified(
+        self, instance: object, include_collections: bool = True
+    ) -> bool:
         r"""Return ``True`` if the given instance has locally
         modified attributes.
 
@@ -997,11 +1063,11 @@ class scoped_session(ScopedSessionMixin):
 
     def bulk_save_objects(
         self,
-        objects,
-        return_defaults=False,
-        update_changed_only=True,
-        preserve_order=True,
-    ):
+        objects: Iterable[object],
+        return_defaults: bool = False,
+        update_changed_only: bool = True,
+        preserve_order: bool = True,
+    ) -> None:
         r"""Perform a bulk save of the given list of objects.
 
         .. container:: class_bases
@@ -1109,8 +1175,12 @@ class scoped_session(ScopedSessionMixin):
         )
 
     def bulk_insert_mappings(
-        self, mapper, mappings, return_defaults=False, render_nulls=False
-    ):
+        self,
+        mapper: Mapper[Any],
+        mappings: Iterable[Dict[str, Any]],
+        return_defaults: bool = False,
+        render_nulls: bool = False,
+    ) -> None:
         r"""Perform a bulk insert of the given list of mapping dictionaries.
 
         .. container:: class_bases
@@ -1221,7 +1291,9 @@ class scoped_session(ScopedSessionMixin):
             render_nulls=render_nulls,
         )
 
-    def bulk_update_mappings(self, mapper, mappings):
+    def bulk_update_mappings(
+        self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]]
+    ) -> None:
         r"""Perform a bulk update of the given list of mapping dictionaries.
 
         .. container:: class_bases
@@ -1287,7 +1359,13 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.bulk_update_mappings(mapper, mappings)
 
-    def merge(self, instance, load=True, options=None):
+    def merge(
+        self,
+        instance: _O,
+        *,
+        load: bool = True,
+        options: Optional[Sequence[ORMOption]] = None,
+    ) -> _O:
         r"""Copy the state of a given instance into a corresponding instance
         within this :class:`.Session`.
 
@@ -1355,7 +1433,9 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.merge(instance, load=load, options=options)
 
-    def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query:
+    def query(
+        self, *entities: _ColumnsClauseArgument, **kwargs: Any
+    ) -> Query[Any]:
         r"""Return a new :class:`_query.Query` object corresponding to this
         :class:`_orm.Session`.
 
@@ -1381,7 +1461,12 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.query(*entities, **kwargs)
 
-    def refresh(self, instance, attribute_names=None, with_for_update=None):
+    def refresh(
+        self,
+        instance: object,
+        attribute_names: Optional[Iterable[str]] = None,
+        with_for_update: Optional[ForUpdateArg] = None,
+    ) -> None:
         r"""Expire and refresh attributes on the given instance.
 
         .. container:: class_bases
@@ -1452,7 +1537,7 @@ class scoped_session(ScopedSessionMixin):
             with_for_update=with_for_update,
         )
 
-    def rollback(self):
+    def rollback(self) -> None:
         r"""Rollback the current transaction in progress.
 
         .. container:: class_bases
@@ -1479,12 +1564,12 @@ class scoped_session(ScopedSessionMixin):
 
     def scalar(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[Dict[str, Any]] = None,
+        **kw: Any,
+    ) -> Any:
         r"""Execute a statement and return a scalar result.
 
         .. container:: class_bases
@@ -1509,12 +1594,12 @@ class scoped_session(ScopedSessionMixin):
 
     def scalars(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[Dict[str, Any]] = None,
+        **kw: Any,
+    ) -> ScalarResult[Any]:
         r"""Execute a statement and return the results as scalars.
 
         .. container:: class_bases
@@ -1615,7 +1700,7 @@ class scoped_session(ScopedSessionMixin):
         return self._proxied.new
 
     @property
-    def identity_map(self) -> identity.IdentityMap:
+    def identity_map(self) -> IdentityMap:
         r"""Proxy for the :attr:`_orm.Session.identity_map` attribute
         on behalf of the :class:`_orm.scoping.scoped_session` class.
 
@@ -1624,7 +1709,7 @@ class scoped_session(ScopedSessionMixin):
         return self._proxied.identity_map
 
     @identity_map.setter
-    def identity_map(self, attr: identity.IdentityMap) -> None:
+    def identity_map(self, attr: IdentityMap) -> None:
         self._proxied.identity_map = attr
 
     @property
@@ -1726,19 +1811,6 @@ class scoped_session(ScopedSessionMixin):
 
         return self._proxied.info
 
-    @property
-    def autocommit(self) -> Any:
-        r"""Proxy for the :attr:`_orm.Session.autocommit` attribute
-        on behalf of the :class:`_orm.scoping.scoped_session` class.
-
-        """  # noqa: E501
-
-        return self._proxied.autocommit
-
-    @autocommit.setter
-    def autocommit(self, attr: Any) -> None:
-        self._proxied.autocommit = attr
-
     @classmethod
     def close_all(cls) -> None:
         r"""Close *all* sessions in memory.
@@ -1755,7 +1827,7 @@ class scoped_session(ScopedSessionMixin):
         return Session.close_all()
 
     @classmethod
-    def object_session(cls, instance: Any) -> "Session":
+    def object_session(cls, instance: object) -> Optional[Session]:
         r"""Return the :class:`.Session` to which an object belongs.
 
         .. container:: class_bases
@@ -1773,13 +1845,13 @@ class scoped_session(ScopedSessionMixin):
     @classmethod
     def identity_key(
         cls,
-        class_=None,
-        ident=None,
+        class_: Optional[Type[Any]] = None,
+        ident: Union[Any, Tuple[Any, ...]] = None,
         *,
-        instance=None,
-        row=None,
-        identity_token=None,
-    ) -> _IdentityKeyType:
+        instance: Optional[Any] = None,
+        row: Optional[Row] = None,
+        identity_token: Optional[Any] = None,
+    ) -> _IdentityKeyType[Any]:
         r"""Return an identity key.
 
         .. container:: class_bases
index 77a97936b2be1631de614df9a189ebe50fd37778..55ce73cf54ca190c58c75e2eece549fa9ba2fd43 100644 (file)
@@ -13,12 +13,20 @@ import itertools
 import sys
 import typing
 from typing import Any
+from typing import Callable
+from typing import cast
 from typing import Dict
+from typing import Iterable
+from typing import Iterator
 from typing import List
+from typing import NoReturn
 from typing import Optional
-from typing import overload
+from typing import Sequence
+from typing import Set
 from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 import weakref
 
@@ -30,14 +38,20 @@ from . import loading
 from . import persistence
 from . import query
 from . import state as statelib
+from ._typing import is_composite_class
+from ._typing import is_user_defined_option
 from .base import _class_to_mapper
-from .base import _IdentityKeyType
 from .base import _none_set
 from .base import _state_mapper
 from .base import instance_str
+from .base import LoaderCallableStatus
 from .base import object_mapper
 from .base import object_state
+from .base import PassiveFlag
 from .base import state_str
+from .context import FromStatement
+from .context import ORMCompileState
+from .identity import IdentityMap
 from .query import Query
 from .state import InstanceState
 from .state_changes import _StateChange
@@ -51,22 +65,41 @@ from .. import util
 from ..engine import Connection
 from ..engine import Engine
 from ..engine.util import TransactionalContext
+from ..event import dispatcher
+from ..event import EventTarget
 from ..inspection import inspect
 from ..sql import coercions
 from ..sql import dml
 from ..sql import roles
+from ..sql import Select
 from ..sql import visitors
 from ..sql.base import CompileState
+from ..sql.selectable import ForUpdateArg
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import IdentitySet
 from ..util.typing import Literal
+from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from ._typing import _IdentityKeyType
+    from ._typing import _InstanceDict
+    from .interfaces import ORMOption
+    from .interfaces import UserDefinedOption
     from .mapper import Mapper
+    from .path_registry import PathRegistry
+    from ..engine import Result
     from ..engine import Row
+    from ..engine.base import Transaction
+    from ..engine.base import TwoPhaseTransaction
+    from ..engine.interfaces import _CoreAnyExecuteParams
+    from ..engine.interfaces import _CoreSingleExecuteParams
+    from ..engine.interfaces import _ExecuteOptions
+    from ..engine.interfaces import _ExecuteOptionsParameter
+    from ..engine.result import ScalarResult
+    from ..event import _InstanceLevelDispatch
     from ..sql._typing import _ColumnsClauseArgument
-    from ..sql._typing import _ExecuteOptions
-    from ..sql._typing import _ExecuteParams
     from ..sql.base import Executable
+    from ..sql.elements import ClauseElement
     from ..sql.schema import Table
 
 __all__ = [
@@ -80,14 +113,45 @@ __all__ = [
     "object_session",
 ]
 
-_sessions = weakref.WeakValueDictionary()
+_sessions: weakref.WeakValueDictionary[
+    int, Session
+] = weakref.WeakValueDictionary()
 """Weak-referencing dictionary of :class:`.Session` objects.
 """
 
+_O = TypeVar("_O", bound=object)
 statelib._sessions = _sessions
 
+_PKIdentityArgument = Union[Any, Tuple[Any, ...]]
 
-def _state_session(state):
+_EntityBindKey = Union[Type[_O], "Mapper[_O]"]
+_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"]
+_SessionBind = Union["Engine", "Connection"]
+
+
+class _ConnectionCallableProto(Protocol):
+    """a callable that returns a :class:`.Connection` given an instance.
+
+    This callable, when present on a :class:`.Session`, is called only from the
+    ORM's persistence mechanism (i.e. the unit of work flush process) to allow
+    for connection-per-instance schemes (i.e. horizontal sharding) to be used
+    as persistence time.
+
+    This callable is not present on a plain :class:`.Session`, however
+    is established when using the horizontal sharding extension.
+
+    """
+
+    def __call__(
+        self,
+        mapper: Optional[Mapper[Any]] = None,
+        instance: Optional[object] = None,
+        **kw: Any,
+    ) -> Connection:
+        ...
+
+
+def _state_session(state: InstanceState[Any]) -> Optional[Session]:
     """Given an :class:`.InstanceState`, return the :class:`.Session`
     associated, if any.
     """
@@ -109,40 +173,17 @@ class _SessionClassMethods:
 
         close_all_sessions()
 
-    @classmethod
-    @overload
-    def identity_key(
-        cls,
-        class_: type,
-        ident: Tuple[Any, ...],
-        *,
-        identity_token: Optional[str],
-    ) -> _IdentityKeyType:
-        ...
-
-    @classmethod
-    @overload
-    def identity_key(cls, *, instance: Any) -> _IdentityKeyType:
-        ...
-
-    @classmethod
-    @overload
-    def identity_key(
-        cls, class_: type, *, row: "Row", identity_token: Optional[str]
-    ) -> _IdentityKeyType:
-        ...
-
     @classmethod
     @util.preload_module("sqlalchemy.orm.util")
     def identity_key(
         cls,
-        class_=None,
-        ident=None,
+        class_: Optional[Type[Any]] = None,
+        ident: Union[Any, Tuple[Any, ...]] = None,
         *,
-        instance=None,
-        row=None,
-        identity_token=None,
-    ) -> _IdentityKeyType:
+        instance: Optional[Any] = None,
+        row: Optional[Row] = None,
+        identity_token: Optional[Any] = None,
+    ) -> _IdentityKeyType[Any]:
         """Return an identity key.
 
         This is an alias of :func:`.util.identity_key`.
@@ -157,7 +198,7 @@ class _SessionClassMethods:
         )
 
     @classmethod
-    def object_session(cls, instance: Any) -> "Session":
+    def object_session(cls, instance: object) -> Optional[Session]:
         """Return the :class:`.Session` to which an object belongs.
 
         This is an alias of :func:`.object_session`.
@@ -205,26 +246,26 @@ class ORMExecuteState(util.MemoizedSlots):
         "_update_execution_options",
     )
 
-    session: "Session"
-    statement: "Executable"
-    parameters: "_ExecuteParams"
-    execution_options: "_ExecuteOptions"
-    local_execution_options: "_ExecuteOptions"
+    session: Session
+    statement: Executable
+    parameters: Optional[_CoreAnyExecuteParams]
+    execution_options: _ExecuteOptions
+    local_execution_options: _ExecuteOptions
     bind_arguments: Dict[str, Any]
-    _compile_state_cls: Type[context.ORMCompileState]
-    _starting_event_idx: Optional[int]
+    _compile_state_cls: Optional[Type[ORMCompileState]]
+    _starting_event_idx: int
     _events_todo: List[Any]
-    _update_execution_options: Optional["_ExecuteOptions"]
+    _update_execution_options: Optional[_ExecuteOptions]
 
     def __init__(
         self,
-        session: "Session",
-        statement: "Executable",
-        parameters: "_ExecuteParams",
-        execution_options: "_ExecuteOptions",
+        session: Session,
+        statement: Executable,
+        parameters: Optional[_CoreAnyExecuteParams],
+        execution_options: _ExecuteOptions,
         bind_arguments: Dict[str, Any],
-        compile_state_cls: Type[context.ORMCompileState],
-        events_todo: List[Any],
+        compile_state_cls: Optional[Type[ORMCompileState]],
+        events_todo: List[_InstanceLevelDispatch[Session]],
     ):
         self.session = session
         self.statement = statement
@@ -237,16 +278,16 @@ class ORMExecuteState(util.MemoizedSlots):
         self._compile_state_cls = compile_state_cls
         self._events_todo = list(events_todo)
 
-    def _remaining_events(self):
+    def _remaining_events(self) -> List[_InstanceLevelDispatch[Session]]:
         return self._events_todo[self._starting_event_idx + 1 :]
 
     def invoke_statement(
         self,
-        statement=None,
-        params=None,
-        execution_options=None,
-        bind_arguments=None,
-    ):
+        statement: Optional[Executable] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+        bind_arguments: Optional[Dict[str, Any]] = None,
+    ) -> Result:
         """Execute the statement represented by this
         :class:`.ORMExecuteState`, without re-invoking events that have
         already proceeded.
@@ -270,9 +311,12 @@ class ORMExecuteState(util.MemoizedSlots):
         :param statement: optional statement to be invoked, in place of the
          statement currently represented by :attr:`.ORMExecuteState.statement`.
 
-        :param params: optional dictionary of parameters which will be merged
-         into the existing :attr:`.ORMExecuteState.parameters` of this
-         :class:`.ORMExecuteState`.
+        :param params: optional dictionary of parameters or list of parameters
+         which will be merged into the existing
+         :attr:`.ORMExecuteState.parameters` of this :class:`.ORMExecuteState`.
+
+         .. versionchanged:: 2.0 a list of parameter dictionaries is accepted
+            for executemany executions.
 
         :param execution_options: optional dictionary of execution options
          will be merged into the existing
@@ -302,9 +346,32 @@ class ORMExecuteState(util.MemoizedSlots):
             _bind_arguments.update(bind_arguments)
         _bind_arguments["_sa_skip_events"] = True
 
+        _params: Optional[_CoreAnyExecuteParams]
         if params:
-            _params = dict(self.parameters)
-            _params.update(params)
+            if self.is_executemany:
+                _params = []
+                exec_many_parameters = cast(
+                    "List[Dict[str, Any]]", self.parameters
+                )
+                for _existing_params, _new_params in itertools.zip_longest(
+                    exec_many_parameters,
+                    cast("List[Dict[str, Any]]", params),
+                ):
+                    if _existing_params is None or _new_params is None:
+                        raise sa_exc.InvalidRequestError(
+                            f"Can't apply executemany parameters to "
+                            f"statement; number of parameter sets passed to "
+                            f"Session.execute() ({len(exec_many_parameters)}) "
+                            f"does not match number of parameter sets given "
+                            f"to ORMExecuteState.invoke_statement() "
+                            f"({len(params)})"
+                        )
+                    _existing_params = dict(_existing_params)
+                    _existing_params.update(_new_params)
+                    _params.append(_existing_params)
+            else:
+                _params = dict(cast("Dict[str, Any]", self.parameters))
+                _params.update(cast("Dict[str, Any]", params))
         else:
             _params = self.parameters
 
@@ -321,7 +388,7 @@ class ORMExecuteState(util.MemoizedSlots):
         )
 
     @property
-    def bind_mapper(self):
+    def bind_mapper(self) -> Optional[Mapper[Any]]:
         """Return the :class:`_orm.Mapper` that is the primary "bind" mapper.
 
         For an :class:`_orm.ORMExecuteState` object invoking an ORM
@@ -349,7 +416,7 @@ class ORMExecuteState(util.MemoizedSlots):
         return self.bind_arguments.get("mapper", None)
 
     @property
-    def all_mappers(self):
+    def all_mappers(self) -> Sequence[Mapper[Any]]:
         """Return a sequence of all :class:`_orm.Mapper` objects that are
         involved at the top level of this statement.
 
@@ -369,7 +436,7 @@ class ORMExecuteState(util.MemoizedSlots):
         """
         if not self.is_orm_statement:
             return []
-        elif self.is_select:
+        elif isinstance(self.statement, (Select, FromStatement)):
             result = []
             seen = set()
             for d in self.statement.column_descriptions:
@@ -380,13 +447,13 @@ class ORMExecuteState(util.MemoizedSlots):
                         seen.add(insp.mapper)
                         result.append(insp.mapper)
             return result
-        elif self.is_update or self.is_delete:
+        elif self.statement.is_dml and self.bind_mapper:
             return [self.bind_mapper]
         else:
             return []
 
     @property
-    def is_orm_statement(self):
+    def is_orm_statement(self) -> bool:
         """return True if the operation is an ORM statement.
 
         This indicates that the select(), update(), or delete() being
@@ -399,44 +466,64 @@ class ORMExecuteState(util.MemoizedSlots):
         return self._compile_state_cls is not None
 
     @property
-    def is_select(self):
+    def is_executemany(self) -> bool:
+        """return True if the parameters are a multi-element list of
+        dictionaries with more than one dictionary.
+
+        .. versionadded:: 2.0
+
+        """
+        return isinstance(self.parameters, list)
+
+    @property
+    def is_select(self) -> bool:
         """return True if this is a SELECT operation."""
         return self.statement.is_select
 
     @property
-    def is_insert(self):
+    def is_insert(self) -> bool:
         """return True if this is an INSERT operation."""
         return self.statement.is_dml and self.statement.is_insert
 
     @property
-    def is_update(self):
+    def is_update(self) -> bool:
         """return True if this is an UPDATE operation."""
         return self.statement.is_dml and self.statement.is_update
 
     @property
-    def is_delete(self):
+    def is_delete(self) -> bool:
         """return True if this is a DELETE operation."""
         return self.statement.is_dml and self.statement.is_delete
 
     @property
-    def _is_crud(self):
+    def _is_crud(self) -> bool:
         return isinstance(self.statement, (dml.Update, dml.Delete))
 
-    def update_execution_options(self, **opts):
+    def update_execution_options(self, **opts: _ExecuteOptions) -> None:
+        """Update the local execution options with new values."""
         # TODO: no coverage
         self.local_execution_options = self.local_execution_options.union(opts)
 
-    def _orm_compile_options(self):
+    def _orm_compile_options(
+        self,
+    ) -> Optional[
+        Union[
+            context.ORMCompileState.default_compile_options,
+            Type[context.ORMCompileState.default_compile_options],
+        ]
+    ]:
         if not self.is_select:
             return None
         opts = self.statement._compile_options
-        if opts.isinstance(context.ORMCompileState.default_compile_options):
-            return opts
+        if opts is not None and opts.isinstance(
+            context.ORMCompileState.default_compile_options
+        ):
+            return opts  # type: ignore
         else:
             return None
 
     @property
-    def lazy_loaded_from(self):
+    def lazy_loaded_from(self) -> Optional[InstanceState[Any]]:
         """An :class:`.InstanceState` that is using this statement execution
         for a lazy load operation.
 
@@ -451,7 +538,7 @@ class ORMExecuteState(util.MemoizedSlots):
         return self.load_options._lazy_loaded_from
 
     @property
-    def loader_strategy_path(self):
+    def loader_strategy_path(self) -> Optional[PathRegistry]:
         """Return the :class:`.PathRegistry` for the current load path.
 
         This object represents the "path" in a query along relationships
@@ -465,7 +552,7 @@ class ORMExecuteState(util.MemoizedSlots):
             return None
 
     @property
-    def is_column_load(self):
+    def is_column_load(self) -> bool:
         """Return True if the operation is refreshing column-oriented
         attributes on an existing ORM object.
 
@@ -492,7 +579,7 @@ class ORMExecuteState(util.MemoizedSlots):
         return opts is not None and opts._for_refresh_state
 
     @property
-    def is_relationship_load(self):
+    def is_relationship_load(self) -> bool:
         """Return True if this load is loading objects on behalf of a
         relationship.
 
@@ -518,7 +605,12 @@ class ORMExecuteState(util.MemoizedSlots):
         return path is not None and not path.is_root
 
     @property
-    def load_options(self):
+    def load_options(
+        self,
+    ) -> Union[
+        context.QueryContext.default_load_options,
+        Type[context.QueryContext.default_load_options],
+    ]:
         """Return the load_options that will be used for this execution."""
 
         if not self.is_select:
@@ -531,7 +623,12 @@ class ORMExecuteState(util.MemoizedSlots):
         )
 
     @property
-    def update_delete_options(self):
+    def update_delete_options(
+        self,
+    ) -> Union[
+        persistence.BulkUDCompileState.default_update_options,
+        Type[persistence.BulkUDCompileState.default_update_options],
+    ]:
         """Return the update_delete_options that will be used for this
         execution."""
 
@@ -546,7 +643,7 @@ class ORMExecuteState(util.MemoizedSlots):
         )
 
     @property
-    def user_defined_options(self):
+    def user_defined_options(self) -> Sequence[UserDefinedOption]:
         """The sequence of :class:`.UserDefinedOptions` that have been
         associated with the statement being invoked.
 
@@ -554,7 +651,7 @@ class ORMExecuteState(util.MemoizedSlots):
         return [
             opt
             for opt in self.statement._with_options
-            if not opt._is_compile_state and not opt._is_legacy_option
+            if is_user_defined_option(opt)
         ]
 
 
@@ -597,14 +694,29 @@ class SessionTransaction(_StateChange, TransactionalContext):
 
     """
 
-    _rollback_exception = None
+    _rollback_exception: Optional[BaseException] = None
+
+    _connections: Dict[
+        Union[Engine, Connection], Tuple[Connection, Transaction, bool, bool]
+    ]
+    session: Session
+    _parent: Optional[SessionTransaction]
+
+    _state: SessionTransactionState
+
+    _new: weakref.WeakKeyDictionary[InstanceState[Any], object]
+    _deleted: weakref.WeakKeyDictionary[InstanceState[Any], object]
+    _dirty: weakref.WeakKeyDictionary[InstanceState[Any], object]
+    _key_switches: weakref.WeakKeyDictionary[
+        InstanceState[Any], Tuple[Any, Any]
+    ]
 
     def __init__(
         self,
-        session,
-        parent=None,
-        nested=False,
-        autobegin=False,
+        session: Session,
+        parent: Optional[SessionTransaction] = None,
+        nested: bool = False,
+        autobegin: bool = False,
     ):
         TransactionalContext._trans_ctx_check(session)
 
@@ -629,7 +741,9 @@ class SessionTransaction(_StateChange, TransactionalContext):
 
         self.session.dispatch.after_transaction_create(self.session, self)
 
-    def _raise_for_prerequisite_state(self, operation_name, state):
+    def _raise_for_prerequisite_state(
+        self, operation_name: str, state: SessionTransactionState
+    ) -> NoReturn:
         if state is SessionTransactionState.DEACTIVE:
             if self._rollback_exception:
                 raise sa_exc.PendingRollbackError(
@@ -655,7 +769,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
             )
 
     @property
-    def parent(self):
+    def parent(self) -> Optional[SessionTransaction]:
         """The parent :class:`.SessionTransaction` of this
         :class:`.SessionTransaction`.
 
@@ -673,7 +787,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
         """
         return self._parent
 
-    nested = False
+    nested: bool = False
     """Indicates if this is a nested, or SAVEPOINT, transaction.
 
     When :attr:`.SessionTransaction.nested` is True, it is expected
@@ -682,33 +796,40 @@ class SessionTransaction(_StateChange, TransactionalContext):
     """
 
     @property
-    def is_active(self):
+    def is_active(self) -> bool:
         return (
             self.session is not None
             and self._state is SessionTransactionState.ACTIVE
         )
 
     @property
-    def _is_transaction_boundary(self):
+    def _is_transaction_boundary(self) -> bool:
         return self.nested or not self._parent
 
     @_StateChange.declare_states(
         (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
     )
-    def connection(self, bindkey, execution_options=None, **kwargs):
+    def connection(
+        self,
+        bindkey: Optional[Mapper[Any]],
+        execution_options: Optional[_ExecuteOptions] = None,
+        **kwargs: Any,
+    ) -> Connection:
         bind = self.session.get_bind(bindkey, **kwargs)
         return self._connection_for_bind(bind, execution_options)
 
     @_StateChange.declare_states(
         (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
     )
-    def _begin(self, nested=False):
+    def _begin(self, nested: bool = False) -> SessionTransaction:
         return SessionTransaction(self.session, self, nested=nested)
 
-    def _iterate_self_and_parents(self, upto=None):
+    def _iterate_self_and_parents(
+        self, upto: Optional[SessionTransaction] = None
+    ) -> Iterable[SessionTransaction]:
 
         current = self
-        result = ()
+        result: Tuple[SessionTransaction, ...] = ()
         while current:
             result += (current,)
             if current._parent is upto:
@@ -723,12 +844,14 @@ class SessionTransaction(_StateChange, TransactionalContext):
 
         return result
 
-    def _take_snapshot(self, autobegin=False):
+    def _take_snapshot(self, autobegin: bool = False) -> None:
         if not self._is_transaction_boundary:
-            self._new = self._parent._new
-            self._deleted = self._parent._deleted
-            self._dirty = self._parent._dirty
-            self._key_switches = self._parent._key_switches
+            parent = self._parent
+            assert parent is not None
+            self._new = parent._new
+            self._deleted = parent._deleted
+            self._dirty = parent._dirty
+            self._key_switches = parent._key_switches
             return
 
         if not autobegin and not self.session._flushing:
@@ -739,7 +862,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
         self._dirty = weakref.WeakKeyDictionary()
         self._key_switches = weakref.WeakKeyDictionary()
 
-    def _restore_snapshot(self, dirty_only=False):
+    def _restore_snapshot(self, dirty_only: bool = False) -> None:
         """Restore the restoration state taken before a transaction began.
 
         Corresponds to a rollback.
@@ -771,7 +894,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
             if not dirty_only or s.modified or s in self._dirty:
                 s._expire(s.dict, self.session.identity_map._modified)
 
-    def _remove_snapshot(self):
+    def _remove_snapshot(self) -> None:
         """Remove the restoration state taken before a transaction began.
 
         Corresponds to a commit.
@@ -788,15 +911,21 @@ class SessionTransaction(_StateChange, TransactionalContext):
             )
             self._deleted.clear()
         elif self.nested:
-            self._parent._new.update(self._new)
-            self._parent._dirty.update(self._dirty)
-            self._parent._deleted.update(self._deleted)
-            self._parent._key_switches.update(self._key_switches)
+            parent = self._parent
+            assert parent is not None
+            parent._new.update(self._new)
+            parent._dirty.update(self._dirty)
+            parent._deleted.update(self._deleted)
+            parent._key_switches.update(self._key_switches)
 
     @_StateChange.declare_states(
         (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
     )
-    def _connection_for_bind(self, bind, execution_options):
+    def _connection_for_bind(
+        self,
+        bind: _SessionBind,
+        execution_options: Optional[_ExecuteOptions],
+    ) -> Connection:
 
         if bind in self._connections:
             if execution_options:
@@ -829,6 +958,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
             if execution_options:
                 conn = conn.execution_options(**execution_options)
 
+            transaction: Transaction
             if self.session.twophase and self._parent is None:
                 transaction = conn.begin_twophase()
             elif self.nested:
@@ -837,9 +967,9 @@ class SessionTransaction(_StateChange, TransactionalContext):
                 # if given a future connection already in a transaction, don't
                 # commit that transaction unless it is a savepoint
                 if conn.in_nested_transaction():
-                    transaction = conn.get_nested_transaction()
+                    transaction = conn._get_required_nested_transaction()
                 else:
-                    transaction = conn.get_transaction()
+                    transaction = conn._get_required_transaction()
                     should_commit = False
             else:
                 transaction = conn.begin()
@@ -861,7 +991,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
             self.session.dispatch.after_begin(self.session, self, conn)
             return conn
 
-    def prepare(self):
+    def prepare(self) -> None:
         if self._parent is not None or not self.session.twophase:
             raise sa_exc.InvalidRequestError(
                 "'twophase' mode not enabled, or not root transaction; "
@@ -872,12 +1002,13 @@ class SessionTransaction(_StateChange, TransactionalContext):
     @_StateChange.declare_states(
         (SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED
     )
-    def _prepare_impl(self):
+    def _prepare_impl(self) -> None:
 
         if self._parent is None or self.nested:
             self.session.dispatch.before_commit(self.session)
 
         stx = self.session._transaction
+        assert stx is not None
         if stx is not self:
             for subtransaction in stx._iterate_self_and_parents(upto=self):
                 subtransaction.commit()
@@ -897,7 +1028,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
         if self._parent is None and self.session.twophase:
             try:
                 for t in set(self._connections.values()):
-                    t[1].prepare()
+                    cast("TwoPhaseTransaction", t[1]).prepare()
             except:
                 with util.safe_reraise():
                     self.rollback()
@@ -929,9 +1060,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
             self.close()
 
         if _to_root and self._parent:
-            return self._parent.commit(_to_root=True)
-
-        return self._parent
+            self._parent.commit(_to_root=True)
 
     @_StateChange.declare_states(
         (
@@ -941,9 +1070,12 @@ class SessionTransaction(_StateChange, TransactionalContext):
         ),
         SessionTransactionState.CLOSED,
     )
-    def rollback(self, _capture_exception=False, _to_root=False):
+    def rollback(
+        self, _capture_exception: bool = False, _to_root: bool = False
+    ) -> None:
 
         stx = self.session._transaction
+        assert stx is not None
         if stx is not self:
             for subtransaction in stx._iterate_self_and_parents(upto=self):
                 subtransaction.close()
@@ -993,19 +1125,18 @@ class SessionTransaction(_StateChange, TransactionalContext):
         if self._parent and _capture_exception:
             self._parent._rollback_exception = sys.exc_info()[1]
 
-        if rollback_err:
+        if rollback_err and rollback_err[1]:
             raise rollback_err[1].with_traceback(rollback_err[2])
 
         sess.dispatch.after_soft_rollback(sess, self)
 
         if _to_root and self._parent:
-            return self._parent.rollback(_to_root=True)
-        return self._parent
+            self._parent.rollback(_to_root=True)
 
     @_StateChange.declare_states(
         _StateChangeStates.ANY, SessionTransactionState.CLOSED
     )
-    def close(self, invalidate=False):
+    def close(self, invalidate: bool = False) -> None:
         if self.nested:
             self.session._nested_transaction = (
                 self._previous_nested_transaction
@@ -1027,25 +1158,30 @@ class SessionTransaction(_StateChange, TransactionalContext):
         self._state = SessionTransactionState.CLOSED
         sess = self.session
 
-        self.session = None
-        self._connections = None
+        # TODO: these two None sets were historically after the
+        # event hook below, and in 2.0 I changed it this way for some reason,
+        # and I remember there being a reason, but not what it was.
+        # Why do we need to get rid of them at all?  test_memusage::CycleTest
+        # passes with these commented out.
+        # self.session = None  # type: ignore
+        # self._connections = None  # type: ignore
 
         sess.dispatch.after_transaction_end(sess, self)
 
-    def _get_subject(self):
+    def _get_subject(self) -> Session:
         return self.session
 
-    def _transaction_is_active(self):
+    def _transaction_is_active(self) -> bool:
         return self._state is SessionTransactionState.ACTIVE
 
-    def _transaction_is_closed(self):
+    def _transaction_is_closed(self) -> bool:
         return self._state is SessionTransactionState.CLOSED
 
-    def _rollback_can_be_called(self):
+    def _rollback_can_be_called(self) -> bool:
         return self._state not in (COMMITTED, CLOSED)
 
 
-class Session(_SessionClassMethods):
+class Session(_SessionClassMethods, EventTarget):
     """Manages persistence operations for ORM-mapped objects.
 
     The Session's usage paradigm is described at :doc:`/orm/session`.
@@ -1055,15 +1191,27 @@ class Session(_SessionClassMethods):
 
     _is_asyncio = False
 
-    identity_map: identity.IdentityMap
-    _new: Dict["InstanceState", Any]
-    _deleted: Dict["InstanceState", Any]
+    dispatch: dispatcher[Session]
+
+    identity_map: IdentityMap
+    """A mapping of object identities to objects themselves.
+
+    Iterating through ``Session.identity_map.values()`` provides
+    access to the full set of persistent objects (i.e., those
+    that have row identity) currently in the session.
+
+    .. seealso::
+
+        :func:`.identity_key` - helper function to produce the keys used
+        in this dictionary.
+
+    """
+
+    _new: Dict[InstanceState[Any], Any]
+    _deleted: Dict[InstanceState[Any], Any]
     bind: Optional[Union[Engine, Connection]]
-    __binds: Dict[
-        Union[type, "Mapper", "Table"],
-        Union[engine.Engine, engine.Connection],
-    ]
-    _flusing: bool
+    __binds: Dict[_SessionBindKey, _SessionBind]
+    _flushing: bool
     _warn_on_events: bool
     _transaction: Optional[SessionTransaction]
     _nested_transaction: Optional[SessionTransaction]
@@ -1072,24 +1220,19 @@ class Session(_SessionClassMethods):
     expire_on_commit: bool
     enable_baked_queries: bool
     twophase: bool
-    _query_cls: Type[Query]
+    _query_cls: Type[Query[Any]]
 
     def __init__(
         self,
-        bind: Optional[Union[engine.Engine, engine.Connection]] = None,
+        bind: Optional[_SessionBind] = None,
         autoflush: bool = True,
         future: Literal[True] = True,
         expire_on_commit: bool = True,
         twophase: bool = False,
-        binds: Optional[
-            Dict[
-                Union[type, "Mapper", "Table"],
-                Union[engine.Engine, engine.Connection],
-            ]
-        ] = None,
+        binds: Optional[Dict[_SessionBindKey, _SessionBind]] = None,
         enable_baked_queries: bool = True,
         info: Optional[Dict[Any, Any]] = None,
-        query_cls: Optional[Type[query.Query]] = None,
+        query_cls: Optional[Type[Query[Any]]] = None,
         autocommit: Literal[False] = False,
     ):
         r"""Construct a new Session.
@@ -1249,23 +1392,23 @@ class Session(_SessionClassMethods):
         _sessions[self.hash_key] = self
 
     # used by sqlalchemy.engine.util.TransactionalContext
-    _trans_context_manager = None
+    _trans_context_manager: Optional[TransactionalContext] = None
 
-    connection_callable = None
+    connection_callable: Optional[_ConnectionCallableProto] = None
 
-    def __enter__(self):
+    def __enter__(self) -> Session:
         return self
 
-    def __exit__(self, type_, value, traceback):
+    def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
         self.close()
 
     @contextlib.contextmanager
-    def _maker_context_manager(self):
+    def _maker_context_manager(self) -> Iterator[Session]:
         with self:
             with self.begin():
                 yield self
 
-    def in_transaction(self):
+    def in_transaction(self) -> bool:
         """Return True if this :class:`_orm.Session` has begun a transaction.
 
         .. versionadded:: 1.4
@@ -1278,7 +1421,7 @@ class Session(_SessionClassMethods):
         """
         return self._transaction is not None
 
-    def in_nested_transaction(self):
+    def in_nested_transaction(self) -> bool:
         """Return True if this :class:`_orm.Session` has begun a nested
         transaction, e.g. SAVEPOINT.
 
@@ -1287,7 +1430,7 @@ class Session(_SessionClassMethods):
         """
         return self._nested_transaction is not None
 
-    def get_transaction(self):
+    def get_transaction(self) -> Optional[SessionTransaction]:
         """Return the current root transaction in progress, if any.
 
         .. versionadded:: 1.4
@@ -1298,7 +1441,7 @@ class Session(_SessionClassMethods):
             trans = trans._parent
         return trans
 
-    def get_nested_transaction(self):
+    def get_nested_transaction(self) -> Optional[SessionTransaction]:
         """Return the current nested transaction in progress, if any.
 
         .. versionadded:: 1.4
@@ -1308,7 +1451,7 @@ class Session(_SessionClassMethods):
         return self._nested_transaction
 
     @util.memoized_property
-    def info(self):
+    def info(self) -> Dict[Any, Any]:
         """A user-modifiable dictionary.
 
         The initial value of this dictionary can be populated using the
@@ -1320,16 +1463,18 @@ class Session(_SessionClassMethods):
         """
         return {}
 
-    def _autobegin(self):
+    def _autobegin_t(self) -> SessionTransaction:
         if self._transaction is None:
 
             trans = SessionTransaction(self, autobegin=True)
             assert self._transaction is trans
-            return True
+            return trans
 
-        return False
+        return self._transaction
 
-    def begin(self, nested=False, _subtrans=False):
+    def begin(
+        self, nested: bool = False, _subtrans: bool = False
+    ) -> SessionTransaction:
         """Begin a transaction, or nested transaction,
         on this :class:`.Session`, if one is not already begun.
 
@@ -1364,13 +1509,16 @@ class Session(_SessionClassMethods):
 
         """
 
-        if self._autobegin():
+        trans = self._transaction
+        if trans is None:
+            trans = self._autobegin_t()
+
             if not nested and not _subtrans:
-                return self._transaction
+                return trans
 
-        if self._transaction is not None:
+        if trans is not None:
             if _subtrans or nested:
-                trans = self._transaction._begin(nested=nested)
+                trans = trans._begin(nested=nested)
                 assert self._transaction is trans
                 if nested:
                     self._nested_transaction = trans
@@ -1386,9 +1534,12 @@ class Session(_SessionClassMethods):
             trans = SessionTransaction(self)
             assert self._transaction is trans
 
-        return self._transaction  # needed for __enter__/__exit__ hook
+            if TYPE_CHECKING:
+                assert self._transaction is not None
+
+        return trans  # needed for __enter__/__exit__ hook
 
-    def begin_nested(self):
+    def begin_nested(self) -> SessionTransaction:
         """Begin a "nested" transaction on this Session, e.g. SAVEPOINT.
 
         The target database(s) and associated drivers must support SQL
@@ -1413,7 +1564,7 @@ class Session(_SessionClassMethods):
         """
         return self.begin(nested=True)
 
-    def rollback(self):
+    def rollback(self) -> None:
         """Rollback the current transaction in progress.
 
         If no transaction is in progress, this method is a pass-through.
@@ -1450,11 +1601,11 @@ class Session(_SessionClassMethods):
             :ref:`unitofwork_transaction`
 
         """
-        if self._transaction is None:
-            if not self._autobegin():
-                raise sa_exc.InvalidRequestError("No transaction is begun.")
+        trans = self._transaction
+        if trans is None:
+            trans = self._autobegin_t()
 
-        self._transaction.commit(_to_root=True)
+        trans.commit(_to_root=True)
 
     def prepare(self) -> None:
         """Prepare the current transaction in progress for two phase commit.
@@ -1467,16 +1618,16 @@ class Session(_SessionClassMethods):
         :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
 
         """
-        if self._transaction is None:
-            if not self._autobegin():
-                raise sa_exc.InvalidRequestError("No transaction is begun.")
+        trans = self._transaction
+        if trans is None:
+            trans = self._autobegin_t()
 
-        self._transaction.prepare()
+        trans.prepare()
 
     def connection(
         self,
         bind_arguments: Optional[Dict[str, Any]] = None,
-        execution_options: Optional["_ExecuteOptions"] = None,
+        execution_options: Optional[_ExecuteOptions] = None,
     ) -> "Connection":
         r"""Return a :class:`_engine.Connection` object corresponding to this
         :class:`.Session` object's transactional state.
@@ -1521,24 +1672,28 @@ class Session(_SessionClassMethods):
             execution_options=execution_options,
         )
 
-    def _connection_for_bind(self, engine, execution_options=None, **kw):
+    def _connection_for_bind(
+        self,
+        engine: _SessionBind,
+        execution_options: Optional[_ExecuteOptions] = None,
+        **kw: Any,
+    ) -> Connection:
         TransactionalContext._trans_ctx_check(self)
 
-        if self._transaction is None:
-            assert self._autobegin()
-        return self._transaction._connection_for_bind(
-            engine, execution_options
-        )
+        trans = self._transaction
+        if trans is None:
+            trans = self._autobegin_t()
+        return trans._connection_for_bind(engine, execution_options)
 
     def execute(
         self,
-        statement: "Executable",
-        params: Optional["_ExecuteParams"] = None,
-        execution_options: "_ExecuteOptions" = util.EMPTY_DICT,
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[Dict[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-    ):
+    ) -> Result:
         r"""Execute a SQL expression construct.
 
         Returns a :class:`_engine.Result` object representing
@@ -1603,6 +1758,8 @@ class Session(_SessionClassMethods):
             compile_state_cls = CompileState._get_plugin_class_for_plugin(
                 statement, "orm"
             )
+            if TYPE_CHECKING:
+                assert isinstance(compile_state_cls, ORMCompileState)
         else:
             compile_state_cls = None
 
@@ -1645,9 +1802,9 @@ class Session(_SessionClassMethods):
             )
             for idx, fn in enumerate(events_todo):
                 orm_exec_state._starting_event_idx = idx
-                result = fn(orm_exec_state)
-                if result:
-                    return result
+                fn_result: Optional[Result] = fn(orm_exec_state)
+                if fn_result:
+                    return fn_result
 
             statement = orm_exec_state.statement
             execution_options = orm_exec_state.local_execution_options
@@ -1655,7 +1812,9 @@ class Session(_SessionClassMethods):
         bind = self.get_bind(**bind_arguments)
 
         conn = self._connection_for_bind(bind)
-        result = conn.execute(statement, params or {}, execution_options)
+        result: Result = conn.execute(
+            statement, params or {}, execution_options
+        )
 
         if compile_state_cls:
             result = compile_state_cls.orm_setup_cursor_result(
@@ -1671,12 +1830,12 @@ class Session(_SessionClassMethods):
 
     def scalar(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[Dict[str, Any]] = None,
+        **kw: Any,
+    ) -> Any:
         """Execute a statement and return a scalar result.
 
         Usage and parameters are the same as that of
@@ -1695,12 +1854,12 @@ class Session(_SessionClassMethods):
 
     def scalars(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[Dict[str, Any]] = None,
+        **kw: Any,
+    ) -> ScalarResult[Any]:
         """Execute a statement and return the results as scalars.
 
         Usage and parameters are the same as that of
@@ -1722,7 +1881,7 @@ class Session(_SessionClassMethods):
             **kw,
         ).scalars()
 
-    def close(self):
+    def close(self) -> None:
         """Close out the transactional resources and ORM objects used by this
         :class:`_orm.Session`.
 
@@ -1754,7 +1913,7 @@ class Session(_SessionClassMethods):
         """
         self._close_impl(invalidate=False)
 
-    def invalidate(self):
+    def invalidate(self) -> None:
         """Close this Session, using connection invalidation.
 
         This is a variant of :meth:`.Session.close` that will additionally
@@ -1790,13 +1949,13 @@ class Session(_SessionClassMethods):
         """
         self._close_impl(invalidate=True)
 
-    def _close_impl(self, invalidate):
+    def _close_impl(self, invalidate: bool) -> None:
         self.expunge_all()
         if self._transaction is not None:
             for transaction in self._transaction._iterate_self_and_parents():
                 transaction.close(invalidate)
 
-    def expunge_all(self):
+    def expunge_all(self) -> None:
         """Remove all object instances from this ``Session``.
 
         This is equivalent to calling ``expunge(obj)`` on all objects in this
@@ -1812,7 +1971,7 @@ class Session(_SessionClassMethods):
 
         statelib.InstanceState._detach_states(all_states, self)
 
-    def _add_bind(self, key, bind):
+    def _add_bind(self, key: _SessionBindKey, bind: _SessionBind) -> None:
         try:
             insp = inspect(key)
         except sa_exc.NoInspectionAvailable as err:
@@ -1834,7 +1993,9 @@ class Session(_SessionClassMethods):
                     "Not an acceptable bind target: %s" % key
                 )
 
-    def bind_mapper(self, mapper, bind):
+    def bind_mapper(
+        self, mapper: _EntityBindKey[_O], bind: _SessionBind
+    ) -> None:
         """Associate a :class:`_orm.Mapper` or arbitrary Python class with a
         "bind", e.g. an :class:`_engine.Engine` or
         :class:`_engine.Connection`.
@@ -1862,7 +2023,7 @@ class Session(_SessionClassMethods):
         """
         self._add_bind(mapper, bind)
 
-    def bind_table(self, table, bind):
+    def bind_table(self, table: Table, bind: _SessionBind) -> None:
         """Associate a :class:`_schema.Table` with a "bind", e.g. an
         :class:`_engine.Engine`
         or :class:`_engine.Connection`.
@@ -1892,12 +2053,12 @@ class Session(_SessionClassMethods):
 
     def get_bind(
         self,
-        mapper=None,
-        clause=None,
-        bind=None,
-        _sa_skip_events=None,
-        _sa_skip_for_implicit_returning=False,
-    ):
+        mapper: Optional[_EntityBindKey[_O]] = None,
+        clause: Optional[ClauseElement] = None,
+        bind: Optional[_SessionBind] = None,
+        _sa_skip_events: Optional[bool] = None,
+        _sa_skip_for_implicit_returning: bool = False,
+    ) -> Union[Engine, Connection]:
         """Return a "bind" to which this :class:`.Session` is bound.
 
         The "bind" is usually an instance of :class:`_engine.Engine`,
@@ -1995,23 +2156,25 @@ class Session(_SessionClassMethods):
         # look more closely at the mapper.
         if mapper is not None:
             try:
-                mapper = inspect(mapper)
+                inspected_mapper = inspect(mapper)
             except sa_exc.NoInspectionAvailable as err:
                 if isinstance(mapper, type):
                     raise exc.UnmappedClassError(mapper) from err
                 else:
                     raise
+        else:
+            inspected_mapper = None
 
         # match up the mapper or clause in the __binds
         if self.__binds:
             # matching mappers and selectables to entries in the
             # binds dictionary; supported use case.
-            if mapper:
-                for cls in mapper.class_.__mro__:
+            if inspected_mapper:
+                for cls in inspected_mapper.class_.__mro__:
                     if cls in self.__binds:
                         return self.__binds[cls]
                 if clause is None:
-                    clause = mapper.persist_selectable
+                    clause = inspected_mapper.persist_selectable
 
             if clause is not None:
                 plugin_subject = clause._propagate_attrs.get(
@@ -2025,6 +2188,8 @@ class Session(_SessionClassMethods):
 
                 for obj in visitors.iterate(clause):
                     if obj in self.__binds:
+                        if TYPE_CHECKING:
+                            assert isinstance(obj, Table)
                         return self.__binds[obj]
 
         # none of the __binds matched, but we have a fallback bind.
@@ -2033,17 +2198,19 @@ class Session(_SessionClassMethods):
             return self.bind
 
         context = []
-        if mapper is not None:
-            context.append("mapper %s" % mapper)
+        if inspected_mapper is not None:
+            context.append(f"mapper {inspected_mapper}")
         if clause is not None:
             context.append("SQL expression")
 
         raise sa_exc.UnboundExecutionError(
-            "Could not locate a bind configured on %s or this Session."
-            % (", ".join(context),),
+            f"Could not locate a bind configured on "
+            f'{", ".join(context)} or this Session.'
         )
 
-    def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query:
+    def query(
+        self, *entities: _ColumnsClauseArgument, **kwargs: Any
+    ) -> Query[Any]:
         """Return a new :class:`_query.Query` object corresponding to this
         :class:`_orm.Session`.
 
@@ -2065,12 +2232,12 @@ class Session(_SessionClassMethods):
 
     def _identity_lookup(
         self,
-        mapper,
-        primary_key_identity,
-        identity_token=None,
-        passive=attributes.PASSIVE_OFF,
-        lazy_loaded_from=None,
-    ):
+        mapper: Mapper[_O],
+        primary_key_identity: Union[Any, Tuple[Any, ...]],
+        identity_token: Any = None,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+        lazy_loaded_from: Optional[InstanceState[Any]] = None,
+    ) -> Union[Optional[_O], LoaderCallableStatus]:
         """Locate an object in the identity map.
 
         Given a primary key identity, constructs an identity key and then
@@ -2117,9 +2284,9 @@ class Session(_SessionClassMethods):
         )
         return loading.get_from_identity(self, mapper, key, passive)
 
-    @property
+    @util.non_memoized_property
     @contextlib.contextmanager
-    def no_autoflush(self):
+    def no_autoflush(self) -> Iterator[Session]:
         """Return a context manager that disables autoflush.
 
         e.g.::
@@ -2145,7 +2312,7 @@ class Session(_SessionClassMethods):
         finally:
             self.autoflush = autoflush
 
-    def _autoflush(self):
+    def _autoflush(self) -> None:
         if self.autoflush and not self._flushing:
             try:
                 self.flush()
@@ -2161,7 +2328,12 @@ class Session(_SessionClassMethods):
                 )
                 raise e.with_traceback(sys.exc_info()[2])
 
-    def refresh(self, instance, attribute_names=None, with_for_update=None):
+    def refresh(
+        self,
+        instance: object,
+        attribute_names: Optional[Iterable[str]] = None,
+        with_for_update: Optional[ForUpdateArg] = None,
+    ) -> None:
         """Expire and refresh attributes on the given instance.
 
         The selected attributes will first be expired as they would when using
@@ -2233,7 +2405,7 @@ class Session(_SessionClassMethods):
                 "A blank dictionary is ambiguous."
             )
 
-        with_for_update = query.ForUpdateArg._from_argument(with_for_update)
+        with_for_update = ForUpdateArg._from_argument(with_for_update)
 
         stmt = sql.select(object_mapper(instance))
         if (
@@ -2251,7 +2423,7 @@ class Session(_SessionClassMethods):
                 "Could not refresh instance '%s'" % instance_str(instance)
             )
 
-    def expire_all(self):
+    def expire_all(self) -> None:
         """Expires all persistent instances within this Session.
 
         When any attributes on a persistent instance is next accessed,
@@ -2286,7 +2458,9 @@ class Session(_SessionClassMethods):
         for state in self.identity_map.all_states():
             state._expire(state.dict, self.identity_map._modified)
 
-    def expire(self, instance, attribute_names=None):
+    def expire(
+        self, instance: object, attribute_names: Optional[Iterable[str]] = None
+    ) -> None:
         """Expire the attributes on an instance.
 
         Marks the attributes of an instance as out of date. When an expired
@@ -2329,7 +2503,11 @@ class Session(_SessionClassMethods):
             raise exc.UnmappedInstanceError(instance) from err
         self._expire_state(state, attribute_names)
 
-    def _expire_state(self, state, attribute_names):
+    def _expire_state(
+        self,
+        state: InstanceState[Any],
+        attribute_names: Optional[Iterable[str]],
+    ) -> None:
         self._validate_persistent(state)
         if attribute_names:
             state._expire_attributes(state.dict, attribute_names)
@@ -2343,7 +2521,9 @@ class Session(_SessionClassMethods):
             for o, m, st_, dct_ in cascaded:
                 self._conditional_expire(st_)
 
-    def _conditional_expire(self, state, autoflush=None):
+    def _conditional_expire(
+        self, state: InstanceState[Any], autoflush: Optional[bool] = None
+    ) -> None:
         """Expire a state if persistent, else expunge if pending"""
 
         if state.key:
@@ -2352,7 +2532,7 @@ class Session(_SessionClassMethods):
             self._new.pop(state)
             state._detach(self)
 
-    def expunge(self, instance):
+    def expunge(self, instance: object) -> None:
         """Remove the `instance` from this ``Session``.
 
         This will free all internal references to the instance.  Cascading
@@ -2373,7 +2553,9 @@ class Session(_SessionClassMethods):
         )
         self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded])
 
-    def _expunge_states(self, states, to_transient=False):
+    def _expunge_states(
+        self, states: Iterable[InstanceState[Any]], to_transient: bool = False
+    ) -> None:
         for state in states:
             if state in self._new:
                 self._new.pop(state)
@@ -2388,7 +2570,7 @@ class Session(_SessionClassMethods):
             states, self, to_transient=to_transient
         )
 
-    def _register_persistent(self, states):
+    def _register_persistent(self, states: Set[InstanceState[Any]]) -> None:
         """Register all persistent objects from a flush.
 
         This is used both for pending objects moving to the persistent
@@ -2429,11 +2611,13 @@ class Session(_SessionClassMethods):
                     # state has already replaced this one in the identity
                     # map (see test/orm/test_naturalpks.py ReversePKsTest)
                     self.identity_map.safe_discard(state)
-                    if state in self._transaction._key_switches:
-                        orig_key = self._transaction._key_switches[state][0]
+                    trans = self._transaction
+                    assert trans is not None
+                    if state in trans._key_switches:
+                        orig_key = trans._key_switches[state][0]
                     else:
                         orig_key = state.key
-                    self._transaction._key_switches[state] = (
+                    trans._key_switches[state] = (
                         orig_key,
                         instance_key,
                     )
@@ -2470,7 +2654,7 @@ class Session(_SessionClassMethods):
         for state in set(states).intersection(self._new):
             self._new.pop(state)
 
-    def _register_altered(self, states):
+    def _register_altered(self, states: Iterable[InstanceState[Any]]) -> None:
         if self._transaction:
             for state in states:
                 if state in self._new:
@@ -2478,7 +2662,9 @@ class Session(_SessionClassMethods):
                 else:
                     self._transaction._dirty[state] = True
 
-    def _remove_newly_deleted(self, states):
+    def _remove_newly_deleted(
+        self, states: Iterable[InstanceState[Any]]
+    ) -> None:
         persistent_to_deleted = self.dispatch.persistent_to_deleted or None
         for state in states:
             if self._transaction:
@@ -2498,7 +2684,7 @@ class Session(_SessionClassMethods):
             if persistent_to_deleted is not None:
                 persistent_to_deleted(self, state)
 
-    def add(self, instance: Any, _warn: bool = True) -> None:
+    def add(self, instance: object, _warn: bool = True) -> None:
         """Place an object in the ``Session``.
 
         Its state will be persisted to the database on the next flush
@@ -2518,7 +2704,7 @@ class Session(_SessionClassMethods):
 
         self._save_or_update_state(state)
 
-    def add_all(self, instances):
+    def add_all(self, instances: Iterable[object]) -> None:
         """Add the given collection of instances to this ``Session``."""
 
         if self._warn_on_events:
@@ -2527,7 +2713,7 @@ class Session(_SessionClassMethods):
         for instance in instances:
             self.add(instance, _warn=False)
 
-    def _save_or_update_state(self, state):
+    def _save_or_update_state(self, state: InstanceState[Any]) -> None:
         state._orphaned_outside_of_session = False
         self._save_or_update_impl(state)
 
@@ -2537,7 +2723,7 @@ class Session(_SessionClassMethods):
         ):
             self._save_or_update_impl(st_)
 
-    def delete(self, instance):
+    def delete(self, instance: object) -> None:
         """Mark an instance as deleted.
 
         The database delete operation occurs upon ``flush()``.
@@ -2553,7 +2739,9 @@ class Session(_SessionClassMethods):
 
         self._delete_impl(state, instance, head=True)
 
-    def _delete_impl(self, state, obj, head):
+    def _delete_impl(
+        self, state: InstanceState[Any], obj: object, head: bool
+    ) -> None:
 
         if state.key is None:
             if head:
@@ -2580,23 +2768,28 @@ class Session(_SessionClassMethods):
             cascade_states = list(
                 state.manager.mapper.cascade_iterator("delete", state)
             )
+        else:
+            cascade_states = None
 
         self._deleted[state] = obj
 
         if head:
+            if TYPE_CHECKING:
+                assert cascade_states is not None
             for o, m, st_, dct_ in cascade_states:
                 self._delete_impl(st_, o, False)
 
     def get(
         self,
-        entity,
-        ident,
-        options=None,
-        populate_existing=False,
-        with_for_update=None,
-        identity_token=None,
-        execution_options=None,
-    ):
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: Optional[ForUpdateArg] = None,
+        identity_token: Optional[Any] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+    ) -> Optional[_O]:
         """Return an instance based on the given primary key identifier,
         or ``None`` if not found.
 
@@ -2696,7 +2889,7 @@ class Session(_SessionClassMethods):
             entity,
             ident,
             loading.load_on_pk_identity,
-            options,
+            options=options,
             populate_existing=populate_existing,
             with_for_update=with_for_update,
             identity_token=identity_token,
@@ -2705,23 +2898,24 @@ class Session(_SessionClassMethods):
 
     def _get_impl(
         self,
-        entity,
-        primary_key_identity,
-        db_load_fn,
-        options=None,
-        populate_existing=False,
-        with_for_update=None,
-        identity_token=None,
-        execution_options=None,
-    ):
+        entity: _EntityBindKey[_O],
+        primary_key_identity: _PKIdentityArgument,
+        db_load_fn: Callable[..., _O],
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: Optional[ForUpdateArg] = None,
+        identity_token: Optional[Any] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> Optional[_O]:
 
         # convert composite types to individual args
-        if hasattr(primary_key_identity, "__composite_values__"):
+        if is_composite_class(primary_key_identity):
             primary_key_identity = primary_key_identity.__composite_values__()
 
-        mapper = inspect(entity)
+        mapper: Optional[Mapper[_O]] = inspect(entity)
 
-        if not mapper or not mapper.is_mapper:
+        if mapper is None or not mapper.is_mapper:
             raise sa_exc.ArgumentError(
                 "Expected mapped class or mapper, got: %r" % entity
             )
@@ -2729,7 +2923,7 @@ class Session(_SessionClassMethods):
         is_dict = isinstance(primary_key_identity, dict)
         if not is_dict:
             primary_key_identity = util.to_list(
-                primary_key_identity, default=(None,)
+                primary_key_identity, default=[None]
             )
 
         if len(primary_key_identity) != len(mapper.primary_key):
@@ -2770,11 +2964,12 @@ class Session(_SessionClassMethods):
             if instance is not None:
                 # reject calls for id in identity map but class
                 # mismatch.
-                if not issubclass(instance.__class__, mapper.class_):
+                if not isinstance(instance, mapper.class_):
                     return None
                 return instance
-            elif instance is attributes.PASSIVE_CLASS_MISMATCH:
-                return None
+
+            # TODO: this was being tested before, but this is not possible
+            assert instance is not LoaderCallableStatus.PASSIVE_CLASS_MISMATCH
 
         # set_label_style() not strictly necessary, however this will ensure
         # that tablename_colname style is used which at the moment is
@@ -2788,7 +2983,7 @@ class Session(_SessionClassMethods):
             LABEL_STYLE_TABLENAME_PLUS_COL
         )
         if with_for_update is not None:
-            statement._for_update_arg = query.ForUpdateArg._from_argument(
+            statement._for_update_arg = ForUpdateArg._from_argument(
                 with_for_update
             )
 
@@ -2803,7 +2998,13 @@ class Session(_SessionClassMethods):
             load_options=load_options,
         )
 
-    def merge(self, instance, load=True, options=None):
+    def merge(
+        self,
+        instance: _O,
+        *,
+        load: bool = True,
+        options: Optional[Sequence[ORMOption]] = None,
+    ) -> _O:
         """Copy the state of a given instance into a corresponding instance
         within this :class:`.Session`.
 
@@ -2866,8 +3067,8 @@ class Session(_SessionClassMethods):
         if self._warn_on_events:
             self._flush_warning("Session.merge()")
 
-        _recursive = {}
-        _resolve_conflict_map = {}
+        _recursive: Dict[InstanceState[Any], object] = {}
+        _resolve_conflict_map: Dict[_IdentityKeyType[Any], object] = {}
 
         if load:
             # flush current contents if we expect to load data
@@ -2890,20 +3091,23 @@ class Session(_SessionClassMethods):
 
     def _merge(
         self,
-        state,
-        state_dict,
-        load=True,
-        options=None,
-        _recursive=None,
-        _resolve_conflict_map=None,
-    ):
+        state: InstanceState[_O],
+        state_dict: _InstanceDict,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        load: bool,
+        _recursive: Dict[InstanceState[Any], object],
+        _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+    ) -> _O:
         mapper = _state_mapper(state)
         if state in _recursive:
-            return _recursive[state]
+            return cast(_O, _recursive[state])
 
         new_instance = False
         key = state.key
 
+        merged: Optional[_O]
+
         if key is None:
             if state in self._new:
                 util.warn(
@@ -2920,7 +3124,9 @@ class Session(_SessionClassMethods):
                     "load=False."
                 )
             key = mapper._identity_key_from_state(state)
-            key_is_persistent = attributes.NEVER_SET not in key[1] and (
+            key_is_persistent = LoaderCallableStatus.NEVER_SET not in key[
+                1
+            ] and (
                 not _none_set.intersection(key[1])
                 or (
                     mapper.allow_partial_pks
@@ -2941,7 +3147,7 @@ class Session(_SessionClassMethods):
 
         if merged is None:
             if key_is_persistent and key in _resolve_conflict_map:
-                merged = _resolve_conflict_map[key]
+                merged = cast(_O, _resolve_conflict_map[key])
 
             elif not load:
                 if state.modified:
@@ -2986,19 +3192,21 @@ class Session(_SessionClassMethods):
                     state,
                     state_dict,
                     mapper.version_id_col,
-                    passive=attributes.PASSIVE_NO_INITIALIZE,
+                    passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
                 )
 
                 merged_version = mapper._get_state_attr_by_column(
                     merged_state,
                     merged_dict,
                     mapper.version_id_col,
-                    passive=attributes.PASSIVE_NO_INITIALIZE,
+                    passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
                 )
 
                 if (
-                    existing_version is not attributes.PASSIVE_NO_RESULT
-                    and merged_version is not attributes.PASSIVE_NO_RESULT
+                    existing_version
+                    is not LoaderCallableStatus.PASSIVE_NO_RESULT
+                    and merged_version
+                    is not LoaderCallableStatus.PASSIVE_NO_RESULT
                     and existing_version != merged_version
                 ):
                     raise exc.StaleDataError(
@@ -3043,14 +3251,14 @@ class Session(_SessionClassMethods):
             merged_state.manager.dispatch.load(merged_state, None)
         return merged
 
-    def _validate_persistent(self, state):
+    def _validate_persistent(self, state: InstanceState[Any]) -> None:
         if not self.identity_map.contains_state(state):
             raise sa_exc.InvalidRequestError(
                 "Instance '%s' is not persistent within this Session"
                 % state_str(state)
             )
 
-    def _save_impl(self, state):
+    def _save_impl(self, state: InstanceState[Any]) -> None:
         if state.key is not None:
             raise sa_exc.InvalidRequestError(
                 "Object '%s' already has an identity - "
@@ -3065,7 +3273,9 @@ class Session(_SessionClassMethods):
         if to_attach:
             self._after_attach(state, obj)
 
-    def _update_impl(self, state, revert_deletion=False):
+    def _update_impl(
+        self, state: InstanceState[Any], revert_deletion: bool = False
+    ) -> None:
         if state.key is None:
             raise sa_exc.InvalidRequestError(
                 "Instance '%s' is not persisted" % state_str(state)
@@ -3103,13 +3313,13 @@ class Session(_SessionClassMethods):
         elif revert_deletion:
             self.dispatch.deleted_to_persistent(self, state)
 
-    def _save_or_update_impl(self, state):
+    def _save_or_update_impl(self, state: InstanceState[Any]) -> None:
         if state.key is None:
             self._save_impl(state)
         else:
             self._update_impl(state)
 
-    def enable_relationship_loading(self, obj):
+    def enable_relationship_loading(self, obj: object) -> None:
         """Associate an object with this :class:`.Session` for related
         object loading.
 
@@ -3174,8 +3384,8 @@ class Session(_SessionClassMethods):
         if to_attach:
             self._after_attach(state, obj)
 
-    def _before_attach(self, state, obj):
-        self._autobegin()
+    def _before_attach(self, state: InstanceState[Any], obj: object) -> bool:
+        self._autobegin_t()
 
         if state.session_id == self.hash_key:
             return False
@@ -3191,7 +3401,7 @@ class Session(_SessionClassMethods):
 
         return True
 
-    def _after_attach(self, state, obj):
+    def _after_attach(self, state: InstanceState[Any], obj: object) -> None:
         state.session_id = self.hash_key
         if state.modified and state._strong_obj is None:
             state._strong_obj = obj
@@ -3202,7 +3412,7 @@ class Session(_SessionClassMethods):
         else:
             self.dispatch.transient_to_pending(self, state)
 
-    def __contains__(self, instance):
+    def __contains__(self, instance: object) -> bool:
         """Return True if the instance is associated with this session.
 
         The instance may be pending or persistent within the Session for a
@@ -3215,7 +3425,7 @@ class Session(_SessionClassMethods):
             raise exc.UnmappedInstanceError(instance) from err
         return self._contains_state(state)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[object]:
         """Iterate over all pending or persistent instances within this
         Session.
 
@@ -3224,10 +3434,10 @@ class Session(_SessionClassMethods):
             list(self._new.values()) + list(self.identity_map.values())
         )
 
-    def _contains_state(self, state):
+    def _contains_state(self, state: InstanceState[Any]) -> bool:
         return state in self._new or self.identity_map.contains_state(state)
 
-    def flush(self, objects=None):
+    def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
         """Flush all the object changes to the database.
 
         Writes out all pending object creations, deletions and modifications
@@ -3261,7 +3471,7 @@ class Session(_SessionClassMethods):
         finally:
             self._flushing = False
 
-    def _flush_warning(self, method):
+    def _flush_warning(self, method: Any) -> None:
         util.warn(
             "Usage of the '%s' operation is not currently supported "
             "within the execution stage of the flush process. "
@@ -3269,14 +3479,14 @@ class Session(_SessionClassMethods):
             "event listeners or connection-level operations instead." % method
         )
 
-    def _is_clean(self):
+    def _is_clean(self) -> bool:
         return (
             not self.identity_map.check_modified()
             and not self._deleted
             and not self._new
         )
 
-    def _flush(self, objects=None):
+    def _flush(self, objects: Optional[Sequence[object]] = None) -> None:
 
         dirty = self._dirty_states
         if not dirty and not self._deleted and not self._new:
@@ -3398,11 +3608,11 @@ class Session(_SessionClassMethods):
 
     def bulk_save_objects(
         self,
-        objects,
-        return_defaults=False,
-        update_changed_only=True,
-        preserve_order=True,
-    ):
+        objects: Iterable[object],
+        return_defaults: bool = False,
+        update_changed_only: bool = True,
+        preserve_order: bool = True,
+    ) -> None:
         """Perform a bulk save of the given list of objects.
 
         The bulk save feature allows mapped objects to be used as the
@@ -3496,6 +3706,8 @@ class Session(_SessionClassMethods):
 
         """
 
+        obj_states: Iterable[InstanceState[Any]]
+
         obj_states = (attributes.instance_state(obj) for obj in objects)
 
         if not preserve_order:
@@ -3508,7 +3720,9 @@ class Session(_SessionClassMethods):
                 key=lambda state: (id(state.mapper), state.key is not None),
             )
 
-        def grouping_key(state):
+        def grouping_key(
+            state: InstanceState[_O],
+        ) -> Tuple[Mapper[_O], bool]:
             return (state.mapper, state.key is not None)
 
         for (mapper, isupdate), states in itertools.groupby(
@@ -3525,8 +3739,12 @@ class Session(_SessionClassMethods):
             )
 
     def bulk_insert_mappings(
-        self, mapper, mappings, return_defaults=False, render_nulls=False
-    ):
+        self,
+        mapper: Mapper[Any],
+        mappings: Iterable[Dict[str, Any]],
+        return_defaults: bool = False,
+        render_nulls: bool = False,
+    ) -> None:
         """Perform a bulk insert of the given list of mapping dictionaries.
 
         The bulk insert feature allows plain Python dictionaries to be used as
@@ -3633,7 +3851,9 @@ class Session(_SessionClassMethods):
             render_nulls,
         )
 
-    def bulk_update_mappings(self, mapper, mappings):
+    def bulk_update_mappings(
+        self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]]
+    ) -> None:
         """Perform a bulk update of the given list of mapping dictionaries.
 
         The bulk update feature allows plain Python dictionaries to be used as
@@ -3696,14 +3916,14 @@ class Session(_SessionClassMethods):
 
     def _bulk_save_mappings(
         self,
-        mapper,
-        mappings,
-        isupdate,
-        isstates,
-        return_defaults,
-        update_changed_only,
-        render_nulls,
-    ):
+        mapper: Mapper[_O],
+        mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+        isupdate: bool,
+        isstates: bool,
+        return_defaults: bool,
+        update_changed_only: bool,
+        render_nulls: bool,
+    ) -> None:
         mapper = _class_to_mapper(mapper)
         self._flushing = True
 
@@ -3734,7 +3954,9 @@ class Session(_SessionClassMethods):
         finally:
             self._flushing = False
 
-    def is_modified(self, instance, include_collections=True):
+    def is_modified(
+        self, instance: object, include_collections: bool = True
+    ) -> bool:
         r"""Return ``True`` if the given instance has locally
         modified attributes.
 
@@ -3800,7 +4022,7 @@ class Session(_SessionClassMethods):
                 continue
 
             (added, unchanged, deleted) = attr.impl.get_history(
-                state, dict_, passive=attributes.NO_CHANGE
+                state, dict_, passive=PassiveFlag.NO_CHANGE
             )
 
             if added or deleted:
@@ -3809,7 +4031,7 @@ class Session(_SessionClassMethods):
             return False
 
     @property
-    def is_active(self):
+    def is_active(self) -> bool:
         """True if this :class:`.Session` not in "partial rollback" state.
 
         .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins
@@ -3838,22 +4060,8 @@ class Session(_SessionClassMethods):
         """
         return self._transaction is None or self._transaction.is_active
 
-    identity_map = None
-    """A mapping of object identities to objects themselves.
-
-    Iterating through ``Session.identity_map.values()`` provides
-    access to the full set of persistent objects (i.e., those
-    that have row identity) currently in the session.
-
-    .. seealso::
-
-        :func:`.identity_key` - helper function to produce the keys used
-        in this dictionary.
-
-    """
-
     @property
-    def _dirty_states(self):
+    def _dirty_states(self) -> Iterable[InstanceState[Any]]:
         """The set of all persistent states considered dirty.
 
         This method returns all states that were modified including
@@ -3863,7 +4071,7 @@ class Session(_SessionClassMethods):
         return self.identity_map._dirty_states()
 
     @property
-    def dirty(self):
+    def dirty(self) -> IdentitySet:
         """The set of all persistent instances considered dirty.
 
         E.g.::
@@ -3886,7 +4094,7 @@ class Session(_SessionClassMethods):
         attributes, use the :meth:`.Session.is_modified` method.
 
         """
-        return util.IdentitySet(
+        return IdentitySet(
             [
                 state.obj()
                 for state in self._dirty_states
@@ -3895,13 +4103,13 @@ class Session(_SessionClassMethods):
         )
 
     @property
-    def deleted(self):
+    def deleted(self) -> IdentitySet:
         "The set of all instances marked as 'deleted' within this ``Session``"
 
         return util.IdentitySet(list(self._deleted.values()))
 
     @property
-    def new(self):
+    def new(self) -> IdentitySet:
         "The set of all instances marked as 'new' within this ``Session``."
 
         return util.IdentitySet(list(self._new.values()))
@@ -4002,14 +4210,16 @@ class sessionmaker(_SessionClassMethods):
 
     """
 
+    class_: Type[Session]
+
     def __init__(
         self,
-        bind=None,
-        class_=Session,
-        autoflush=True,
-        expire_on_commit=True,
-        info=None,
-        **kw,
+        bind: Optional[_SessionBind] = None,
+        class_: Type[Session] = Session,
+        autoflush: bool = True,
+        expire_on_commit: bool = True,
+        info: Optional[Dict[Any, Any]] = None,
+        **kw: Any,
     ):
         r"""Construct a new :class:`.sessionmaker`.
 
@@ -4052,7 +4262,7 @@ class sessionmaker(_SessionClassMethods):
         # events can be associated with it specifically.
         self.class_ = type(class_.__name__, (class_,), {})
 
-    def begin(self):
+    def begin(self) -> contextlib.AbstractContextManager[Session]:
         """Produce a context manager that both provides a new
         :class:`_orm.Session` as well as a transaction that commits.
 
@@ -4074,7 +4284,7 @@ class sessionmaker(_SessionClassMethods):
         session = self()
         return session._maker_context_manager()
 
-    def __call__(self, **local_kw):
+    def __call__(self, **local_kw: Any) -> Session:
         """Produce a new :class:`.Session` object using the configuration
         established in this :class:`.sessionmaker`.
 
@@ -4094,7 +4304,7 @@ class sessionmaker(_SessionClassMethods):
                 local_kw.setdefault(k, v)
         return self.class_(**local_kw)
 
-    def configure(self, **new_kw):
+    def configure(self, **new_kw: Any) -> None:
         """(Re)configure the arguments for this sessionmaker.
 
         e.g.::
@@ -4105,7 +4315,7 @@ class sessionmaker(_SessionClassMethods):
         """
         self.kw.update(new_kw)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s(class_=%r, %s)" % (
             self.__class__.__name__,
             self.class_.__name__,
@@ -4113,7 +4323,7 @@ class sessionmaker(_SessionClassMethods):
         )
 
 
-def close_all_sessions():
+def close_all_sessions() -> None:
     """Close all sessions in memory.
 
     This function consults a global registry of all :class:`.Session` objects
@@ -4131,7 +4341,7 @@ def close_all_sessions():
         sess.close()
 
 
-def make_transient(instance):
+def make_transient(instance: object) -> None:
     """Alter the state of the given instance so that it is :term:`transient`.
 
     .. note::
@@ -4195,7 +4405,7 @@ def make_transient(instance):
         del state._deleted
 
 
-def make_transient_to_detached(instance):
+def make_transient_to_detached(instance: object) -> None:
     """Make the given transient instance :term:`detached`.
 
     .. note::
@@ -4234,7 +4444,7 @@ def make_transient_to_detached(instance):
     state._expire_attributes(state.dict, state.unloaded_expirable)
 
 
-def object_session(instance):
+def object_session(instance: object) -> Optional[Session]:
     """Return the :class:`.Session` to which the given instance belongs.
 
     This is essentially the same as the :attr:`.InstanceState.session`
index c3e4e299ab66f8bbbf36b0cfc9d90e4a11f25a28..7ccda956598f5645b6ba27d5d6883f56eceb24aa 100644 (file)
@@ -14,13 +14,25 @@ defines a large part of the ORM's interactivity.
 
 from __future__ import annotations
 
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
 import weakref
 
 from . import base
 from . import exc as orm_exc
 from . import interfaces
+from ._typing import is_collection_impl
 from .base import ATTR_WAS_SET
 from .base import INIT_OK
+from .base import LoaderCallableStatus
 from .base import NEVER_SET
 from .base import NO_VALUE
 from .base import PASSIVE_NO_INITIALIZE
@@ -31,17 +43,47 @@ from .path_registry import PathRegistry
 from .. import exc as sa_exc
 from .. import inspection
 from .. import util
+from ..util.typing import Protocol
 
+if TYPE_CHECKING:
+    from ._typing import _IdentityKeyType
+    from ._typing import _InstanceDict
+    from ._typing import _LoaderCallable
+    from .attributes import AttributeImpl
+    from .attributes import History
+    from .base import LoaderCallableStatus
+    from .base import PassiveFlag
+    from .identity import IdentityMap
+    from .instrumentation import ClassManager
+    from .interfaces import ORMOption
+    from .mapper import Mapper
+    from .session import Session
+    from ..engine import Row
+    from ..ext.asyncio.session import async_session as _async_provider
+    from ..ext.asyncio.session import AsyncSession
 
-# late-populated by session.py
-_sessions = None
+_T = TypeVar("_T", bound=Any)
 
-# optionally late-provided by sqlalchemy.ext.asyncio.session
-_async_provider = None
+if TYPE_CHECKING:
+    _sessions: weakref.WeakValueDictionary[int, Session]
+else:
+    # late-populated by session.py
+    _sessions = None
+
+
+if not TYPE_CHECKING:
+    # optionally late-provided by sqlalchemy.ext.asyncio.session
+
+    _async_provider = None  # noqa
+
+
+class _InstanceDictProto(Protocol):
+    def __call__(self) -> Optional[IdentityMap]:
+        ...
 
 
 @inspection._self_inspects
-class InstanceState(interfaces.InspectionAttrInfo):
+class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
     """tracks state information at the instance level.
 
     The :class:`.InstanceState` is a key object used by the
@@ -67,23 +109,57 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
     """
 
-    session_id = None
-    key = None
-    runid = None
-    load_options = ()
-    load_path = PathRegistry.root
-    insert_order = None
-    _strong_obj = None
-    modified = False
-    expired = False
-    _deleted = False
-    _load_pending = False
-    _orphaned_outside_of_session = False
-    is_instance = True
-    identity_token = None
-    _last_known_values = ()
-
-    callables = ()
+    __slots__ = (
+        "__dict__",
+        "__weakref__",
+        "class_",
+        "manager",
+        "obj",
+        "committed_state",
+        "expired_attributes",
+    )
+
+    manager: ClassManager[_T]
+    session_id: Optional[int] = None
+    key: Optional[_IdentityKeyType[_T]] = None
+    runid: Optional[int] = None
+    load_options: Tuple[ORMOption, ...] = ()
+    load_path: PathRegistry = PathRegistry.root
+    insert_order: Optional[int] = None
+    _strong_obj: Optional[object] = None
+    obj: weakref.ref[_T]
+
+    committed_state: Dict[str, Any]
+
+    modified: bool = False
+    expired: bool = False
+    _deleted: bool = False
+    _load_pending: bool = False
+    _orphaned_outside_of_session: bool = False
+    is_instance: bool = True
+    identity_token: object = None
+    _last_known_values: Optional[Dict[str, Any]] = None
+
+    _instance_dict: _InstanceDictProto
+    """A weak reference, or in the default case a plain callable, that
+    returns a reference to the current :class:`.IdentityMap`, if any.
+
+    """
+    if not TYPE_CHECKING:
+
+        def _instance_dict(self):
+            """default 'weak reference' for _instance_dict"""
+            return None
+
+    expired_attributes: Set[str]
+    """The set of keys which are 'expired' to be loaded by
+       the manager's deferred scalar loader, assuming no pending
+       changes.
+
+       see also the ``unmodified`` collection which is intersected
+       against this set when a refresh operation occurs."""
+
+    callables: Dict[str, Callable[[InstanceState[_T], PassiveFlag], Any]]
     """A namespace where a per-state loader callable can be associated.
 
     In SQLAlchemy 1.0, this is only used for lazy loaders / deferred
@@ -95,23 +171,18 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
     """
 
-    def __init__(self, obj, manager):
+    if not TYPE_CHECKING:
+        callables = util.EMPTY_DICT
+
+    def __init__(self, obj: _T, manager: ClassManager[_T]):
         self.class_ = obj.__class__
         self.manager = manager
         self.obj = weakref.ref(obj, self._cleanup)
         self.committed_state = {}
         self.expired_attributes = set()
 
-    expired_attributes = None
-    """The set of keys which are 'expired' to be loaded by
-       the manager's deferred scalar loader, assuming no pending
-       changes.
-
-       see also the ``unmodified`` collection which is intersected
-       against this set when a refresh operation occurs."""
-
     @util.memoized_property
-    def attrs(self):
+    def attrs(self) -> util.ReadOnlyProperties[AttributeState]:
         """Return a namespace representing each attribute on
         the mapped object, including its current value
         and history.
@@ -123,11 +194,11 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
         """
         return util.ReadOnlyProperties(
-            dict((key, AttributeState(self, key)) for key in self.manager)
+            {key: AttributeState(self, key) for key in self.manager}
         )
 
     @property
-    def transient(self):
+    def transient(self) -> bool:
         """Return ``True`` if the object is :term:`transient`.
 
         .. seealso::
@@ -138,7 +209,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return self.key is None and not self._attached
 
     @property
-    def pending(self):
+    def pending(self) -> bool:
         """Return ``True`` if the object is :term:`pending`.
 
 
@@ -150,7 +221,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return self.key is None and self._attached
 
     @property
-    def deleted(self):
+    def deleted(self) -> bool:
         """Return ``True`` if the object is :term:`deleted`.
 
         An object that is in the deleted state is guaranteed to
@@ -180,7 +251,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return self.key is not None and self._attached and self._deleted
 
     @property
-    def was_deleted(self):
+    def was_deleted(self) -> bool:
         """Return True if this object is or was previously in the
         "deleted" state and has not been reverted to persistent.
 
@@ -204,7 +275,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return self._deleted
 
     @property
-    def persistent(self):
+    def persistent(self) -> bool:
         """Return ``True`` if the object is :term:`persistent`.
 
         An object that is in the persistent state is guaranteed to
@@ -225,7 +296,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return self.key is not None and self._attached and not self._deleted
 
     @property
-    def detached(self):
+    def detached(self) -> bool:
         """Return ``True`` if the object is :term:`detached`.
 
         .. seealso::
@@ -235,15 +306,15 @@ class InstanceState(interfaces.InspectionAttrInfo):
         """
         return self.key is not None and not self._attached
 
-    @property
+    @util.non_memoized_property
     @util.preload_module("sqlalchemy.orm.session")
-    def _attached(self):
+    def _attached(self) -> bool:
         return (
             self.session_id is not None
             and self.session_id in util.preloaded.orm_session._sessions
         )
 
-    def _track_last_known_value(self, key):
+    def _track_last_known_value(self, key: str) -> None:
         """Track the last known value of a particular key after expiration
         operations.
 
@@ -251,12 +322,14 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
         """
 
-        if key not in self._last_known_values:
-            self._last_known_values = dict(self._last_known_values)
-            self._last_known_values[key] = NO_VALUE
+        lkv = self._last_known_values
+        if lkv is None:
+            self._last_known_values = lkv = {}
+        if key not in lkv:
+            lkv[key] = NO_VALUE
 
     @property
-    def session(self):
+    def session(self) -> Optional[Session]:
         """Return the owning :class:`.Session` for this instance,
         or ``None`` if none available.
 
@@ -280,7 +353,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return None
 
     @property
-    def async_session(self):
+    def async_session(self) -> Optional[AsyncSession]:
         """Return the owning :class:`_asyncio.AsyncSession` for this instance,
         or ``None`` if none available.
 
@@ -308,13 +381,17 @@ class InstanceState(interfaces.InspectionAttrInfo):
             return None
 
     @property
-    def object(self):
+    def object(self) -> Optional[_T]:
         """Return the mapped object represented by this
-        :class:`.InstanceState`."""
+        :class:`.InstanceState`.
+
+        Returns None if the object has been garbage collected
+
+        """
         return self.obj()
 
     @property
-    def identity(self):
+    def identity(self) -> Optional[Tuple[Any, ...]]:
         """Return the mapped identity of the mapped object.
         This is the primary key identity as persisted by the ORM
         which can always be passed directly to
@@ -334,7 +411,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
             return self.key[1]
 
     @property
-    def identity_key(self):
+    def identity_key(self) -> Optional[_IdentityKeyType[_T]]:
         """Return the identity key for the mapped object.
 
         This is the key used to locate the object within
@@ -343,29 +420,27 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
 
         """
-        # TODO: just change .key to .identity_key across
-        # the board ?  probably
         return self.key
 
     @util.memoized_property
-    def parents(self):
+    def parents(self) -> Dict[int, InstanceState[Any]]:
         return {}
 
     @util.memoized_property
-    def _pending_mutations(self):
+    def _pending_mutations(self) -> Dict[str, PendingCollection]:
         return {}
 
     @util.memoized_property
-    def _empty_collections(self):
+    def _empty_collections(self) -> Dict[Any, Any]:
         return {}
 
     @util.memoized_property
-    def mapper(self):
+    def mapper(self) -> Mapper[_T]:
         """Return the :class:`_orm.Mapper` used for this mapped object."""
         return self.manager.mapper
 
     @property
-    def has_identity(self):
+    def has_identity(self) -> bool:
         """Return ``True`` if this object has an identity key.
 
         This should always have the same value as the
@@ -375,7 +450,12 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return bool(self.key)
 
     @classmethod
-    def _detach_states(self, states, session, to_transient=False):
+    def _detach_states(
+        self,
+        states: Iterable[InstanceState[_T]],
+        session: Session,
+        to_transient: bool = False,
+    ) -> None:
         persistent_to_detached = (
             session.dispatch.persistent_to_detached or None
         )
@@ -407,17 +487,17 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
             state._strong_obj = None
 
-    def _detach(self, session=None):
+    def _detach(self, session: Optional[Session] = None) -> None:
         if session:
             InstanceState._detach_states([self], session)
         else:
             self.session_id = self._strong_obj = None
 
-    def _dispose(self):
+    def _dispose(self) -> None:
+        # used by the test suite, apparently
         self._detach()
-        del self.obj
 
-    def _cleanup(self, ref):
+    def _cleanup(self, ref: weakref.ref[_T]) -> None:
         """Weakref callback cleanup.
 
         This callable cleans out the state when it is being garbage
@@ -445,13 +525,9 @@ class InstanceState(interfaces.InspectionAttrInfo):
             # assert self not in instance_dict._modified
 
         self.session_id = self._strong_obj = None
-        del self.obj
-
-    def obj(self):
-        return None
 
     @property
-    def dict(self):
+    def dict(self) -> _InstanceDict:
         """Return the instance dict used by the object.
 
         Under normal circumstances, this is always synonymous
@@ -469,35 +545,39 @@ class InstanceState(interfaces.InspectionAttrInfo):
         else:
             return {}
 
-    def _initialize_instance(*mixed, **kwargs):
+    def _initialize_instance(*mixed: Any, **kwargs: Any) -> None:
         self, instance, args = mixed[0], mixed[1], mixed[2:]  # noqa
         manager = self.manager
 
         manager.dispatch.init(self, args, kwargs)
 
         try:
-            return manager.original_init(*mixed[1:], **kwargs)
+            manager.original_init(*mixed[1:], **kwargs)
         except:
             with util.safe_reraise():
                 manager.dispatch.init_failure(self, args, kwargs)
 
-    def get_history(self, key, passive):
+    def get_history(self, key: str, passive: PassiveFlag) -> History:
         return self.manager[key].impl.get_history(self, self.dict, passive)
 
-    def get_impl(self, key):
+    def get_impl(self, key: str) -> AttributeImpl:
         return self.manager[key].impl
 
-    def _get_pending_mutation(self, key):
+    def _get_pending_mutation(self, key: str) -> PendingCollection:
         if key not in self._pending_mutations:
             self._pending_mutations[key] = PendingCollection()
         return self._pending_mutations[key]
 
-    def __getstate__(self):
-        state_dict = {"instance": self.obj()}
+    def __getstate__(self) -> Dict[str, Any]:
+        state_dict = {
+            "instance": self.obj(),
+            "class_": self.class_,
+            "committed_state": self.committed_state,
+            "expired_attributes": self.expired_attributes,
+        }
         state_dict.update(
             (k, self.__dict__[k])
             for k in (
-                "committed_state",
                 "_pending_mutations",
                 "modified",
                 "expired",
@@ -518,21 +598,18 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
         return state_dict
 
-    def __setstate__(self, state_dict):
+    def __setstate__(self, state_dict: Dict[str, Any]) -> None:
         inst = state_dict["instance"]
         if inst is not None:
             self.obj = weakref.ref(inst, self._cleanup)
             self.class_ = inst.__class__
         else:
-            # None being possible here generally new as of 0.7.4
-            # due to storage of state in "parents".  "class_"
-            # also new.
-            self.obj = None
+            self.obj = lambda: None  # type: ignore
             self.class_ = state_dict["class_"]
 
         self.committed_state = state_dict.get("committed_state", {})
-        self._pending_mutations = state_dict.get("_pending_mutations", {})
-        self.parents = state_dict.get("parents", {})
+        self._pending_mutations = state_dict.get("_pending_mutations", {})  # type: ignore  # noqa E501
+        self.parents = state_dict.get("parents", {})  # type: ignore
         self.modified = state_dict.get("modified", False)
         self.expired = state_dict.get("expired", False)
         if "info" in state_dict:
@@ -540,15 +617,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         if "callables" in state_dict:
             self.callables = state_dict["callables"]
 
-            try:
-                self.expired_attributes = state_dict["expired_attributes"]
-            except KeyError:
-                self.expired_attributes = set()
-                # 0.9 and earlier compat
-                for k in list(self.callables):
-                    if self.callables[k] is self:
-                        self.expired_attributes.add(k)
-                        del self.callables[k]
+            self.expired_attributes = state_dict["expired_attributes"]
         else:
             if "expired_attributes" in state_dict:
                 self.expired_attributes = state_dict["expired_attributes"]
@@ -563,57 +632,61 @@ class InstanceState(interfaces.InspectionAttrInfo):
             ]
         )
         if self.key:
-            try:
-                self.identity_token = self.key[2]
-            except IndexError:
-                # 1.1 and earlier compat before identity_token
-                assert len(self.key) == 2
-                self.key = self.key + (None,)
-                self.identity_token = None
+            self.identity_token = self.key[2]
 
         if "load_path" in state_dict:
             self.load_path = PathRegistry.deserialize(state_dict["load_path"])
 
         state_dict["manager"](self, inst, state_dict)
 
-    def _reset(self, dict_, key):
+    def _reset(self, dict_: _InstanceDict, key: str) -> None:
         """Remove the given attribute and any
         callables associated with it."""
 
         old = dict_.pop(key, None)
-        if old is not None and self.manager[key].impl.collection:
-            self.manager[key].impl._invalidate_collection(old)
+        manager_impl = self.manager[key].impl
+        if old is not None and is_collection_impl(manager_impl):
+            manager_impl._invalidate_collection(old)
         self.expired_attributes.discard(key)
         if self.callables:
             self.callables.pop(key, None)
 
-    def _copy_callables(self, from_):
+    def _copy_callables(self, from_: InstanceState[Any]) -> None:
         if "callables" in from_.__dict__:
             self.callables = dict(from_.callables)
 
     @classmethod
-    def _instance_level_callable_processor(cls, manager, fn, key):
+    def _instance_level_callable_processor(
+        cls, manager: ClassManager[_T], fn: _LoaderCallable, key: Any
+    ) -> Callable[[InstanceState[_T], _InstanceDict, Row], None]:
         impl = manager[key].impl
-        if impl.collection:
+        if is_collection_impl(impl):
+            fixed_impl = impl
 
-            def _set_callable(state, dict_, row):
+            def _set_callable(
+                state: InstanceState[_T], dict_: _InstanceDict, row: Row
+            ) -> None:
                 if "callables" not in state.__dict__:
                     state.callables = {}
                 old = dict_.pop(key, None)
                 if old is not None:
-                    impl._invalidate_collection(old)
+                    fixed_impl._invalidate_collection(old)
                 state.callables[key] = fn
 
         else:
 
-            def _set_callable(state, dict_, row):
+            def _set_callable(
+                state: InstanceState[_T], dict_: _InstanceDict, row: Row
+            ) -> None:
                 if "callables" not in state.__dict__:
                     state.callables = {}
                 state.callables[key] = fn
 
         return _set_callable
 
-    def _expire(self, dict_, modified_set):
+    def _expire(
+        self, dict_: _InstanceDict, modified_set: Set[InstanceState[Any]]
+    ) -> None:
         self.expired = True
         if self.modified:
             modified_set.discard(self)
@@ -653,7 +726,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
         if self._last_known_values:
             self._last_known_values.update(
-                (k, dict_[k]) for k in self._last_known_values if k in dict_
+                {k: dict_[k] for k in self._last_known_values if k in dict_}
             )
 
         for key in self.manager._all_key_set.intersection(dict_):
@@ -661,7 +734,12 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
         self.manager.dispatch.expire(self, None)
 
-    def _expire_attributes(self, dict_, attribute_names, no_loader=False):
+    def _expire_attributes(
+        self,
+        dict_: _InstanceDict,
+        attribute_names: Iterable[str],
+        no_loader: bool = False,
+    ) -> None:
         pending = self.__dict__.get("_pending_mutations", None)
 
         callables = self.callables
@@ -676,15 +754,12 @@ class InstanceState(interfaces.InspectionAttrInfo):
                 if callables and key in callables:
                     del callables[key]
             old = dict_.pop(key, NO_VALUE)
-            if impl.collection and old is not NO_VALUE:
+            if is_collection_impl(impl) and old is not NO_VALUE:
                 impl._invalidate_collection(old)
 
-            if (
-                self._last_known_values
-                and key in self._last_known_values
-                and old is not NO_VALUE
-            ):
-                self._last_known_values[key] = old
+            lkv = self._last_known_values
+            if lkv is not None and key in lkv and old is not NO_VALUE:
+                lkv[key] = old
 
             self.committed_state.pop(key, None)
             if pending:
@@ -692,7 +767,9 @@ class InstanceState(interfaces.InspectionAttrInfo):
 
         self.manager.dispatch.expire(self, attribute_names)
 
-    def _load_expired(self, state, passive):
+    def _load_expired(
+        self, state: InstanceState[_T], passive: PassiveFlag
+    ) -> LoaderCallableStatus:
         """__call__ allows the InstanceState to act as a deferred
         callable for loading expired attributes, which is also
         serializable (picklable).
@@ -720,12 +797,12 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return ATTR_WAS_SET
 
     @property
-    def unmodified(self):
+    def unmodified(self) -> Set[str]:
         """Return the set of keys which have no uncommitted changes"""
 
         return set(self.manager).difference(self.committed_state)
 
-    def unmodified_intersection(self, keys):
+    def unmodified_intersection(self, keys: Iterable[str]) -> Set[str]:
         """Return self.unmodified.intersection(keys)."""
 
         return (
@@ -735,7 +812,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         )
 
     @property
-    def unloaded(self):
+    def unloaded(self) -> Set[str]:
         """Return the set of keys which do not have a loaded value.
 
         This includes expired attributes and any other attribute that
@@ -749,7 +826,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
         )
 
     @property
-    def unloaded_expirable(self):
+    def unloaded_expirable(self) -> Set[str]:
         """Return the set of keys which do not have a loaded value.
 
         This includes expired attributes and any other attribute that
@@ -759,19 +836,21 @@ class InstanceState(interfaces.InspectionAttrInfo):
         return self.unloaded
 
     @property
-    def _unloaded_non_object(self):
+    def _unloaded_non_object(self) -> Set[str]:
         return self.unloaded.intersection(
             attr
             for attr in self.manager
             if self.manager[attr].impl.accepts_scalar_loader
         )
 
-    def _instance_dict(self):
-        return None
-
     def _modified_event(
-        self, dict_, attr, previous, collection=False, is_userland=False
-    ):
+        self,
+        dict_: _InstanceDict,
+        attr: AttributeImpl,
+        previous: Any,
+        collection: bool = False,
+        is_userland: bool = False,
+    ) -> None:
         if attr:
             if not attr.send_modified_events:
                 return
@@ -782,6 +861,8 @@ class InstanceState(interfaces.InspectionAttrInfo):
                 )
             if attr.key not in self.committed_state or is_userland:
                 if collection:
+                    if TYPE_CHECKING:
+                        assert is_collection_impl(attr)
                     if previous is NEVER_SET:
                         if attr.key in dict_:
                             previous = dict_[attr.key]
@@ -790,8 +871,9 @@ class InstanceState(interfaces.InspectionAttrInfo):
                         previous = attr.copy(previous)
                 self.committed_state[attr.key] = previous
 
-            if attr.key in self._last_known_values:
-                self._last_known_values[attr.key] = NO_VALUE
+            lkv = self._last_known_values
+            if lkv is not None and attr.key in lkv:
+                lkv[attr.key] = NO_VALUE
 
         # assert self._strong_obj is None or self.modified
 
@@ -823,7 +905,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
                         pass
                     else:
                         if session._transaction is None:
-                            session._autobegin()
+                            session._autobegin_t()
 
             if inst is None and attr:
                 raise orm_exc.ObjectDereferencedError(
@@ -833,7 +915,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
                     % (self.manager[attr.key], base.state_class_str(self))
                 )
 
-    def _commit(self, dict_, keys):
+    def _commit(self, dict_: _InstanceDict, keys: Iterable[str]) -> None:
         """Commit attributes.
 
         This is used by a partial-attribute load operation to mark committed
@@ -862,7 +944,9 @@ class InstanceState(interfaces.InspectionAttrInfo):
             ):
                 del self.callables[key]
 
-    def _commit_all(self, dict_, instance_dict=None):
+    def _commit_all(
+        self, dict_: _InstanceDict, instance_dict: Optional[IdentityMap] = None
+    ) -> None:
         """commit all attributes unconditionally.
 
         This is used after a flush() or a full load/refresh
@@ -881,7 +965,11 @@ class InstanceState(interfaces.InspectionAttrInfo):
         self._commit_all_states([(self, dict_)], instance_dict)
 
     @classmethod
-    def _commit_all_states(self, iter_, instance_dict=None):
+    def _commit_all_states(
+        self,
+        iter_: Iterable[Tuple[InstanceState[Any], _InstanceDict]],
+        instance_dict: Optional[IdentityMap] = None,
+    ) -> None:
         """Mass / highly inlined version of commit_all()."""
 
         for state, dict_ in iter_:
@@ -916,12 +1004,17 @@ class AttributeState:
 
     """
 
-    def __init__(self, state, key):
+    __slots__ = ("state", "key")
+
+    state: InstanceState[Any]
+    key: str
+
+    def __init__(self, state: InstanceState[Any], key: str):
         self.state = state
         self.key = key
 
     @property
-    def loaded_value(self):
+    def loaded_value(self) -> Any:
         """The current value of this attribute as loaded from the database.
 
         If the value has not been loaded, or is otherwise not present
@@ -931,7 +1024,7 @@ class AttributeState:
         return self.state.dict.get(self.key, NO_VALUE)
 
     @property
-    def value(self):
+    def value(self) -> Any:
         """Return the value of this attribute.
 
         This operation is equivalent to accessing the object's
@@ -944,7 +1037,7 @@ class AttributeState:
         )
 
     @property
-    def history(self):
+    def history(self) -> History:
         """Return the current **pre-flush** change history for
         this attribute, via the :class:`.History` interface.
 
@@ -971,7 +1064,7 @@ class AttributeState:
         """
         return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE)
 
-    def load_history(self):
+    def load_history(self) -> History:
         """Return the current **pre-flush** change history for
         this attribute, via the :class:`.History` interface.
 
@@ -1008,17 +1101,22 @@ class PendingCollection:
 
     """
 
-    def __init__(self):
+    __slots__ = ("deleted_items", "added_items")
+
+    deleted_items: util.IdentitySet
+    added_items: util.OrderedIdentitySet
+
+    def __init__(self) -> None:
         self.deleted_items = util.IdentitySet()
         self.added_items = util.OrderedIdentitySet()
 
-    def append(self, value):
+    def append(self, value: Any) -> None:
         if value in self.deleted_items:
             self.deleted_items.remove(value)
         else:
             self.added_items.add(value)
 
-    def remove(self, value):
+    def remove(self, value: Any) -> None:
         if value in self.added_items:
             self.added_items.remove(value)
         else:
index 1afeab05bcb20150b902d46e5ae2a7e699deb76b..b7bf96558534e5aa7ad9923e6aab76b6bdeed734 100644 (file)
@@ -16,12 +16,15 @@ from typing import Any
 from typing import Callable
 from typing import Optional
 from typing import Tuple
+from typing import TypeVar
 from typing import Union
 
 from .. import exc as sa_exc
 from .. import util
 from ..util.typing import Literal
 
+_F = TypeVar("_F", bound=Callable[..., Any])
+
 
 class _StateChangeState(Enum):
     pass
@@ -60,7 +63,7 @@ class _StateChange:
             Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...]
         ],
         moves_to: _StateChangeState,
-    ) -> Callable[..., Any]:
+    ) -> Callable[[_F], _F]:
         """Method decorator declaring valid states.
 
         :param prerequisite_states: sequence of acceptable prerequisite
index da098e8c5f770408cad27ab4f753d415a0670602..9ff284e733d6e280a9c86b1c65f713d803bca0ba 100644 (file)
@@ -15,6 +15,12 @@ organizes them in order of dependency, and executes.
 
 from __future__ import annotations
 
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import Set
+from typing import TYPE_CHECKING
+
 from . import attributes
 from . import exc as orm_exc
 from . import util as orm_util
@@ -23,6 +29,15 @@ from .. import util
 from ..util import topological
 
 
+if TYPE_CHECKING:
+    from .dependency import DependencyProcessor
+    from .interfaces import MapperProperty
+    from .mapper import Mapper
+    from .session import Session
+    from .session import SessionTransaction
+    from .state import InstanceState
+
+
 def track_cascade_events(descriptor, prop):
     """Establish event listeners on object attributes which handle
     cascade-on-set/append.
@@ -131,7 +146,13 @@ def track_cascade_events(descriptor, prop):
 
 
 class UOWTransaction:
-    def __init__(self, session):
+    session: Session
+    transaction: SessionTransaction
+    attributes: Dict[str, Any]
+    deps: util.defaultdict[Mapper[Any], Set[DependencyProcessor]]
+    mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]]
+
+    def __init__(self, session: Session):
         self.session = session
 
         # dictionary used by external actors to
@@ -275,13 +296,13 @@ class UOWTransaction:
 
     def register_object(
         self,
-        state,
-        isdelete=False,
-        listonly=False,
-        cancel_delete=False,
-        operation=None,
-        prop=None,
-    ):
+        state: InstanceState[Any],
+        isdelete: bool = False,
+        listonly: bool = False,
+        cancel_delete: bool = False,
+        operation: Optional[str] = None,
+        prop: Optional[MapperProperty] = None,
+    ) -> bool:
         if not self.session._contains_state(state):
             # this condition is normal when objects are registered
             # as part of a relationship cascade operation.  it should
@@ -408,7 +429,7 @@ class UOWTransaction:
             [a for a in self.postsort_actions.values() if not a.disabled]
         ).difference(cycles)
 
-    def execute(self):
+    def execute(self) -> None:
         postsort_actions = self._generate_actions()
 
         postsort_actions = sorted(
@@ -435,7 +456,7 @@ class UOWTransaction:
             for rec in topological.sort(self.dependencies, postsort_actions):
                 rec.execute(self)
 
-    def finalize_flush_changes(self):
+    def finalize_flush_changes(self) -> None:
         """Mark processed objects as clean / deleted after a successful
         flush().
 
index baca8f54766a32d0e3ea7f26f25920aa488f4c6d..233085f305b0a7352ac6d506f803119aab188085 100644 (file)
@@ -13,7 +13,6 @@ import typing
 from typing import Any
 from typing import Generic
 from typing import Optional
-from typing import overload
 from typing import Tuple
 from typing import Type
 from typing import TypeVar
@@ -22,7 +21,6 @@ import weakref
 
 from . import attributes  # noqa
 from .base import _class_to_mapper  # noqa
-from .base import _IdentityKeyType
 from .base import _never_set  # noqa
 from .base import _none_set  # noqa
 from .base import attribute_str  # noqa
@@ -62,6 +60,9 @@ from ..util.typing import de_stringify_annotation
 from ..util.typing import is_origin_of
 
 if typing.TYPE_CHECKING:
+    from ._typing import _EntityType
+    from ._typing import _IdentityKeyType
+    from ._typing import _InternalEntityType
     from .mapper import Mapper
     from ..engine import Row
     from ..sql._typing import _PropagateAttrsType
@@ -297,27 +298,13 @@ def polymorphic_union(
     return sql.union_all(*result).alias(aliasname)
 
 
-@overload
 def identity_key(
-    class_: type, ident: Tuple[Any, ...], *, identity_token: Optional[str]
-) -> _IdentityKeyType:
-    ...
-
-
-@overload
-def identity_key(*, instance: Any) -> _IdentityKeyType:
-    ...
-
-
-@overload
-def identity_key(
-    class_: type, *, row: "Row", identity_token: Optional[str]
-) -> _IdentityKeyType:
-    ...
-
-
-def identity_key(
-    class_=None, ident=None, *, instance=None, row=None, identity_token=None
+    class_: Optional[Type[Any]] = None,
+    ident: Union[Any, Tuple[Any, ...]] = None,
+    *,
+    instance: Optional[Any] = None,
+    row: Optional[Row] = None,
+    identity_token: Optional[Any] = None,
 ) -> _IdentityKeyType:
     r"""Generate "identity key" tuples, as are used as keys in the
     :attr:`.Session.identity_map` dictionary.
@@ -634,6 +621,7 @@ class AliasedInsp(
     sql_base.HasCacheKey,
     InspectionAttr,
     MemoizedSlots,
+    Generic[_T],
 ):
     """Provide an inspection interface for an
     :class:`.AliasedClass` object.
@@ -699,8 +687,8 @@ class AliasedInsp(
 
     def __init__(
         self,
-        entity,
-        inspected,
+        entity: _EntityType,
+        inspected: _InternalEntityType,
         selectable,
         name,
         with_polymorphic_mappers,
@@ -1797,6 +1785,32 @@ def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type):
     return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
 
 
+def _cleanup_mapped_str_annotation(annotation):
+    # fix up an annotation that comes in as the form:
+    # 'Mapped[List[Address]]'  so that it instead looks like:
+    # 'Mapped[List["Address"]]' , which will allow us to get
+    # "Address" as a string
+    mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
+    if mm and mm.group(1) == "Mapped":
+        stack = []
+        inner = mm
+        while True:
+            stack.append(inner.group(1))
+            g2 = inner.group(2)
+            inner = re.match(r"^(.+?)\[(.+)\]$", g2)
+            if inner is None:
+                stack.append(g2)
+                break
+
+        # stack: ['Mapped', 'List', 'Address']
+        if not re.match(r"""^["'].*["']$""", stack[-1]):
+            stack[-1] = f'"{stack[-1]}"'
+            # stack: ['Mapped', 'List', '"Address"']
+
+            annotation = "[".join(stack) + ("]" * (len(stack) - 1))
+    return annotation
+
+
 def _extract_mapped_subtype(
     raw_annotation: Union[type, str],
     cls: type,
@@ -1816,7 +1830,9 @@ def _extract_mapped_subtype(
             )
         return None
 
-    annotated = de_stringify_annotation(cls, raw_annotation)
+    annotated = de_stringify_annotation(
+        cls, raw_annotation, _cleanup_mapped_str_annotation
+    )
 
     if is_dataclass_field:
         return annotated
index bc1e0672c4cfe779e3d0d8399bce0cd06a2976f3..7e3a1c4e8d0251fcaecb3349667f5ddb964d86e0 100644 (file)
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
+from sqlalchemy.sql.base import Executable
 from . import roles
 from .. import util
 from ..inspection import Inspectable
@@ -183,10 +184,14 @@ if TYPE_CHECKING:
     def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:
         ...
 
-    def is_select_base(t: ReturnsRows) -> TypeGuard[SelectBase]:
+    def is_select_base(
+        t: Union[Executable, ReturnsRows]
+    ) -> TypeGuard[SelectBase]:
         ...
 
-    def is_select_statement(t: ReturnsRows) -> TypeGuard[Select]:
+    def is_select_statement(
+        t: Union[Executable, ReturnsRows]
+    ) -> TypeGuard[Select]:
         ...
 
     def is_table(t: FromClause) -> TypeGuard[TableClause]:
index 7fb9c260267e438be6cdfaafcc8549dc0e7d4aae..ccd5e8c40e224735a128eae3fc82bf62d1c25d84 100644 (file)
@@ -648,7 +648,9 @@ class CompileState:
             return None
 
     @classmethod
-    def _get_plugin_class_for_plugin(cls, statement, plugin_name):
+    def _get_plugin_class_for_plugin(
+        cls, statement: Executable, plugin_name: str
+    ) -> Optional[Type[CompileState]]:
         try:
             return cls.plugins[
                 (plugin_name, statement._effective_plugin_target)
@@ -790,7 +792,7 @@ class Options(metaclass=_MetaOptions):
         )
 
     @classmethod
-    def isinstance(cls, klass):
+    def isinstance(cls, klass: Type[Any]) -> bool:
         return issubclass(cls, klass)
 
     @hybridmethod
@@ -912,6 +914,8 @@ class ExecutableOption(HasCopyInternals):
 
     _is_has_cache_key = False
 
+    _is_core = True
+
     def _clone(self, **kw):
         """Create a shallow copy of this ExecutableOption."""
         c = self.__class__.__new__(self.__class__)
index 4c71ca38b13cbe8f273764fc65e70768f95a9bd9..623bb0be2ee92a8643bcabfa058323238e5742a6 100644 (file)
@@ -114,7 +114,7 @@ def _deep_is_literal(element):
                 schema.SchemaEventTarget,
                 HasCacheKey,
                 Options,
-                util.langhelpers._symbol,
+                util.langhelpers.symbol,
             ),
         )
         and not hasattr(element, "__clause_element__")
index d7cc327333e0406f41e9f4009f3085e9b299d43c..99a6baa890c3b99019ad71b3682cd25b76b964d4 100644 (file)
@@ -3037,7 +3037,9 @@ class ForUpdateArg(ClauseElement):
     skip_locked: bool
 
     @classmethod
-    def _from_argument(cls, with_for_update):
+    def _from_argument(
+        cls, with_for_update: Union[ForUpdateArg, None, bool, Dict[str, Any]]
+    ) -> Optional[ForUpdateArg]:
         if isinstance(with_for_update, ForUpdateArg):
             return with_for_update
         elif with_for_update in (None, False):
@@ -3045,7 +3047,7 @@ class ForUpdateArg(ClauseElement):
         elif with_for_update is True:
             return ForUpdateArg()
         else:
-            return ForUpdateArg(**with_for_update)
+            return ForUpdateArg(**cast("Dict[str, Any]", with_for_update))
 
     def __eq__(self, other):
         return (
index 4253aa61bcdb01510f1aed3cf5356e8f0b368af1..da6292fcf439073472b60ae2ce15fc54b382a962 100644 (file)
@@ -49,6 +49,7 @@ from .config import combinations_list
 from .config import db
 from .config import fixture
 from .config import requirements as requires
+from .config import skip_test
 from .exclusions import _is_excluded
 from .exclusions import _server_version
 from .exclusions import against as _against
index c0c2e7dfb7db8a4c87e599314f8f979d22f2a374..6d41231d98b533d8bc2bba5c2671b8230e9350aa 100644 (file)
@@ -96,6 +96,7 @@ from .langhelpers import dictlike_iteritems as dictlike_iteritems
 from .langhelpers import duck_type_collection as duck_type_collection
 from .langhelpers import ellipses_string as ellipses_string
 from .langhelpers import EnsureKWArg as EnsureKWArg
+from .langhelpers import FastIntFlag as FastIntFlag
 from .langhelpers import format_argspec_init as format_argspec_init
 from .langhelpers import format_argspec_plus as format_argspec_plus
 from .langhelpers import generic_fn_descriptor as generic_fn_descriptor
index eb5b16b650c75bd53cdc7ff6ca0e566e646d3705..eea76f60b08b518da33c0d3aeee631e9c1d36a30 100644 (file)
@@ -22,6 +22,7 @@ from typing import Generic
 from typing import Iterable
 from typing import Iterator
 from typing import List
+from typing import Mapping
 from typing import Optional
 from typing import overload
 from typing import Set
@@ -123,7 +124,7 @@ def merge_lists_w_ordering(a, b):
     return result
 
 
-def coerce_to_immutabledict(d):
+def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
     if not d:
         return EMPTY_DICT
     elif isinstance(d, immutabledict):
@@ -161,6 +162,8 @@ class FacadeDict(ImmutableDictBase[_KT, _VT]):
 
 _DT = TypeVar("_DT", bound=Any)
 
+_F = TypeVar("_F", bound=Any)
+
 
 class Properties(Generic[_T]):
     """Provide a __getattr__/__setattr__ interface over a dict."""
@@ -169,7 +172,7 @@ class Properties(Generic[_T]):
 
     _data: Dict[str, _T]
 
-    def __init__(self, data):
+    def __init__(self, data: Dict[str, _T]):
         object.__setattr__(self, "_data", data)
 
     def __len__(self) -> int:
@@ -178,30 +181,30 @@ class Properties(Generic[_T]):
     def __iter__(self) -> Iterator[_T]:
         return iter(list(self._data.values()))
 
-    def __dir__(self):
+    def __dir__(self) -> List[str]:
         return dir(super(Properties, self)) + [
             str(k) for k in self._data.keys()
         ]
 
-    def __add__(self, other):
-        return list(self) + list(other)
+    def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
+        return list(self) + list(other)  # type: ignore
 
-    def __setitem__(self, key, obj):
+    def __setitem__(self, key: str, obj: _T) -> None:
         self._data[key] = obj
 
     def __getitem__(self, key: str) -> _T:
         return self._data[key]
 
-    def __delitem__(self, key):
+    def __delitem__(self, key: str) -> None:
         del self._data[key]
 
-    def __setattr__(self, key, obj):
+    def __setattr__(self, key: str, obj: _T) -> None:
         self._data[key] = obj
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         return {"_data": self._data}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         object.__setattr__(self, "_data", state["_data"])
 
     def __getattr__(self, key: str) -> _T:
@@ -213,12 +216,12 @@ class Properties(Generic[_T]):
     def __contains__(self, key: str) -> bool:
         return key in self._data
 
-    def as_readonly(self) -> "ReadOnlyProperties[_T]":
+    def as_readonly(self) -> ReadOnlyProperties[_T]:
         """Return an immutable proxy for this :class:`.Properties`."""
 
         return ReadOnlyProperties(self._data)
 
-    def update(self, value):
+    def update(self, value: Dict[str, _T]) -> None:
         self._data.update(value)
 
     @overload
@@ -249,7 +252,7 @@ class Properties(Generic[_T]):
     def has_key(self, key: str) -> bool:
         return key in self._data
 
-    def clear(self):
+    def clear(self) -> None:
         self._data.clear()
 
 
@@ -318,7 +321,7 @@ class WeakSequence:
 
 
 class OrderedIdentitySet(IdentitySet):
-    def __init__(self, iterable=None):
+    def __init__(self, iterable: Optional[Iterable[Any]] = None):
         IdentitySet.__init__(self)
         self._members = OrderedDict()
         if iterable:
@@ -615,7 +618,9 @@ class ScopedRegistry(Generic[_T]):
     scopefunc: _ScopeFuncType
     registry: Any
 
-    def __init__(self, createfunc, scopefunc):
+    def __init__(
+        self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
+    ):
         """Construct a new :class:`.ScopedRegistry`.
 
         :param createfunc:  A creation function that will generate
index d649a0bea712ed066ba3e450309899a773f76a6a..725f6930eec14b02656047b124ca41ab83616868 100644 (file)
@@ -263,52 +263,54 @@ class IdentitySet:
 
     """
 
-    def __init__(self, iterable=None):
+    _members: Dict[int, Any]
+
+    def __init__(self, iterable: Optional[Iterable[Any]] = None):
         self._members = dict()
         if iterable:
             self.update(iterable)
 
-    def add(self, value):
+    def add(self, value: Any) -> None:
         self._members[id(value)] = value
 
-    def __contains__(self, value):
+    def __contains__(self, value: Any) -> bool:
         return id(value) in self._members
 
-    def remove(self, value):
+    def remove(self, value: Any) -> None:
         del self._members[id(value)]
 
-    def discard(self, value):
+    def discard(self, value: Any) -> None:
         try:
             self.remove(value)
         except KeyError:
             pass
 
-    def pop(self):
+    def pop(self) -> Any:
         try:
             pair = self._members.popitem()
             return pair[1]
         except KeyError:
             raise KeyError("pop from an empty set")
 
-    def clear(self):
+    def clear(self) -> None:
         self._members.clear()
 
-    def __cmp__(self, other):
+    def __cmp__(self, other: Any) -> NoReturn:
         raise TypeError("cannot compare sets using cmp()")
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         if isinstance(other, IdentitySet):
             return self._members == other._members
         else:
             return False
 
-    def __ne__(self, other):
+    def __ne__(self, other: Any) -> bool:
         if isinstance(other, IdentitySet):
             return self._members != other._members
         else:
             return True
 
-    def issubset(self, iterable):
+    def issubset(self, iterable: Iterable[Any]) -> bool:
         if isinstance(iterable, self.__class__):
             other = iterable
         else:
@@ -322,17 +324,17 @@ class IdentitySet:
             return False
         return True
 
-    def __le__(self, other):
+    def __le__(self, other: Any) -> bool:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.issubset(other)
 
-    def __lt__(self, other):
+    def __lt__(self, other: Any) -> bool:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         return len(self) < len(other) and self.issubset(other)
 
-    def issuperset(self, iterable):
+    def issuperset(self, iterable: Iterable[Any]) -> bool:
         if isinstance(iterable, self.__class__):
             other = iterable
         else:
@@ -347,38 +349,38 @@ class IdentitySet:
             return False
         return True
 
-    def __ge__(self, other):
+    def __ge__(self, other: Any) -> bool:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.issuperset(other)
 
-    def __gt__(self, other):
+    def __gt__(self, other: Any) -> bool:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         return len(self) > len(other) and self.issuperset(other)
 
-    def union(self, iterable):
+    def union(self, iterable: Iterable[Any]) -> IdentitySet:
         result = self.__class__()
         members = self._members
         result._members.update(members)
         result._members.update((id(obj), obj) for obj in iterable)
         return result
 
-    def __or__(self, other):
+    def __or__(self, other: Any) -> IdentitySet:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.union(other)
 
-    def update(self, iterable):
+    def update(self, iterable: Iterable[Any]) -> None:
         self._members.update((id(obj), obj) for obj in iterable)
 
-    def __ior__(self, other):
+    def __ior__(self, other: Any) -> IdentitySet:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         self.update(other)
         return self
 
-    def difference(self, iterable):
+    def difference(self, iterable: Iterable[Any]) -> IdentitySet:
         result = self.__new__(self.__class__)
         other: Collection[Any]
 
@@ -391,21 +393,21 @@ class IdentitySet:
         }
         return result
 
-    def __sub__(self, other):
+    def __sub__(self, other: IdentitySet) -> IdentitySet:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.difference(other)
 
-    def difference_update(self, iterable):
+    def difference_update(self, iterable: Iterable[Any]) -> None:
         self._members = self.difference(iterable)._members
 
-    def __isub__(self, other):
+    def __isub__(self, other: IdentitySet) -> IdentitySet:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         self.difference_update(other)
         return self
 
-    def intersection(self, iterable):
+    def intersection(self, iterable: Iterable[Any]) -> IdentitySet:
         result = self.__new__(self.__class__)
 
         other: Collection[Any]
@@ -419,21 +421,21 @@ class IdentitySet:
         }
         return result
 
-    def __and__(self, other):
+    def __and__(self, other: IdentitySet) -> IdentitySet:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.intersection(other)
 
-    def intersection_update(self, iterable):
+    def intersection_update(self, iterable: Iterable[Any]) -> None:
         self._members = self.intersection(iterable)._members
 
-    def __iand__(self, other):
+    def __iand__(self, other: IdentitySet) -> IdentitySet:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         self.intersection_update(other)
         return self
 
-    def symmetric_difference(self, iterable):
+    def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet:
         result = self.__new__(self.__class__)
         if isinstance(iterable, self.__class__):
             other = iterable._members
@@ -447,37 +449,37 @@ class IdentitySet:
         )
         return result
 
-    def __xor__(self, other):
+    def __xor__(self, other: IdentitySet) -> IdentitySet:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.symmetric_difference(other)
 
-    def symmetric_difference_update(self, iterable):
+    def symmetric_difference_update(self, iterable: Iterable[Any]) -> None:
         self._members = self.symmetric_difference(iterable)._members
 
-    def __ixor__(self, other):
+    def __ixor__(self, other: IdentitySet) -> IdentitySet:
         if not isinstance(other, IdentitySet):
             return NotImplemented
         self.symmetric_difference(other)
         return self
 
-    def copy(self):
+    def copy(self) -> IdentitySet:
         result = self.__new__(self.__class__)
         result._members = self._members.copy()
         return result
 
     __copy__ = copy
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._members)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Any]:
         return iter(self._members.values())
 
-    def __hash__(self):
+    def __hash__(self) -> NoReturn:
         raise TypeError("set objects are unhashable")
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s(%r)" % (type(self).__name__, list(self._members.values()))
 
 
index 3e89c72bbbb005d83d3f732d7ac2a39814f559c7..2cb9c45d6bea6a1f6873ba15a4efadcc9a2f1ea6 100644 (file)
@@ -12,6 +12,7 @@ modules, classes, hierarchies, attributes, functions, and methods.
 from __future__ import annotations
 
 import collections
+import enum
 from functools import update_wrapper
 import hashlib
 import inspect
@@ -671,13 +672,13 @@ def format_argspec_init(method, grouped=True):
 
 
 def create_proxy_methods(
-    target_cls,
-    target_cls_sphinx_name,
-    proxy_cls_sphinx_name,
-    classmethods=(),
-    methods=(),
-    attributes=(),
-):
+    target_cls: Type[Any],
+    target_cls_sphinx_name: str,
+    proxy_cls_sphinx_name: str,
+    classmethods: Sequence[str] = (),
+    methods: Sequence[str] = (),
+    attributes: Sequence[str] = (),
+) -> Callable[[_T], _T]:
     """A class decorator indicating attributes should refer to a proxy
     class.
 
@@ -1539,24 +1540,50 @@ class hybridmethod(Generic[_T]):
         return self
 
 
-class _symbol(int):
+class symbol(int):
+    """A constant symbol.
+
+    >>> symbol('foo') is symbol('foo')
+    True
+    >>> symbol('foo')
+    <symbol 'foo>
+
+    A slight refinement of the MAGICCOOKIE=object() pattern.  The primary
+    advantage of symbol() is its repr().  They are also singletons.
+
+    Repeated calls of symbol('name') will all return the same instance.
+
+    In SQLAlchemy 2.0, symbol() is used for the implementation of
+    ``_FastIntFlag``, but otherwise should be mostly replaced by
+    ``enum.Enum`` and variants.
+
+
+    """
+
     name: str
 
+    symbols: Dict[str, symbol] = {}
+    _lock = threading.Lock()
+
     def __new__(
         cls,
         name: str,
         doc: Optional[str] = None,
         canonical: Optional[int] = None,
-    ) -> "_symbol":
-        """Construct a new named symbol."""
-        assert isinstance(name, str)
-        if canonical is None:
-            canonical = hash(name)
-        v = int.__new__(_symbol, canonical)
-        v.name = name
-        if doc:
-            v.__doc__ = doc
-        return v
+    ) -> symbol:
+        with cls._lock:
+            sym = cls.symbols.get(name)
+            if sym is None:
+                assert isinstance(name, str)
+                if canonical is None:
+                    canonical = hash(name)
+                sym = int.__new__(symbol, canonical)
+                sym.name = name
+                if doc:
+                    sym.__doc__ = doc
+
+                cls.symbols[name] = sym
+            return sym
 
     def __reduce__(self):
         return symbol, (self.name, "x", int(self))
@@ -1565,90 +1592,60 @@ class _symbol(int):
         return repr(self)
 
     def __repr__(self):
-        return "symbol(%r)" % self.name
+        return f"symbol({self.name!r})"
 
 
-_symbol.__name__ = "symbol"
+class _IntFlagMeta(type):
+    def __init__(
+        cls,
+        classname: str,
+        bases: Tuple[Type[Any], ...],
+        dict_: Dict[str, Any],
+        **kw: Any,
+    ) -> None:
+        items: List[symbol]
+        cls._items = items = []
+        for k, v in dict_.items():
+            if isinstance(v, int):
+                sym = symbol(k, canonical=v)
+            elif not k.startswith("_"):
+                raise TypeError("Expected integer values for IntFlag")
+            else:
+                continue
+            setattr(cls, k, sym)
+            items.append(sym)
 
+    def __iter__(self) -> Iterator[symbol]:
+        return iter(self._items)
 
-class symbol:
-    """A constant symbol.
 
-    >>> symbol('foo') is symbol('foo')
-    True
-    >>> symbol('foo')
-    <symbol 'foo>
+class _FastIntFlag(metaclass=_IntFlagMeta):
+    """An 'IntFlag' copycat that isn't slow when performing bitwise
+    operations.
 
-    A slight refinement of the MAGICCOOKIE=object() pattern.  The primary
-    advantage of symbol() is its repr().  They are also singletons.
+    the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
+    and ``_FastIntFlag`` otherwise.
 
-    Repeated calls of symbol('name') will all return the same instance.
+    """
 
-    The optional ``doc`` argument assigns to ``__doc__``.  This
-    is strictly so that Sphinx autoattr picks up the docstring we want
-    (it doesn't appear to pick up the in-module docstring if the datamember
-    is in a different module - autoattribute also blows up completely).
-    If Sphinx fixes/improves this then we would no longer need
-    ``doc`` here.
 
-    """
+if TYPE_CHECKING:
+    from enum import IntFlag
 
-    symbols: Dict[str, "_symbol"] = {}
-    _lock = threading.Lock()
+    FastIntFlag = IntFlag
+else:
+    FastIntFlag = _FastIntFlag
 
-    def __new__(  # type: ignore[misc]
-        cls,
-        name: str,
-        doc: Optional[str] = None,
-        canonical: Optional[int] = None,
-    ) -> _symbol:
-        with cls._lock:
-            sym = cls.symbols.get(name)
-            if sym is None:
-                cls.symbols[name] = sym = _symbol(name, doc, canonical)
-            return sym
 
-    @classmethod
-    def parse_user_argument(
-        cls, arg, choices, name, resolve_symbol_names=False
-    ):
-        """Given a user parameter, parse the parameter into a chosen symbol.
-
-        The user argument can be a string name that matches the name of a
-        symbol, or the symbol object itself, or any number of alternate choices
-        such as True/False/ None etc.
-
-        :param arg: the user argument.
-        :param choices: dictionary of symbol object to list of possible
-         entries.
-        :param name: name of the argument.   Used in an :class:`.ArgumentError`
-         that is raised if the parameter doesn't match any available argument.
-        :param resolve_symbol_names: include the name of each symbol as a valid
-         entry.
-
-        """
-        # note using hash lookup is tricky here because symbol's `__hash__`
-        # is its int value which we don't want included in the lookup
-        # explicitly, so we iterate and compare each.
-        for sym, choice in choices.items():
-            if arg is sym:
-                return sym
-            elif resolve_symbol_names and arg == sym.name:
-                return sym
-            elif arg in choice:
-                return sym
-
-        if arg is None:
-            return None
-
-        raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+_E = TypeVar("_E", bound=enum.Enum)
 
 
 def parse_user_argument_for_enum(
     arg: Any,
-    choices: Dict[_T, List[Any]],
+    choices: Dict[_E, List[Any]],
     name: str,
-) -> Optional[_T]:
+    resolve_symbol_names: bool = False,
+) -> Optional[_E]:
     """Given a user parameter, parse the parameter into a chosen value
     from a list of choice objects, typically Enum values.
 
@@ -1663,18 +1660,18 @@ def parse_user_argument_for_enum(
         that is raised if the parameter doesn't match any available argument.
 
     """
-    # TODO: use whatever built in thing Enum provides for this,
-    # if applicable
     for enum_value, choice in choices.items():
         if arg is enum_value:
             return enum_value
+        elif resolve_symbol_names and arg == enum_value.name:
+            return enum_value
         elif arg in choice:
             return enum_value
 
     if arg is None:
         return None
 
-    raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+    raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
 
 
 _creation_order = 1
index c861c83b3fd3f0810e812a3db3f3acbe71a95675..907c51064923cc5dbc29c3f8eba97ce58755a5fd 100644 (file)
@@ -23,6 +23,8 @@ _FN = TypeVar("_FN", bound=Callable[..., Any])
 
 if TYPE_CHECKING:
     from sqlalchemy.engine import default as engine_default
+    from sqlalchemy.orm import session as orm_session
+    from sqlalchemy.orm import util as orm_util
     from sqlalchemy.sql import dml as sql_dml
     from sqlalchemy.sql import util as sql_util
 
index df54017da78d3c0394b2ff9b75ba39b76e31aa05..dd574f3b0fe3f943f190b2023487d5e4474e9780 100644 (file)
@@ -3,10 +3,12 @@ from __future__ import annotations
 import sys
 import typing
 from typing import Any
+from typing import Callable
 from typing import cast
 from typing import Dict
 from typing import ForwardRef
 from typing import Iterable
+from typing import Optional
 from typing import Tuple
 from typing import Type
 from typing import TypeVar
@@ -82,7 +84,9 @@ else:
 
 
 def de_stringify_annotation(
-    cls: Type[Any], annotation: Union[str, Type[Any]]
+    cls: Type[Any],
+    annotation: Union[str, Type[Any]],
+    str_cleanup_fn: Optional[Callable[[str], str]] = None,
 ) -> Union[str, Type[Any]]:
     """Resolve annotations that may be string based into real objects.
 
@@ -105,9 +109,13 @@ def de_stringify_annotation(
         annotation = cast(ForwardRef, annotation).__forward_arg__
 
     if isinstance(annotation, str):
+        if str_cleanup_fn:
+            annotation = str_cleanup_fn(annotation)
+
         base_globals: "Dict[str, Any]" = getattr(
             sys.modules.get(cls.__module__, None), "__dict__", {}
         )
+
         try:
             annotation = eval(annotation, base_globals, None)
         except NameError:
index 012f1bffa98616bb0f81130f1ea6a83baffbaab9..8f7f50715a040926de88c977ef415a85dda431a5 100644 (file)
@@ -99,6 +99,10 @@ module = [
     "sqlalchemy.engine.*",
     "sqlalchemy.pool.*",
 
+    "sqlalchemy.orm.scoping",
+    "sqlalchemy.orm.session",
+    "sqlalchemy.orm.state",
+
     # modules
     "sqlalchemy.events",
     "sqlalchemy.exc",
index e22340da682903645946a12560cd4c349269415f..c5a47ddf97a8130600cda249de692f24731017e2 100644 (file)
@@ -27,6 +27,7 @@ from sqlalchemy.testing.util import gc_collect
 from sqlalchemy.testing.util import picklers
 from sqlalchemy.util import classproperty
 from sqlalchemy.util import compat
+from sqlalchemy.util import FastIntFlag
 from sqlalchemy.util import get_callable_argspec
 from sqlalchemy.util import langhelpers
 from sqlalchemy.util import preloaded
@@ -2300,6 +2301,20 @@ class SymbolTest(fixtures.TestBase):
         assert sym1 is not sym3
         assert sym1 != sym3
 
+    def test_fast_int_flag(self):
+        class Enum(FastIntFlag):
+            sym1 = 1
+            sym2 = 2
+
+            sym3 = 3
+
+        assert Enum.sym1 is not Enum.sym3
+        assert Enum.sym1 != Enum.sym3
+
+        assert Enum.sym1.name == "sym1"
+
+        eq_(list(Enum), [Enum.sym1, Enum.sym2, Enum.sym3])
+
     def test_pickle(self):
         sym1 = util.symbol("foo")
         sym2 = util.symbol("foo")
@@ -2338,17 +2353,19 @@ class SymbolTest(fixtures.TestBase):
         assert (sym1 | sym2) & (sym2 | sym4)
 
     def test_parser(self):
-        sym1 = util.symbol("sym1", canonical=1)
-        sym2 = util.symbol("sym2", canonical=2)
-        sym3 = util.symbol("sym3", canonical=4)
-        sym4 = util.symbol("sym4", canonical=8)
+        class MyEnum(FastIntFlag):
+            sym1 = 1
+            sym2 = 2
+            sym3 = 4
+            sym4 = 8
 
+        sym1, sym2, sym3, sym4 = tuple(MyEnum)
         lookup_one = {sym1: [], sym2: [True], sym3: [False], sym4: [None]}
         lookup_two = {sym1: [], sym2: [True], sym3: [False]}
         lookup_three = {sym1: [], sym2: ["symbol2"], sym3: []}
 
         is_(
-            util.symbol.parse_user_argument(
+            langhelpers.parse_user_argument_for_enum(
                 "sym2", lookup_one, "some_name", resolve_symbol_names=True
             ),
             sym2,
@@ -2357,35 +2374,41 @@ class SymbolTest(fixtures.TestBase):
         assert_raises_message(
             exc.ArgumentError,
             "Invalid value for 'some_name': 'sym2'",
-            util.symbol.parse_user_argument,
+            langhelpers.parse_user_argument_for_enum,
             "sym2",
             lookup_one,
             "some_name",
         )
         is_(
-            util.symbol.parse_user_argument(
+            langhelpers.parse_user_argument_for_enum(
                 True, lookup_one, "some_name", resolve_symbol_names=False
             ),
             sym2,
         )
 
         is_(
-            util.symbol.parse_user_argument(sym2, lookup_one, "some_name"),
+            langhelpers.parse_user_argument_for_enum(
+                sym2, lookup_one, "some_name"
+            ),
             sym2,
         )
 
         is_(
-            util.symbol.parse_user_argument(None, lookup_one, "some_name"),
+            langhelpers.parse_user_argument_for_enum(
+                None, lookup_one, "some_name"
+            ),
             sym4,
         )
 
         is_(
-            util.symbol.parse_user_argument(None, lookup_two, "some_name"),
+            langhelpers.parse_user_argument_for_enum(
+                None, lookup_two, "some_name"
+            ),
             None,
         )
 
         is_(
-            util.symbol.parse_user_argument(
+            langhelpers.parse_user_argument_for_enum(
                 "symbol2", lookup_three, "some_name"
             ),
             sym2,
@@ -2394,7 +2417,7 @@ class SymbolTest(fixtures.TestBase):
         assert_raises_message(
             exc.ArgumentError,
             "Invalid value for 'some_name': 'foo'",
-            util.symbol.parse_user_argument,
+            langhelpers.parse_user_argument_for_enum,
             "foo",
             lookup_three,
             "some_name",
diff --git a/test/ext/mypy/plain_files/session.py b/test/ext/mypy/plain_files/session.py
new file mode 100644 (file)
index 0000000..24c685e
--- /dev/null
@@ -0,0 +1,50 @@
+from __future__ import annotations
+
+from typing import List
+from typing import Sequence
+
+from sqlalchemy import create_engine
+from sqlalchemy import ForeignKey
+from sqlalchemy import select
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class User(Base):
+    __tablename__ = "user"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    name: Mapped[str]
+    addresses: Mapped[List[Address]] = relationship(back_populates="user")
+
+
+class Address(Base):
+    __tablename__ = "address"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    user_id = mapped_column(ForeignKey("user.id"))
+    email: Mapped[str]
+
+    user: Mapped[User] = relationship(back_populates="addresses")
+
+
+e = create_engine("sqlite://")
+Base.metadata.create_all(e)
+
+with Session(e) as sess:
+    u1 = User(name="u1")
+    sess.add(u1)
+    sess.add_all([Address(user=u1, email="e1"), Address(user=u1, email="e2")])
+    sess.commit()
+
+with Session(e) as sess:
+    users: Sequence[User] = sess.scalars(
+        select(User), execution_options={"stream_results": False}
+    ).all()
index c7022dc31cf4a2d963675d28330bd4076a3ec8ad..f8abd686a0a21ff6ed8a06e237be1c50cac1f2b6 100644 (file)
@@ -1,9 +1,56 @@
 from __future__ import annotations
 
+from typing import List
+
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.testing import is_
 from .test_typed_mapping import MappedColumnTest  # noqa
-from .test_typed_mapping import RelationshipLHSTest  # noqa
+from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
 
 """runs the annotation-sensitive tests from test_typed_mappings while
 having ``from __future__ import annotations`` in effect.
 
 """
+
+
+class RelationshipLHSTest(_RelationshipLHSTest):
+    def test_bidirectional_literal_annotations(self, decl_base):
+        """test the 'string cleanup' function in orm/util.py, where
+        we receive a string annotation like::
+
+            "Mapped[List[B]]"
+
+        Which then fails to evaluate because we don't have "B" yet.
+        The annotation is converted on the fly to::
+
+            'Mapped[List["B"]]'
+
+        so that when we evaluated it, we get ``Mapped[List["B"]]`` and
+        can extract "B" as a string.
+
+        """
+
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+            bs: Mapped[List[B]] = relationship(back_populates="a")
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+            a: Mapped[A] = relationship(
+                back_populates="bs", primaryjoin=a_id == A.id
+            )
+
+        a1 = A(data="data")
+        b1 = B()
+        a1.bs.append(b1)
+        is_(a1, b1.a)
index d6d229f792ba90cd4fd370f11de2982a79a6f6de..058e1735b634e3423ecfd21bca5eed51605cbdc2 100644 (file)
@@ -190,6 +190,7 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
                 },
             ],
         ),
+        argnames="cols, expected",
     )
     def test_column_descriptions(self, cols, expected):
         User, Address = self.classes("User", "Address")
@@ -211,8 +212,13 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
         )
 
         stmt = select(*cols)
+
         eq_(stmt.column_descriptions, expected)
 
+        if stmt._propagate_attrs:
+            stmt = select(*cols).from_statement(stmt)
+            eq_(stmt.column_descriptions, expected)
+
     @testing.combinations(insert, update, delete, argnames="dml_construct")
     @testing.combinations(
         (
index 79b20e285a40a045a3fe1e3217a2b6b8ae863d61..4cecac0de438526ba9fa6782da632e50c3837951 100644 (file)
@@ -5,7 +5,9 @@ from unittest.mock import Mock
 import sqlalchemy as sa
 from sqlalchemy import delete
 from sqlalchemy import event
+from sqlalchemy import exc as sa_exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import insert
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
@@ -42,6 +44,7 @@ from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_not
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -236,6 +239,84 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
             ),
         )
 
+    def test_override_parameters_executesingle(self):
+        User = self.classes.User
+
+        sess = Session(testing.db, future=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def one(ctx):
+            return ctx.invoke_statement(params={"name": "overridden"})
+
+        orig_params = {"id": 18, "name": "original"}
+        with self.sql_execution_asserter() as asserter:
+            sess.execute(insert(User), orig_params)
+        asserter.assert_(
+            CompiledSQL(
+                "INSERT INTO users (id, name) VALUES (:id, :name)",
+                [{"id": 18, "name": "overridden"}],
+            )
+        )
+        # orig params weren't mutated
+        eq_(orig_params, {"id": 18, "name": "original"})
+
+    def test_override_parameters_executemany(self):
+        User = self.classes.User
+
+        sess = Session(testing.db, future=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def one(ctx):
+            return ctx.invoke_statement(
+                params=[{"name": "overridden1"}, {"name": "overridden2"}]
+            )
+
+        orig_params = [
+            {"id": 18, "name": "original1"},
+            {"id": 19, "name": "original2"},
+        ]
+        with self.sql_execution_asserter() as asserter:
+            sess.execute(insert(User), orig_params)
+        asserter.assert_(
+            CompiledSQL(
+                "INSERT INTO users (id, name) VALUES (:id, :name)",
+                [
+                    {"id": 18, "name": "overridden1"},
+                    {"id": 19, "name": "overridden2"},
+                ],
+            )
+        )
+        # orig params weren't mutated
+        eq_(
+            orig_params,
+            [{"id": 18, "name": "original1"}, {"id": 19, "name": "original2"}],
+        )
+
+    def test_override_parameters_executemany_mismatch(self):
+        User = self.classes.User
+
+        sess = Session(testing.db, future=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def one(ctx):
+            return ctx.invoke_statement(
+                params=[{"name": "overridden1"}, {"name": "overridden2"}]
+            )
+
+        orig_params = [
+            {"id": 18, "name": "original1"},
+            {"id": 19, "name": "original2"},
+            {"id": 20, "name": "original3"},
+        ]
+        with expect_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Can't apply executemany parameters to statement; number "
+            r"of parameter sets passed to Session.execute\(\) \(3\) does "
+            r"not match number of parameter sets given to "
+            r"ORMExecuteState.invoke_statement\(\) \(2\)",
+        ):
+            sess.execute(insert(User), orig_params)
+
     def test_chained_events_one(self):
 
         sess = Session(testing.db, future=True)
index a4250e375cd4a2b14cf2580becc93746d2e9922e..c006babc841c0e6e0008127b47f8ddf0bdb45b1e 100644 (file)
@@ -11,7 +11,6 @@ from sqlalchemy.orm import aliased
 from sqlalchemy.orm import attributes
 from sqlalchemy.orm import clear_mappers
 from sqlalchemy.orm import exc as orm_exc
-from sqlalchemy.orm import instrumentation
 from sqlalchemy.orm import lazyload
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import state as sa_state
@@ -410,73 +409,6 @@ class PickleTest(fixtures.MappedTest):
             u2 = loads(dumps(u1))
             eq_(u1, u2)
 
-    def test_09_pickle(self):
-        users = self.tables.users
-        self.mapper_registry.map_imperatively(User, users)
-        sess = fixture_session()
-        sess.add(User(id=1, name="ed"))
-        sess.commit()
-        sess.close()
-
-        inst = User(id=1, name="ed")
-        del inst._sa_instance_state
-
-        state = sa_state.InstanceState.__new__(sa_state.InstanceState)
-        state_09 = {
-            "class_": User,
-            "modified": False,
-            "committed_state": {},
-            "instance": inst,
-            "callables": {"name": state, "id": state},
-            "key": (User, (1,)),
-            "expired": True,
-        }
-        manager = instrumentation._SerializeManager.__new__(
-            instrumentation._SerializeManager
-        )
-        manager.class_ = User
-        state_09["manager"] = manager
-        state.__setstate__(state_09)
-        eq_(state.expired_attributes, {"name", "id"})
-
-        sess = fixture_session()
-        sess.add(inst)
-        eq_(inst.name, "ed")
-        # test identity_token expansion
-        eq_(sa.inspect(inst).key, (User, (1,), None))
-
-    def test_11_pickle(self):
-        users = self.tables.users
-        self.mapper_registry.map_imperatively(User, users)
-        sess = fixture_session()
-        u1 = User(id=1, name="ed")
-        sess.add(u1)
-        sess.commit()
-
-        sess.close()
-
-        manager = instrumentation._SerializeManager.__new__(
-            instrumentation._SerializeManager
-        )
-        manager.class_ = User
-
-        state_11 = {
-            "class_": User,
-            "modified": False,
-            "committed_state": {},
-            "instance": u1,
-            "manager": manager,
-            "key": (User, (1,)),
-            "expired_attributes": set(),
-            "expired": True,
-        }
-
-        state = sa_state.InstanceState.__new__(sa_state.InstanceState)
-        state.__setstate__(state_11)
-
-        eq_(state.identity_token, None)
-        eq_(state.identity_key, (User, (1,), None))
-
     def test_state_info_pickle(self):
         users = self.tables.users
         self.mapper_registry.map_imperatively(User, users)
index f2d7d8569a4037b9125736ef1eef608b900ab6dc..33e66d52f6b76ad46bf7e328be1ec74792f16d82 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import util
 from sqlalchemy.orm import query
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import scoped_session
@@ -158,7 +159,7 @@ class ScopedSessionTest(fixtures.MappedTest):
                     populate_existing=False,
                     with_for_update=None,
                     identity_token=None,
-                    execution_options=None,
+                    execution_options=util.EMPTY_DICT,
                 ),
             ],
         )
index eec4d878ac96929bffabf74500eb55ecf8c64fcb..ffc470972f51dc6dfc0cbd79686b5fb2db8245a3 100644 (file)
@@ -149,13 +149,27 @@ def process_class(
 
         iscoroutine = inspect.iscoroutinefunction(fn)
 
-        if spec.defaults:
-            new_defaults = tuple(
-                _repr_sym("util.EMPTY_DICT") if df is util.EMPTY_DICT else df
-                for df in spec.defaults
-            )
+        if spec.defaults or spec.kwonlydefaults:
             elem = list(spec)
-            elem[3] = tuple(new_defaults)
+
+            if spec.defaults:
+                new_defaults = tuple(
+                    _repr_sym("util.EMPTY_DICT")
+                    if df is util.EMPTY_DICT
+                    else df
+                    for df in spec.defaults
+                )
+                elem[3] = new_defaults
+
+            if spec.kwonlydefaults:
+                new_kwonlydefaults = {
+                    name: _repr_sym("util.EMPTY_DICT")
+                    if df is util.EMPTY_DICT
+                    else df
+                    for name, df in spec.kwonlydefaults.items()
+                }
+                elem[5] = new_kwonlydefaults
+
             spec = compat.FullArgSpec(*elem)
 
         caller_argspec = format_argspec_plus(spec, grouped=False)