From: Mike Bayer Date: Thu, 7 Apr 2022 16:37:23 +0000 (-0400) Subject: pep-484: session, instancestate, etc X-Git-Tag: rel_2_0_0b1~356 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=aa9cd878e8249a4a758c7f968e929e92fede42a5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep-484: session, instancestate, etc Also adds some fixes to annotation-based mapping that have come up, as well as starts to add more pep-484 test cases Change-Id: Ia722bbbc7967a11b23b66c8084eb61df9d233fee --- diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 07783ced78..c0dc54fabe 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -456,6 +456,8 @@ from .json import JSONB from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor +from ...util import FastIntFlag +from ...util import parse_user_argument_for_enum logger = logging.getLogger("sqlalchemy.dialects.postgresql") @@ -519,13 +521,19 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): pass -EXECUTEMANY_PLAIN = util.symbol("executemany_plain", canonical=0) -EXECUTEMANY_BATCH = util.symbol("executemany_batch", canonical=1) -EXECUTEMANY_VALUES = util.symbol("executemany_values", canonical=2) -EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol( - "executemany_values_plus_batch", - canonical=EXECUTEMANY_BATCH | EXECUTEMANY_VALUES, -) +class ExecutemanyMode(FastIntFlag): + EXECUTEMANY_PLAIN = 0 + EXECUTEMANY_BATCH = 1 + EXECUTEMANY_VALUES = 2 + EXECUTEMANY_VALUES_PLUS_BATCH = EXECUTEMANY_BATCH | EXECUTEMANY_VALUES + + +( + EXECUTEMANY_PLAIN, + EXECUTEMANY_BATCH, + EXECUTEMANY_VALUES, + EXECUTEMANY_VALUES_PLUS_BATCH, +) = tuple(ExecutemanyMode) class PGDialect_psycopg2(_PGDialect_common_psycopg): @@ -564,7 +572,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg): # Parse executemany_mode argument, allowing it to be only one of the # symbol names - self.executemany_mode = util.symbol.parse_user_argument( + self.executemany_mode = parse_user_argument_for_enum( executemany_mode, { EXECUTEMANY_PLAIN: [None], diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 5c446a91dc..8bcc7e2587 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -79,8 +79,8 @@ if typing.TYPE_CHECKING: """ -_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.immutabledict() -NO_OPTIONS: Mapping[str, Any] = util.immutabledict() +_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT +NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @@ -936,6 +936,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) ) + def _get_required_transaction(self) -> RootTransaction: + trans = self._transaction + if trans is None: + raise exc.InvalidRequestError("connection is not in a transaction") + return trans + + def _get_required_nested_transaction(self) -> NestedTransaction: + trans = self._nested_transaction + if trans is None: + raise exc.InvalidRequestError( + "connection is not in a nested transaction" + ) + return trans + def get_transaction(self) -> Optional[RootTransaction]: """Return the current root transaction in progress, if any. @@ -1220,7 +1234,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self, func: FunctionElement[Any], distilled_parameters: _CoreMultiExecuteParams, - execution_options: _ExecuteOptions, + execution_options: _ExecuteOptionsParameter, ) -> Result: """Execute a sql.FunctionElement object.""" @@ -1232,7 +1246,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self, default: ColumnDefault, distilled_parameters: _CoreMultiExecuteParams, - execution_options: _ExecuteOptions, + execution_options: _ExecuteOptionsParameter, ) -> Any: """Execute a schema.ColumnDefault object.""" @@ -1291,7 +1305,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self, ddl: DDLElement, distilled_parameters: _CoreMultiExecuteParams, - execution_options: _ExecuteOptions, + execution_options: _ExecuteOptionsParameter, ) -> Result: """Execute a schema.DDL object.""" @@ -1388,7 +1402,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self, elem: Executable, distilled_parameters: _CoreMultiExecuteParams, - execution_options: _ExecuteOptions, + execution_options: _ExecuteOptionsParameter, ) -> Result: """Execute a sql.ClauseElement object.""" @@ -1511,7 +1525,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self, statement: str, parameters: Optional[_DBAPIAnyExecuteParams] = None, - execution_options: Optional[_ExecuteOptions] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> Result: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 213485cc92..529b2ca73b 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -10,11 +10,13 @@ from __future__ import annotations import typing from typing import Any from typing import Callable +from typing import Optional from typing import TypeVar from .. import exc from .. import util from ..util._has_cy import HAS_CYEXTENSION +from ..util.typing import Protocol if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_util import _distill_params_20 as _distill_params_20 @@ -49,6 +51,10 @@ def connection_memoize(key: str) -> Callable[[_C], _C]: return decorated # type: ignore[return-value] +class _TConsSubject(Protocol): + _trans_context_manager: Optional[TransactionalContext] + + class TransactionalContext: """Apply Python context manager behavior to transaction objects. @@ -59,6 +65,8 @@ class TransactionalContext: __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__") + _trans_subject: Optional[_TConsSubject] + def _transaction_is_active(self) -> bool: raise NotImplementedError() @@ -82,7 +90,7 @@ class TransactionalContext: """ raise NotImplementedError() - def _get_subject(self) -> Any: + def _get_subject(self) -> _TConsSubject: raise NotImplementedError() def commit(self) -> None: @@ -95,7 +103,7 @@ class TransactionalContext: raise NotImplementedError() @classmethod - def _trans_ctx_check(cls, subject: Any) -> None: + def _trans_ctx_check(cls, subject: _TConsSubject) -> None: trans_context = subject._trans_context_manager if trans_context: if not trans_context._transaction_is_active(): diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index e1c9496813..7e6d2a3979 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -13,6 +13,7 @@ from .api import listen as listen from .api import listens_for as listens_for from .api import NO_RETVAL as NO_RETVAL from .api import remove as remove +from .attr import _InstanceLevelDispatch as _InstanceLevelDispatch from .attr import RefCollection as RefCollection from .base import _Dispatch as _Dispatch from .base import _DispatchCommon as _DispatchCommon diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index cc57251255..0503076aaf 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -116,7 +116,7 @@ class async_scoped_session(ScopedSessionMixin): # code within this block is **programmatically, # statically generated** by tools/generate_proxy_methods.py - def __contains__(self, instance): + def __contains__(self, instance: object) -> bool: r"""Return True if the instance is associated with this session. .. container:: class_bases @@ -138,7 +138,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.__contains__(instance) - def __iter__(self): + def __iter__(self) -> Iterator[object]: r"""Iterate over all pending or persistent instances within this Session. @@ -156,7 +156,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.__iter__() - def add(self, instance: Any, _warn: bool = True) -> None: + def add(self, instance: object, _warn: bool = True) -> None: r"""Place an object in the ``Session``. .. container:: class_bases @@ -181,7 +181,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.add(instance, _warn=_warn) - def add_all(self, instances): + def add_all(self, instances: Iterable[object]) -> None: r"""Add the given collection of instances to this ``Session``. .. container:: class_bases @@ -374,7 +374,9 @@ class async_scoped_session(ScopedSessionMixin): **kw, ) - def expire(self, instance, attribute_names=None): + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: r"""Expire the attributes on an instance. .. container:: class_bases @@ -426,7 +428,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.expire(instance, attribute_names=attribute_names) - def expire_all(self): + def expire_all(self) -> None: r"""Expires all persistent instances within this Session. .. container:: class_bases @@ -473,7 +475,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.expire_all() - def expunge(self, instance): + def expunge(self, instance: object) -> None: r"""Remove the `instance` from this ``Session``. .. container:: class_bases @@ -495,7 +497,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.expunge(instance) - def expunge_all(self): + def expunge_all(self) -> None: r"""Remove all object instances from this ``Session``. .. container:: class_bases @@ -652,7 +654,9 @@ class async_scoped_session(ScopedSessionMixin): mapper=mapper, clause=clause, bind=bind, **kw ) - def is_modified(self, instance, include_collections=True): + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: r"""Return ``True`` if the given instance has locally modified attributes. @@ -1168,7 +1172,7 @@ class async_scoped_session(ScopedSessionMixin): return await AsyncSession.close_all() @classmethod - def object_session(cls, instance: Any) -> "Session": + def object_session(cls, instance: object) -> Optional[Session]: r"""Return the :class:`.Session` to which an object belongs. .. container:: class_bases @@ -1192,13 +1196,13 @@ class async_scoped_session(ScopedSessionMixin): @classmethod def identity_key( cls, - class_=None, - ident=None, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, *, - instance=None, - row=None, - identity_token=None, - ) -> _IdentityKeyType: + instance: Optional[Any] = None, + row: Optional[Row] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: r"""Return an identity key. .. container:: class_bases diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 0bd2530b20..769fe05bdb 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -640,7 +640,7 @@ class AsyncSession(ReversibleProxy): # code within this block is **programmatically, # statically generated** by tools/generate_proxy_methods.py - def __contains__(self, instance): + def __contains__(self, instance: object) -> bool: r"""Return True if the instance is associated with this session. .. container:: class_bases @@ -656,7 +656,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.__contains__(instance) - def __iter__(self): + def __iter__(self) -> Iterator[object]: r"""Iterate over all pending or persistent instances within this Session. @@ -670,7 +670,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.__iter__() - def add(self, instance: Any, _warn: bool = True) -> None: + def add(self, instance: object, _warn: bool = True) -> None: r"""Place an object in the ``Session``. .. container:: class_bases @@ -689,7 +689,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.add(instance, _warn=_warn) - def add_all(self, instances): + def add_all(self, instances: Iterable[object]) -> None: r"""Add the given collection of instances to this ``Session``. .. container:: class_bases @@ -701,7 +701,9 @@ class AsyncSession(ReversibleProxy): return self._proxied.add_all(instances) - def expire(self, instance, attribute_names=None): + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: r"""Expire the attributes on an instance. .. container:: class_bases @@ -747,7 +749,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.expire(instance, attribute_names=attribute_names) - def expire_all(self): + def expire_all(self) -> None: r"""Expires all persistent instances within this Session. .. container:: class_bases @@ -788,7 +790,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.expire_all() - def expunge(self, instance): + def expunge(self, instance: object) -> None: r"""Remove the `instance` from this ``Session``. .. container:: class_bases @@ -804,7 +806,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.expunge(instance) - def expunge_all(self): + def expunge_all(self) -> None: r"""Remove all object instances from this ``Session``. .. container:: class_bases @@ -820,7 +822,9 @@ class AsyncSession(ReversibleProxy): return self._proxied.expunge_all() - def is_modified(self, instance, include_collections=True): + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: r"""Return ``True`` if the given instance has locally modified attributes. @@ -882,7 +886,7 @@ class AsyncSession(ReversibleProxy): instance, include_collections=include_collections ) - def in_transaction(self): + def in_transaction(self) -> bool: r"""Return True if this :class:`_orm.Session` has begun a transaction. .. container:: class_bases @@ -902,7 +906,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.in_transaction() - def in_nested_transaction(self): + def in_nested_transaction(self) -> bool: r"""Return True if this :class:`_orm.Session` has begun a nested transaction, e.g. SAVEPOINT. @@ -978,7 +982,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.new @property - def identity_map(self) -> identity.IdentityMap: + def identity_map(self) -> IdentityMap: r"""Proxy for the :attr:`_orm.Session.identity_map` attribute on behalf of the :class:`_asyncio.AsyncSession` class. @@ -987,7 +991,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.identity_map @identity_map.setter - def identity_map(self, attr: identity.IdentityMap) -> None: + def identity_map(self, attr: IdentityMap) -> None: self._proxied.identity_map = attr @property @@ -1090,7 +1094,7 @@ class AsyncSession(ReversibleProxy): return self._proxied.info @classmethod - def object_session(cls, instance: Any) -> "Session": + def object_session(cls, instance: object) -> Optional[Session]: r"""Return the :class:`.Session` to which an object belongs. .. container:: class_bases @@ -1108,13 +1112,13 @@ class AsyncSession(ReversibleProxy): @classmethod def identity_key( cls, - class_=None, - ident=None, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, *, - instance=None, - row=None, - identity_token=None, - ) -> _IdentityKeyType: + instance: Optional[Any] = None, + row: Optional[Row] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: r"""Return an identity key. .. container:: class_bases @@ -1243,7 +1247,7 @@ def async_object_session(instance): return None -def async_session(session): +def async_session(session: Session) -> AsyncSession: """Return the :class:`_asyncio.AsyncSession` which is proxying the given :class:`_orm.Session` object, if any. diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 5ca8b03dd6..a0c7905d84 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -1303,7 +1303,9 @@ class Comparator(interfaces.PropComparator[_T]): def property(self) -> Any: return None - def adapt_to_entity(self, adapt_to_entity: AliasedInsp) -> Comparator[_T]: + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> Comparator[_T]: # interesting.... return self diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index c5b0affd2e..6e8a0e7713 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -460,7 +460,7 @@ def composite( class_: Type[_T], *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], **kwargs: Any, -) -> "Composite[_T]": +) -> Composite[_T]: ... @@ -468,7 +468,7 @@ def composite( def composite( *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], **kwargs: Any, -) -> "Composite[Any]": +) -> Composite[Any]: ... @@ -476,7 +476,7 @@ def composite( class_: Any = None, *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], **kwargs: Any, -) -> "Composite[Any]": +) -> Composite[Any]: r"""Return a composite column-based property for use with a Mapper. See the mapping documentation section :ref:`mapper_composite` for a diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index e9ddf6d158..4250cdbe1f 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -1,11 +1,69 @@ from __future__ import annotations +import operator +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union +from sqlalchemy.orm.interfaces import UserDefinedOption +from ..util.typing import Protocol +from ..util.typing import TypeGuard if TYPE_CHECKING: + from .attributes import AttributeImpl + from .attributes import CollectionAttributeImpl + from .base import PassiveFlag + from .descriptor_props import _CompositeClassProto from .mapper import Mapper + from .state import InstanceState + from .util import AliasedClass from .util import AliasedInsp + from ..sql.base import ExecutableOption -_EntityType = Union[Mapper, AliasedInsp] +_T = TypeVar("_T", bound=Any) + +_O = TypeVar("_O", bound=Any) +"""The 'ORM mapped object' type. +I would have preferred this were bound=object however it seems +to not travel in all situations when defined in that way. +""" + +_InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"] + +_EntityType = Union[_T, "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"] + + +_InstanceDict = Dict[str, Any] + +_IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]] + + +class _LoaderCallable(Protocol): + def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any: + ... + + +def is_user_defined_option( + opt: ExecutableOption, +) -> TypeGuard[UserDefinedOption]: + return not opt._is_core and opt._is_user_defined # type: ignore + + +def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]: + return hasattr(obj, "__composite_values__") + + +if TYPE_CHECKING: + + def is_collection_impl( + impl: AttributeImpl, + ) -> TypeGuard[CollectionAttributeImpl]: + ... + +else: + is_collection_impl = operator.attrgetter("collection") diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 3d34927105..33ce96a192 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -18,12 +18,17 @@ from __future__ import annotations from collections import namedtuple import operator -import typing from typing import Any from typing import Callable +from typing import Collection +from typing import Dict from typing import List from typing import NamedTuple +from typing import Optional +from typing import overload from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -35,8 +40,8 @@ from .base import ATTR_WAS_SET from .base import CALLABLES_OK from .base import DEFERRED_HISTORY_LOAD from .base import INIT_OK -from .base import instance_dict -from .base import instance_state +from .base import instance_dict as instance_dict +from .base import instance_state as instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED from .base import manager_of_class @@ -55,6 +60,7 @@ from .base import PASSIVE_NO_RESULT from .base import PASSIVE_OFF from .base import PASSIVE_ONLY_PERSISTENT from .base import PASSIVE_RETURN_NO_VALUE +from .base import PassiveFlag from .base import RELATED_OBJECT_OK # noqa from .base import SQL_OK # noqa from .base import state_str @@ -67,7 +73,8 @@ from ..sql import roles from ..sql import traversals from ..sql import visitors -if typing.TYPE_CHECKING: +if TYPE_CHECKING: + from .state import InstanceState from ..sql.dml import _DMLColumnElement from ..sql.elements import ColumnElement from ..sql.elements import SQLCoreOperations @@ -115,6 +122,8 @@ class QueryableAttribute( is_attribute = True + impl: AttributeImpl + # PropComparator has a __visit_name__ to participate within # traversals. Disambiguate the attribute vs. a comparator. __visit_name__ = "orm_instrumented_attribute" @@ -402,7 +411,19 @@ class InstrumentedAttribute(QueryableAttribute[_T]): def __delete__(self, instance): self.impl.delete(instance_state(instance), instance_dict(instance)) - def __get__(self, instance, owner): + @overload + def __get__( + self, instance: None, owner: Type[Any] + ) -> InstrumentedAttribute: + ... + + @overload + def __get__(self, instance: object, owner: Type[Any]) -> Optional[_T]: + ... + + def __get__( + self, instance: Optional[object], owner: Type[Any] + ) -> Union[InstrumentedAttribute, Optional[_T]]: if instance is None: return self @@ -636,6 +657,8 @@ Event = AttributeEvent class AttributeImpl: """internal implementation for instrumented attributes.""" + collection: bool + def __init__( self, class_, @@ -811,7 +834,12 @@ class AttributeImpl: state.parents[id_] = False - def get_history(self, state, dict_, passive=PASSIVE_OFF): + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive=PASSIVE_OFF, + ) -> History: raise NotImplementedError() def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): @@ -989,7 +1017,12 @@ class ScalarAttributeImpl(AttributeImpl): ): raise AttributeError("%s object does not have a value" % self) - def get_history(self, state, dict_, passive=PASSIVE_OFF): + def get_history( + self, + state: InstanceState[Any], + dict_: Dict[str, Any], + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: if self.key in dict_: return History.from_scalar_attribute(self, state, dict_[self.key]) elif self.key in state.committed_state: @@ -1005,13 +1038,13 @@ class ScalarAttributeImpl(AttributeImpl): def set( self, - state, - dict_, - value, - initiator, - passive=PASSIVE_OFF, - check_old=None, - pop=False, + state: InstanceState[Any], + dict_: Dict[str, Any], + value: Any, + initiator: Optional[Event], + passive: PassiveFlag = PASSIVE_OFF, + check_old: Optional[object] = None, + pop: bool = False, ): if self.dispatch._active_history: old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) @@ -1536,7 +1569,7 @@ class CollectionAttributeImpl(AttributeImpl): if fire_event: self.dispatch.dispose_collection(state, collection, adapter) - def _invalidate_collection(self, collection): + def _invalidate_collection(self, collection: Collection) -> None: adapter = getattr(collection, "_sa_adapter") adapter.invalidated = True diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index a1a9442dca..d8f57e1498 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -20,8 +20,8 @@ from typing import Dict from typing import Generic from typing import Optional from typing import overload -from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -30,6 +30,7 @@ from .. import exc as sa_exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations +from ..util import FastIntFlag from ..util.langhelpers import TypingOnly from ..util.typing import Concatenate from ..util.typing import Literal @@ -37,159 +38,147 @@ from ..util.typing import ParamSpec from ..util.typing import Self if typing.TYPE_CHECKING: + from ._typing import _InternalEntityType from .attributes import InstrumentedAttribute from .mapper import Mapper + from .state import InstanceState _T = TypeVar("_T", bound=Any) +_O = TypeVar("_O", bound=object) -_IdentityKeyType = Tuple[type, Tuple[Any, ...], Optional[str]] - -PASSIVE_NO_RESULT = util.symbol( - "PASSIVE_NO_RESULT", +class LoaderCallableStatus(Enum): + PASSIVE_NO_RESULT = 0 """Symbol returned by a loader callable or other attribute/history retrieval operation when a value could not be determined, based on loader callable flags. - """, -) + """ -PASSIVE_CLASS_MISMATCH = util.symbol( - "PASSIVE_CLASS_MISMATCH", + PASSIVE_CLASS_MISMATCH = 1 """Symbol indicating that an object is locally present for a given primary key identity but it is not of the requested class. The - return value is therefore None and no SQL should be emitted.""", -) + return value is therefore None and no SQL should be emitted.""" -ATTR_WAS_SET = util.symbol( - "ATTR_WAS_SET", + ATTR_WAS_SET = 2 """Symbol returned by a loader callable to indicate the retrieved value, or values, were assigned to their attributes on the target object. - """, -) + """ -ATTR_EMPTY = util.symbol( - "ATTR_EMPTY", + ATTR_EMPTY = 3 """Symbol used internally to indicate an attribute had no callable.""", -) -NO_VALUE = util.symbol( - "NO_VALUE", + NO_VALUE = 4 """Symbol which may be placed as the 'previous' value of an attribute, indicating no value was loaded for an attribute when it was modified, and flags indicated we were not to load it. - """, -) + """ + + NEVER_SET = NO_VALUE + """ + Synonymous with NO_VALUE + + .. versionchanged:: 1.4 NEVER_SET was merged with NO_VALUE + + """ + + +( + PASSIVE_NO_RESULT, + PASSIVE_CLASS_MISMATCH, + ATTR_WAS_SET, + ATTR_EMPTY, + NO_VALUE, +) = tuple(LoaderCallableStatus) + NEVER_SET = NO_VALUE -""" -Synonymous with NO_VALUE -.. versionchanged:: 1.4 NEVER_SET was merged with NO_VALUE -""" -NO_CHANGE = util.symbol( - "NO_CHANGE", +class PassiveFlag(FastIntFlag): + """Bitflag interface that passes options onto loader callables""" + + NO_CHANGE = 0 """No callables or SQL should be emitted on attribute access and no state should change - """, - canonical=0, -) + """ -CALLABLES_OK = util.symbol( - "CALLABLES_OK", + CALLABLES_OK = 1 """Loader callables can be fired off if a value is not present. - """, - canonical=1, -) + """ -SQL_OK = util.symbol( - "SQL_OK", - """Loader callables can emit SQL at least on scalar value attributes.""", - canonical=2, -) + SQL_OK = 2 + """Loader callables can emit SQL at least on scalar value attributes.""" -RELATED_OBJECT_OK = util.symbol( - "RELATED_OBJECT_OK", + RELATED_OBJECT_OK = 4 """Callables can use SQL to load related objects as well as scalar value attributes. - """, - canonical=4, -) + """ -INIT_OK = util.symbol( - "INIT_OK", + INIT_OK = 8 """Attributes should be initialized with a blank value (None or an empty collection) upon get, if no other value can be obtained. - """, - canonical=8, -) + """ -NON_PERSISTENT_OK = util.symbol( - "NON_PERSISTENT_OK", - """Callables can be emitted if the parent is not persistent.""", - canonical=16, -) + NON_PERSISTENT_OK = 16 + """Callables can be emitted if the parent is not persistent.""" -LOAD_AGAINST_COMMITTED = util.symbol( - "LOAD_AGAINST_COMMITTED", + LOAD_AGAINST_COMMITTED = 32 """Callables should use committed values as primary/foreign keys during a load. - """, - canonical=32, -) + """ -NO_AUTOFLUSH = util.symbol( - "NO_AUTOFLUSH", + NO_AUTOFLUSH = 64 """Loader callables should disable autoflush.""", - canonical=64, -) -NO_RAISE = util.symbol( - "NO_RAISE", - """Loader callables should not raise any assertions""", - canonical=128, -) + NO_RAISE = 128 + """Loader callables should not raise any assertions""" -DEFERRED_HISTORY_LOAD = util.symbol( - "DEFERRED_HISTORY_LOAD", - """indicates special load of the previous value of an attribute""", - canonical=256, -) + DEFERRED_HISTORY_LOAD = 256 + """indicates special load of the previous value of an attribute""" -# pre-packaged sets of flags used as inputs -PASSIVE_OFF = util.symbol( - "PASSIVE_OFF", - "Callables can be emitted in all cases.", - canonical=( + # pre-packaged sets of flags used as inputs + PASSIVE_OFF = ( RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK - ), -) -PASSIVE_RETURN_NO_VALUE = util.symbol( - "PASSIVE_RETURN_NO_VALUE", - """PASSIVE_OFF ^ INIT_OK""", - canonical=PASSIVE_OFF ^ INIT_OK, -) -PASSIVE_NO_INITIALIZE = util.symbol( - "PASSIVE_NO_INITIALIZE", - "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK", - canonical=PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK, -) -PASSIVE_NO_FETCH = util.symbol( - "PASSIVE_NO_FETCH", "PASSIVE_OFF ^ SQL_OK", canonical=PASSIVE_OFF ^ SQL_OK -) -PASSIVE_NO_FETCH_RELATED = util.symbol( - "PASSIVE_NO_FETCH_RELATED", - "PASSIVE_OFF ^ RELATED_OBJECT_OK", - canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK, -) -PASSIVE_ONLY_PERSISTENT = util.symbol( - "PASSIVE_ONLY_PERSISTENT", - "PASSIVE_OFF ^ NON_PERSISTENT_OK", - canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK, -) + ) + "Callables can be emitted in all cases." + + PASSIVE_RETURN_NO_VALUE = PASSIVE_OFF ^ INIT_OK + """PASSIVE_OFF ^ INIT_OK""" + + PASSIVE_NO_INITIALIZE = PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK + "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK" + + PASSIVE_NO_FETCH = PASSIVE_OFF ^ SQL_OK + "PASSIVE_OFF ^ SQL_OK" + + PASSIVE_NO_FETCH_RELATED = PASSIVE_OFF ^ RELATED_OBJECT_OK + "PASSIVE_OFF ^ RELATED_OBJECT_OK" + + PASSIVE_ONLY_PERSISTENT = PASSIVE_OFF ^ NON_PERSISTENT_OK + "PASSIVE_OFF ^ NON_PERSISTENT_OK" + + +( + NO_CHANGE, + CALLABLES_OK, + SQL_OK, + RELATED_OBJECT_OK, + INIT_OK, + NON_PERSISTENT_OK, + LOAD_AGAINST_COMMITTED, + NO_AUTOFLUSH, + NO_RAISE, + DEFERRED_HISTORY_LOAD, + PASSIVE_OFF, + PASSIVE_RETURN_NO_VALUE, + PASSIVE_NO_INITIALIZE, + PASSIVE_NO_FETCH, + PASSIVE_NO_FETCH_RELATED, + PASSIVE_ONLY_PERSISTENT, +) = tuple(PassiveFlag) DEFAULT_MANAGER_ATTR = "_sa_class_manager" DEFAULT_STATE_ATTR = "_sa_instance_state" @@ -285,18 +274,27 @@ def manager_of_class(cls): return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None) -instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) +if TYPE_CHECKING: + + def instance_state(instance: _O) -> InstanceState[_O]: + ... + + def instance_dict(instance: object) -> Dict[str, Any]: + ... -instance_dict = operator.attrgetter("__dict__") +else: + instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) + instance_dict = operator.attrgetter("__dict__") -def instance_str(instance): + +def instance_str(instance: object) -> str: """Return a string describing an instance.""" return state_str(instance_state(instance)) -def state_str(state): +def state_str(state: InstanceState[Any]) -> str: """Return a string describing an instance via its InstanceState.""" if state is None: @@ -305,7 +303,7 @@ def state_str(state): return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj())) -def state_class_str(state): +def state_class_str(state: InstanceState[Any]) -> str: """Return a string describing an instance's class via its InstanceState. """ @@ -316,15 +314,15 @@ def state_class_str(state): return "<%s>" % (state.class_.__name__,) -def attribute_str(instance, attribute): +def attribute_str(instance: object, attribute: str) -> str: return instance_str(instance) + "." + attribute -def state_attribute_str(state, attribute): +def state_attribute_str(state: InstanceState[Any], attribute: str) -> str: return state_str(state) + "." + attribute -def object_mapper(instance): +def object_mapper(instance: _T) -> Mapper[_T]: """Given an object, return the primary Mapper associated with the object instance. @@ -343,7 +341,7 @@ def object_mapper(instance): return object_state(instance).mapper -def object_state(instance): +def object_state(instance: _T) -> InstanceState[_T]: """Given an object, return the :class:`.InstanceState` associated with the object. @@ -368,14 +366,14 @@ def object_state(instance): @inspection._inspects(object) -def _inspect_mapped_object(instance): +def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]: try: return instance_state(instance) except (exc.UnmappedClassError,) + exc.NO_STATE: return None -def _class_to_mapper(class_or_mapper): +def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]: insp = inspection.inspect(class_or_mapper, False) if insp is not None: return insp.mapper @@ -383,7 +381,9 @@ def _class_to_mapper(class_or_mapper): raise exc.UnmappedClassError(class_or_mapper) -def _mapper_or_none(entity): +def _mapper_or_none( + entity: Union[_T, _InternalEntityType[_T]] +) -> Optional[Mapper[_T]]: """Return the :class:`_orm.Mapper` for the given class or None if the class is not mapped. """ @@ -579,6 +579,8 @@ class InspectionAttrInfo(InspectionAttr): """ + __slots__ = () + @util.memoized_property def info(self) -> Dict[Any, Any]: """Info dictionary associated with the object, allowing user-defined diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index edd3fb56bf..419da65f7f 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -9,6 +9,7 @@ from __future__ import annotations import itertools from typing import Any +from typing import cast from typing import Dict from typing import List from typing import Optional @@ -61,7 +62,7 @@ from ..sql.selectable import SelectState from ..sql.visitors import InternalTraversal if TYPE_CHECKING: - from ._typing import _EntityType + from ._typing import _InternalEntityType from ..sql.compiler import _CompilerStackEntry from ..sql.dml import _DMLTableElement from ..sql.elements import ColumnElement @@ -213,7 +214,7 @@ class ORMCompileState(CompileState): statement: Union[Select, FromStatement] select_statement: Union[Select, FromStatement] _entities: List[_QueryEntity] - _polymorphic_adapters: Dict[_EntityType, ORMAdapter] + _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter] compile_options: Union[ Type[default_compile_options], default_compile_options ] @@ -630,6 +631,24 @@ class FromStatement(GroupedElement, ReturnsRows, Executable): return compiler.process(compile_state.statement, **kw) + @property + def column_descriptions(self): + """Return a :term:`plugin-enabled` 'column descriptions' structure + referring to the columns which are SELECTed by this statement. + + See the section :ref:`queryguide_inspection` for an overview + of this feature. + + .. seealso:: + + :ref:`queryguide_inspection` - ORM background + + """ + meth = cast( + ORMSelectCompileState, SelectState.get_plugin_class(self) + ).get_column_descriptions + return meth(self) + def _ensure_disambiguated_names(self): return self diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index dd3931faf6..32c69a7446 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -19,9 +19,11 @@ import operator import typing from typing import Any from typing import Callable +from typing import Dict from typing import List from typing import Optional from typing import Tuple +from typing import Type from typing import TypeVar from typing import Union @@ -41,14 +43,23 @@ from .. import sql from .. import util from ..sql import expression from ..sql import operators +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from .attributes import InstrumentedAttribute from .properties import MappedColumn + from ..sql._typing import _ColumnExpressionArgument + from ..sql.schema import Column _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) +class _CompositeClassProto(Protocol): + def __composite_values__(self) -> Tuple[Any, ...]: + ... + + class DescriptorProperty(MapperProperty[_T]): """:class:`.MapperProperty` which proxies access to a user-defined descriptor.""" @@ -110,6 +121,11 @@ class DescriptorProperty(MapperProperty[_T]): mapper.class_manager.instrument_attribute(self.key, proxy_attr) +_CompositeAttrType = Union[ + str, "Column[Any]", "MappedColumn[Any]", "InstrumentedAttribute[Any]" +] + + class Composite( _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T] ): @@ -129,12 +145,21 @@ class Composite( """ - composite_class: Union[type, Callable[..., type]] - attrs: Tuple[ - Union[sql.ColumnElement[Any], "MappedColumn", str, Mapped[Any]], ... + composite_class: Union[ + Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]] ] + attrs: Tuple[_CompositeAttrType, ...] - def __init__(self, class_=None, *attrs, **kwargs): + def __init__( + self, + class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None, + *attrs: _CompositeAttrType, + active_history: bool = False, + deferred: bool = False, + group: Optional[str] = None, + comparator_factory: Optional[Type[Comparator]] = None, + info: Optional[Dict[Any, Any]] = None, + ): super().__init__() if isinstance(class_, (Mapped, str, sql.ColumnElement)): @@ -145,15 +170,17 @@ class Composite( self.composite_class = class_ self.attrs = attrs - self.active_history = kwargs.get("active_history", False) - self.deferred = kwargs.get("deferred", False) - self.group = kwargs.get("group", None) - self.comparator_factory = kwargs.pop( - "comparator_factory", self.__class__.Comparator + self.active_history = active_history + self.deferred = deferred + self.group = group + self.comparator_factory = ( + comparator_factory + if comparator_factory is not None + else self.__class__.Comparator ) self._generated_composite_accessor = None - if "info" in kwargs: - self.info = kwargs.pop("info") + if info is not None: + self.info = info util.set_creation_order(self) self._create_descriptor() @@ -162,7 +189,9 @@ class Composite( super().instrument_class(mapper) self._setup_event_handlers() - def _composite_values_from_instance(self, value): + def _composite_values_from_instance( + self, value: _CompositeClassProto + ) -> Tuple[Any, ...]: if self._generated_composite_accessor: return self._generated_composite_accessor(value) else: diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index f70ea78373..00829ecbb7 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -9,6 +9,10 @@ from __future__ import annotations +from typing import Any +from typing import Optional +from typing import Type + from .. import exc as sa_exc from .. import util from ..exc import MultipleResultsFound # noqa @@ -73,7 +77,7 @@ class UnmappedInstanceError(UnmappedError): """An mapping operation was requested for an unknown instance.""" @util.preload_module("sqlalchemy.orm.base") - def __init__(self, obj, msg=None): + def __init__(self, obj: object, msg: Optional[str] = None): base = util.preloaded.orm_base if not msg: @@ -87,7 +91,7 @@ class UnmappedInstanceError(UnmappedError): "was called." % (name, name) ) except UnmappedClassError: - msg = _default_unmapped(type(obj)) + msg = f"Class '{_safe_cls_name(type(obj))}' is not mapped" if isinstance(obj, type): msg += ( "; was a class (%s) supplied where an instance was " @@ -102,12 +106,12 @@ class UnmappedInstanceError(UnmappedError): class UnmappedClassError(UnmappedError): """An mapping operation was requested for an unknown class.""" - def __init__(self, cls, msg=None): + def __init__(self, cls: Type[object], msg: Optional[str] = None): if not msg: msg = _default_unmapped(cls) UnmappedError.__init__(self, msg) - def __reduce__(self): + def __reduce__(self) -> Any: return self.__class__, (None, self.args[0]) @@ -194,7 +198,7 @@ def _safe_cls_name(cls): @util.preload_module("sqlalchemy.orm.base") -def _default_unmapped(cls): +def _default_unmapped(cls) -> Optional[str]: base = util.preloaded.orm_base try: @@ -204,4 +208,6 @@ def _default_unmapped(cls): name = _safe_cls_name(cls) if not mappers: - return "Class '%s' is not mapped" % name + return f"Class '{name}' is not mapped" + else: + return None diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 3caf0b22fb..d13265c560 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -7,96 +7,126 @@ from __future__ import annotations +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import TYPE_CHECKING +from typing import TypeVar import weakref from . import util as orm_util from .. import exc as sa_exc +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from .state import InstanceState + + +_T = TypeVar("_T", bound=Any) + +_O = TypeVar("_O", bound=object) + class IdentityMap: - def __init__(self): + _wr: weakref.ref[IdentityMap] + + _dict: Dict[_IdentityKeyType[Any], Any] + _modified: Set[InstanceState[Any]] + + def __init__(self) -> None: self._dict = {} self._modified = set() self._wr = weakref.ref(self) - def _kill(self): - self._add_unpresent = _killed + def _kill(self) -> None: + self._add_unpresent = _killed # type: ignore + + def all_states(self) -> List[InstanceState[Any]]: + raise NotImplementedError() + + def contains_state(self, state: InstanceState[Any]) -> bool: + raise NotImplementedError() + + def __contains__(self, key: _IdentityKeyType[Any]) -> bool: + raise NotImplementedError() + + def safe_discard(self, state: InstanceState[Any]) -> None: + raise NotImplementedError() + + def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: + raise NotImplementedError() + + def get( + self, key: _IdentityKeyType[_O], default: Optional[_O] = None + ) -> Optional[_O]: + raise NotImplementedError() def keys(self): return self._dict.keys() - def replace(self, state): + def values(self) -> Iterable[object]: + raise NotImplementedError() + + def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + + def add(self, state: InstanceState[Any]) -> bool: raise NotImplementedError() - def add(self, state): + def _fast_discard(self, state: InstanceState[Any]) -> None: raise NotImplementedError() - def _add_unpresent(self, state, key): + def _add_unpresent( + self, state: InstanceState[Any], key: _IdentityKeyType[Any] + ) -> None: """optional inlined form of add() which can assume item isn't present in the map""" self.add(state) - def update(self, dict_): - raise NotImplementedError("IdentityMap uses add() to insert data") - - def clear(self): - raise NotImplementedError("IdentityMap uses remove() to remove data") - - def _manage_incoming_state(self, state): + def _manage_incoming_state(self, state: InstanceState[Any]) -> None: state._instance_dict = self._wr if state.modified: self._modified.add(state) - def _manage_removed_state(self, state): + def _manage_removed_state(self, state: InstanceState[Any]) -> None: del state._instance_dict if state.modified: self._modified.discard(state) - def _dirty_states(self): + def _dirty_states(self) -> Set[InstanceState[Any]]: return self._modified - def check_modified(self): + def check_modified(self) -> bool: """return True if any InstanceStates present have been marked as 'modified'. """ return bool(self._modified) - def has_key(self, key): + def has_key(self, key: _IdentityKeyType[Any]) -> bool: return key in self - def popitem(self): - raise NotImplementedError("IdentityMap uses remove() to remove data") - - def pop(self, key, *args): - raise NotImplementedError("IdentityMap uses remove() to remove data") - - def setdefault(self, key, default=None): - raise NotImplementedError("IdentityMap uses add() to insert data") - - def __len__(self): + def __len__(self) -> int: return len(self._dict) - def copy(self): - raise NotImplementedError() - - def __setitem__(self, key, value): - raise NotImplementedError("IdentityMap uses add() to insert data") - - def __delitem__(self, key): - raise NotImplementedError("IdentityMap uses remove() to remove data") - class WeakInstanceDict(IdentityMap): - def __getitem__(self, key): + _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]] + + def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: state = self._dict[key] o = state.obj() if o is None: raise KeyError(key) return o - def __contains__(self, key): + def __contains__(self, key: _IdentityKeyType[Any]) -> bool: try: if key in self._dict: state = self._dict[key] @@ -108,7 +138,7 @@ class WeakInstanceDict(IdentityMap): else: return o is not None - def contains_state(self, state): + def contains_state(self, state: InstanceState[Any]) -> bool: if state.key in self._dict: try: return self._dict[state.key] is state @@ -117,13 +147,15 @@ class WeakInstanceDict(IdentityMap): else: return False - def replace(self, state): + def replace( + self, state: InstanceState[Any] + ) -> Optional[InstanceState[Any]]: if state.key in self._dict: try: existing = self._dict[state.key] except KeyError: # catch gc removed the key after we just checked for it - pass + existing = None else: if existing is not state: self._manage_removed_state(existing) @@ -136,7 +168,7 @@ class WeakInstanceDict(IdentityMap): self._manage_incoming_state(state) return existing - def add(self, state): + def add(self, state: InstanceState[Any]) -> bool: key = state.key # inline of self.__contains__ if key in self._dict: @@ -161,12 +193,16 @@ class WeakInstanceDict(IdentityMap): self._manage_incoming_state(state) return True - def _add_unpresent(self, state, key): + def _add_unpresent( + self, state: InstanceState[Any], key: _IdentityKeyType[Any] + ) -> None: # inlined form of add() called by loading.py self._dict[key] = state state._instance_dict = self._wr - def get(self, key, default=None): + def get( + self, key: _IdentityKeyType[_O], default: Optional[_O] = None + ) -> Optional[_O]: if key not in self._dict: return default try: @@ -180,7 +216,7 @@ class WeakInstanceDict(IdentityMap): return default return o - def items(self): + def items(self) -> List[InstanceState[Any]]: values = self.all_states() result = [] for state in values: @@ -189,7 +225,7 @@ class WeakInstanceDict(IdentityMap): result.append((state.key, value)) return result - def values(self): + def values(self) -> List[object]: values = self.all_states() result = [] for state in values: @@ -199,13 +235,13 @@ class WeakInstanceDict(IdentityMap): return result - def __iter__(self): + def __iter__(self) -> Iterator[_IdentityKeyType[Any]]: return iter(self.keys()) - def all_states(self): + def all_states(self) -> List[InstanceState[Any]]: return list(self._dict.values()) - def _fast_discard(self, state): + def _fast_discard(self, state: InstanceState[Any]) -> None: # used by InstanceState for state being # GC'ed, inlines _managed_removed_state try: @@ -217,10 +253,10 @@ class WeakInstanceDict(IdentityMap): if st is state: self._dict.pop(state.key, None) - def discard(self, state): + def discard(self, state: InstanceState[Any]) -> None: self.safe_discard(state) - def safe_discard(self, state): + def safe_discard(self, state: InstanceState[Any]) -> None: if state.key in self._dict: try: st = self._dict[state.key] @@ -233,7 +269,7 @@ class WeakInstanceDict(IdentityMap): self._manage_removed_state(state) -def _killed(state, key): +def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn: # external function to avoid creating cycles when assigned to # the IdentityMap raise sa_exc.InvalidRequestError( diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index a050c533a5..030d1595b2 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -32,33 +32,64 @@ alternate instrumentation forms. from __future__ import annotations +from typing import Any +from typing import Dict +from typing import Generic +from typing import Set +from typing import TYPE_CHECKING +from typing import TypeVar + from . import base from . import collections from . import exc from . import interfaces from . import state from .. import util +from ..event import EventTarget from ..util import HasMemoized +from ..util.typing import Protocol +if TYPE_CHECKING: + from .attributes import InstrumentedAttribute + from .mapper import Mapper + from ..event import dispatcher +_T = TypeVar("_T", bound=Any) DEL_ATTR = util.symbol("DEL_ATTR") -class ClassManager(HasMemoized, dict): +class _ExpiredAttributeLoaderProto(Protocol): + def __call__( + self, + state: state.InstanceState[Any], + toload: Set[str], + passive: base.PassiveFlag, + ): + ... + + +class ClassManager( + HasMemoized, + Dict[str, "InstrumentedAttribute[Any]"], + Generic[_T], + EventTarget, +): """Tracks state information at the class level.""" + dispatch: dispatcher[ClassManager] + MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR STATE_ATTR = base.DEFAULT_STATE_ATTR _state_setter = staticmethod(util.attrsetter(STATE_ATTR)) - expired_attribute_loader = None + expired_attribute_loader: _ExpiredAttributeLoaderProto "previously known as deferred_scalar_loader" init_method = None factory = None - mapper = None + declarative_scan = None registry = None @@ -199,7 +230,7 @@ class ClassManager(HasMemoized, dict): return frozenset([attr.impl for attr in self.values()]) @util.memoized_property - def mapper(self): + def mapper(self) -> Mapper[_T]: # raises unless self.mapper has been assigned raise exc.UnmappedClassError(self.class_) @@ -426,7 +457,9 @@ class ClassManager(HasMemoized, dict): def teardown_instance(self, instance): delattr(instance, self.STATE_ATTR) - def _serialize(self, state, state_dict): + def _serialize( + self, state: state.InstanceState, state_dict: Dict[str, Any] + ) -> _SerializeManager: return _SerializeManager(state, state_dict) def _new_state_if_none(self, instance): @@ -480,7 +513,7 @@ class _SerializeManager: """ - def __init__(self, state, d): + def __init__(self, state: state.InstanceState[Any], d: Dict[str, Any]): self.class_ = state.class_ manager = state.manager manager.dispatch.pickle(state, d) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index b4228323b4..7be7ce32b4 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -838,6 +838,10 @@ class ORMOption(ExecutableOption): """ + _is_core = False + + _is_user_defined = False + _is_compile_state = False _is_criteria_option = False @@ -942,6 +946,8 @@ class UserDefinedOption(ORMOption): _is_legacy_option = False + _is_user_defined = True + propagate_to_loaders = False """if True, indicate this option should be carried along to "secondary" Query objects produced during lazy loads diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 6f4c654ce4..ae083054cd 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -15,12 +15,24 @@ as well as some of the attribute loading strategies. from __future__ import annotations +from typing import Any +from typing import Iterable +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from sqlalchemy.orm.context import FromStatement from . import attributes from . import exc as orm_exc from . import path_registry from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED +from .base import PassiveFlag from .util import _none_set from .util import state_str from .. import exc as sa_exc @@ -31,9 +43,25 @@ from ..engine.result import ChunkedIteratorResult from ..engine.result import FrozenResult from ..engine.result import SimpleResultMetaData from ..sql import util as sql_util +from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectState +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from .base import LoaderCallableStatus + from .context import FromStatement + from .interfaces import ORMOption + from .mapper import Mapper + from .session import Session + from .state import InstanceState + from ..engine.interfaces import _ExecuteOptions + from ..sql import Select + from ..sql.base import Executable + from ..sql.selectable import ForUpdateArg + +_T = TypeVar("_T", bound=Any) +_O = TypeVar("_O", bound=object) _new_runid = util.counter() @@ -350,7 +378,12 @@ def merge_result(query, iterator, load=True): session.autoflush = autoflush -def get_from_identity(session, mapper, key, passive): +def get_from_identity( + session: Session, + mapper: Mapper[_O], + key: _IdentityKeyType[_O], + passive: PassiveFlag, +) -> Union[Optional[_O], LoaderCallableStatus]: """Look up the given key in the given session's identity map, check the object for expired state if found. @@ -385,16 +418,17 @@ def get_from_identity(session, mapper, key, passive): def load_on_ident( - session, - statement, - key, - load_options=None, - refresh_state=None, - with_for_update=None, - only_load_props=None, - no_autoflush=False, - bind_arguments=util.EMPTY_DICT, - execution_options=util.EMPTY_DICT, + session: Session, + statement: Union[Select, FromStatement], + key: Optional[_IdentityKeyType], + *, + load_options: Optional[Sequence[ORMOption]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + with_for_update: Optional[ForUpdateArg] = None, + only_load_props: Optional[Iterable[str]] = None, + no_autoflush: bool = False, + bind_arguments: Mapping[str, Any] = util.EMPTY_DICT, + execution_options: _ExecuteOptions = util.EMPTY_DICT, ): """Load the given identity key from the database.""" if key is not None: @@ -419,17 +453,18 @@ def load_on_ident( def load_on_pk_identity( - session, - statement, - primary_key_identity, - load_options=None, - refresh_state=None, - with_for_update=None, - only_load_props=None, - identity_token=None, - no_autoflush=False, - bind_arguments=util.EMPTY_DICT, - execution_options=util.EMPTY_DICT, + session: Session, + statement: Union[Select, FromStatement], + primary_key_identity: Optional[Tuple[Any, ...]], + *, + load_options: Optional[Sequence[ORMOption]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + with_for_update: Optional[ForUpdateArg] = None, + only_load_props: Optional[Iterable[str]] = None, + identity_token: Optional[Any] = None, + no_autoflush: bool = False, + bind_arguments: Mapping[str, Any] = util.EMPTY_DICT, + execution_options: _ExecuteOptions = util.EMPTY_DICT, ): """Load the given primary key identity from the database.""" diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 982b4b6d9c..c85861a594 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -22,9 +22,13 @@ from itertools import chain import sys import threading from typing import Any +from typing import Callable from typing import Generic +from typing import Iterator +from typing import Optional +from typing import Tuple from typing import Type -from typing import TypeVar +from typing import TYPE_CHECKING import weakref from . import attributes @@ -33,9 +37,11 @@ from . import instrumentation from . import loading from . import properties from . import util as orm_util +from ._typing import _O from .base import _class_to_mapper from .base import _state_mapper from .base import class_mapper +from .base import PassiveFlag from .base import state_str from .interfaces import _MappedAttribute from .interfaces import EXT_SKIP @@ -62,10 +68,15 @@ from ..sql import visitors from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import HasMemoized -_mapper_registries = weakref.WeakKeyDictionary() - +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from .instrumentation import ClassManager + from .state import InstanceState + from ..sql.elements import ColumnElement + from ..sql.schema import Column -_MC = TypeVar("_MC") +_mapper_registries = weakref.WeakKeyDictionary() def _all_registries(): @@ -99,7 +110,7 @@ class Mapper( sql_base.MemoizedHasCacheKey, InspectionAttr, log.Identified, - Generic[_MC], + Generic[_O], ): """Defines an association between a Python class and a database table or other relational structure, so that ORM operations against the class may @@ -115,9 +126,13 @@ class Mapper( _dispose_called = False _ready_for_configure = False - class_: Type[_MC] + class_: Type[_O] """The class to which this :class:`_orm.Mapper` is mapped.""" + _identity_class: Type[_O] + + always_refresh: bool + @util.deprecated_params( non_primary=( "1.3", @@ -130,7 +145,7 @@ class Mapper( ) def __init__( self, - class_: Type[_MC], + class_: Type[_O], local_table=None, properties=None, primary_key=None, @@ -813,7 +828,7 @@ class Mapper( """ - primary_key = None + primary_key: Tuple[Column[Any], ...] """An iterable containing the collection of :class:`_schema.Column` objects which comprise the 'primary key' of the mapped table, from the @@ -837,7 +852,7 @@ class Mapper( """ - class_ = None + class_: Type[_O] """The Python class which this :class:`_orm.Mapper` maps. This is a *read only* attribute determined during mapper construction. @@ -845,7 +860,7 @@ class Mapper( """ - class_manager = None + class_manager: ClassManager[_O] """The :class:`.ClassManager` which maintains event listeners and class-bound descriptors for this :class:`_orm.Mapper`. @@ -1965,7 +1980,7 @@ class Mapper( else self.persist_selectable.description, ) - def _is_orphan(self, state): + def _is_orphan(self, state: InstanceState[_O]) -> bool: orphan_possible = False for mapper in self.iterate_to_root(): for (key, cls) in mapper._delete_orphans: @@ -2804,16 +2819,24 @@ class Mapper( identity_token, ) - def identity_key_from_primary_key(self, primary_key, identity_token=None): + def identity_key_from_primary_key( + self, + primary_key: Tuple[Any, ...], + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[_O]: """Return an identity-map key for use in storing/retrieving an item from an identity map. :param primary_key: A list of values indicating the identifier. """ - return self._identity_class, tuple(primary_key), identity_token + return ( + self._identity_class, + tuple(primary_key), + identity_token, + ) - def identity_key_from_instance(self, instance): + def identity_key_from_instance(self, instance: _O) -> _IdentityKeyType[_O]: """Return the identity key for the given instance, based on its primary key attributes. @@ -2830,8 +2853,10 @@ class Mapper( return self._identity_key_from_state(state, attributes.PASSIVE_OFF) def _identity_key_from_state( - self, state, passive=attributes.PASSIVE_RETURN_NO_VALUE - ): + self, + state: InstanceState[_O], + passive: PassiveFlag = attributes.PASSIVE_RETURN_NO_VALUE, + ) -> _IdentityKeyType[_O]: dict_ = state.dict manager = state.manager return ( @@ -2845,7 +2870,7 @@ class Mapper( state.identity_token, ) - def primary_key_from_instance(self, instance): + def primary_key_from_instance(self, instance: _O) -> Tuple[Any, ...]: """Return the list of primary key values for the given instance. @@ -2903,8 +2928,12 @@ class Mapper( return {self._columntoproperty[col].key for col in self._all_pk_cols} def _get_state_attr_by_column( - self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE - ): + self, + state: InstanceState[_O], + dict_: _InstanceDict, + column: Column[Any], + passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, + ) -> Any: prop = self._columntoproperty[column] return state.manager[prop.key].impl.get(state, dict_, passive=passive) @@ -3146,7 +3175,14 @@ class Mapper( def _subclass_load_via_in_mapper(self): return self._subclass_load_via_in(self) - def cascade_iterator(self, type_, state, halt_on=None): + def cascade_iterator( + self, + type_: str, + state: InstanceState[_O], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[ + Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict] + ]: r"""Iterate each element and its mapper in an object graph, for all relationships that meet the given cascade rule. diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 9a7aa91a03..e2cf1d5b04 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -14,6 +14,7 @@ from functools import reduce from itertools import chain import logging from typing import Any +from typing import Sequence from typing import Tuple from typing import Union @@ -198,12 +199,12 @@ class PathRegistry(HasCacheKey): p = p[0:-1] return p - def serialize(self): + def serialize(self) -> Sequence[Any]: path = self.path return self._serialize_path(path) @classmethod - def deserialize(cls, path: Tuple) -> "PathRegistry": + def deserialize(cls, path: Sequence[Any]) -> PathRegistry: assert path is not None p = cls._deserialize_path(path) return cls.coerce(p) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 355ddc922d..93b49ab254 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -19,6 +19,12 @@ from itertools import chain from itertools import groupby from itertools import zip_longest import operator +from typing import Any +from typing import Dict +from typing import Iterable +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from . import attributes from . import evaluator @@ -47,15 +53,22 @@ from ..sql.dml import UpdateDMLState from ..sql.elements import BooleanClauseList from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +if TYPE_CHECKING: + from .mapper import Mapper + from .session import SessionTransaction + from .state import InstanceState + +_O = TypeVar("_O", bound=object) + def _bulk_insert( - mapper, - mappings, - session_transaction, - isstates, - return_defaults, - render_nulls, -): + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, +) -> None: base_mapper = mapper.base_mapper if session_transaction.session.connection_callable: @@ -126,8 +139,12 @@ def _bulk_insert( def _bulk_update( - mapper, mappings, session_transaction, isstates, update_changed_only -): + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, +) -> None: base_mapper = mapper.base_mapper search_keys = mapper._primary_key_propkeys diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index cc9d5a23b3..e498b17b4d 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -7,6 +7,19 @@ from __future__ import annotations +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + from . import exc as orm_exc from .base import class_mapper from .session import Session @@ -17,16 +30,54 @@ from ..util import ScopedRegistry from ..util import ThreadLocalRegistry from ..util import warn from ..util import warn_deprecated +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from .identity import IdentityMap + from .interfaces import ORMOption + from .mapper import Mapper + from .query import Query + from .session import _EntityBindKey + from .session import _PKIdentityArgument + from .session import _SessionBind + from .session import sessionmaker + from .session import SessionTransaction + from ..engine import Connection + from ..engine import Engine + from ..engine import Result + from ..engine import Row + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _ExecuteOptionsParameter + from ..engine.result import ScalarResult + from ..sql._typing import _ColumnsClauseArgument + from ..sql.base import Executable + from ..sql.elements import ClauseElement + from ..sql.selectable import ForUpdateArg + + +class _QueryDescriptorType(Protocol): + def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]: + ... + + +_O = TypeVar("_O", bound=object) __all__ = ["scoped_session", "ScopedSessionMixin"] class ScopedSessionMixin: + session_factory: sessionmaker + _support_async: bool + registry: ScopedRegistry[Session] + @property - def _proxied(self): - return self.registry() + def _proxied(self) -> Session: + return self.registry() # type: ignore - def __call__(self, **kw): + def __call__(self, **kw: Any) -> Session: r"""Return the current :class:`.Session`, creating it using the :attr:`.scoped_session.session_factory` if not present. @@ -57,7 +108,7 @@ class ScopedSessionMixin: ) return sess - def configure(self, **kwargs): + def configure(self, **kwargs: Any) -> None: """reconfigure the :class:`.sessionmaker` used by this :class:`.scoped_session`. @@ -120,7 +171,6 @@ class ScopedSessionMixin: "autoflush", "no_autoflush", "info", - "autocommit", ], ) class scoped_session(ScopedSessionMixin): @@ -136,15 +186,20 @@ class scoped_session(ScopedSessionMixin): """ - _support_async = False + _support_async: bool = False - session_factory = None + session_factory: sessionmaker """The `session_factory` provided to `__init__` is stored in this attribute and may be accessed at a later time. This can be useful when a new non-scoped :class:`.Session` or :class:`_engine.Connection` to the database is needed.""" - def __init__(self, session_factory, scopefunc=None): + def __init__( + self, + session_factory: sessionmaker, + scopefunc: Optional[Callable[[], Any]] = None, + ): + """Construct a new :class:`.scoped_session`. :param session_factory: a factory to create new :class:`.Session` @@ -167,7 +222,7 @@ class scoped_session(ScopedSessionMixin): else: self.registry = ThreadLocalRegistry(session_factory) - def remove(self): + def remove(self) -> None: """Dispose of the current :class:`.Session`, if present. This will first call :meth:`.Session.close` method @@ -184,7 +239,9 @@ class scoped_session(ScopedSessionMixin): self.registry().close() self.registry.clear() - def query_property(self, query_cls=None): + def query_property( + self, query_cls: Optional[Type[Query[Any]]] = None + ) -> _QueryDescriptorType: """return a class property which produces a :class:`_query.Query` object against the class and the current :class:`.Session` when called. @@ -211,16 +268,18 @@ class scoped_session(ScopedSessionMixin): """ class query: - def __get__(s, instance, owner): + def __get__( + s, instance: Any, owner: Type[Any] + ) -> Optional[Query[Any]]: try: mapper = class_mapper(owner) - if mapper: - if query_cls: - # custom query class - return query_cls(mapper, session=self.registry()) - else: - # session's configured query class - return self.registry().query(mapper) + assert mapper is not None + if query_cls: + # custom query class + return query_cls(mapper, session=self.registry()) + else: + # session's configured query class + return self.registry().query(mapper) except orm_exc.UnmappedClassError: return None @@ -231,7 +290,7 @@ class scoped_session(ScopedSessionMixin): # code within this block is **programmatically, # statically generated** by tools/generate_proxy_methods.py - def __contains__(self, instance): + def __contains__(self, instance: object) -> bool: r"""Return True if the instance is associated with this session. .. container:: class_bases @@ -247,7 +306,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.__contains__(instance) - def __iter__(self): + def __iter__(self) -> Iterator[object]: r"""Iterate over all pending or persistent instances within this Session. @@ -261,7 +320,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.__iter__() - def add(self, instance: Any, _warn: bool = True) -> None: + def add(self, instance: object, _warn: bool = True) -> None: r"""Place an object in the ``Session``. .. container:: class_bases @@ -280,7 +339,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.add(instance, _warn=_warn) - def add_all(self, instances): + def add_all(self, instances: Iterable[object]) -> None: r"""Add the given collection of instances to this ``Session``. .. container:: class_bases @@ -292,7 +351,9 @@ class scoped_session(ScopedSessionMixin): return self._proxied.add_all(instances) - def begin(self, nested=False, _subtrans=False): + def begin( + self, nested: bool = False, _subtrans: bool = False + ) -> SessionTransaction: r"""Begin a transaction, or nested transaction, on this :class:`.Session`, if one is not already begun. @@ -335,7 +396,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.begin(nested=nested, _subtrans=_subtrans) - def begin_nested(self): + def begin_nested(self) -> SessionTransaction: r"""Begin a "nested" transaction on this Session, e.g. SAVEPOINT. .. container:: class_bases @@ -367,7 +428,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.begin_nested() - def close(self): + def close(self) -> None: r"""Close out the transactional resources and ORM objects used by this :class:`_orm.Session`. @@ -434,7 +495,7 @@ class scoped_session(ScopedSessionMixin): def connection( self, bind_arguments: Optional[Dict[str, Any]] = None, - execution_options: Optional["_ExecuteOptions"] = None, + execution_options: Optional[_ExecuteOptions] = None, ) -> "Connection": r"""Return a :class:`_engine.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -476,7 +537,7 @@ class scoped_session(ScopedSessionMixin): bind_arguments=bind_arguments, execution_options=execution_options ) - def delete(self, instance): + def delete(self, instance: object) -> None: r"""Mark an instance as deleted. .. container:: class_bases @@ -493,13 +554,13 @@ class scoped_session(ScopedSessionMixin): def execute( self, - statement: "Executable", - params: Optional["_ExecuteParams"] = None, - execution_options: "_ExecuteOptions" = util.EMPTY_DICT, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ): + ) -> Result: r"""Execute a SQL expression construct. .. container:: class_bases @@ -567,7 +628,9 @@ class scoped_session(ScopedSessionMixin): _add_event=_add_event, ) - def expire(self, instance, attribute_names=None): + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: r"""Expire the attributes on an instance. .. container:: class_bases @@ -613,7 +676,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.expire(instance, attribute_names=attribute_names) - def expire_all(self): + def expire_all(self) -> None: r"""Expires all persistent instances within this Session. .. container:: class_bases @@ -654,7 +717,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.expire_all() - def expunge(self, instance): + def expunge(self, instance: object) -> None: r"""Remove the `instance` from this ``Session``. .. container:: class_bases @@ -670,7 +733,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.expunge(instance) - def expunge_all(self): + def expunge_all(self) -> None: r"""Remove all object instances from this ``Session``. .. container:: class_bases @@ -686,7 +749,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.expunge_all() - def flush(self, objects=None): + def flush(self, objects: Optional[Sequence[Any]] = None) -> None: r"""Flush all the object changes to the database. .. container:: class_bases @@ -719,14 +782,15 @@ class scoped_session(ScopedSessionMixin): def get( self, - entity, - ident, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - execution_options=None, - ): + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: r"""Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -841,12 +905,12 @@ class scoped_session(ScopedSessionMixin): def get_bind( self, - mapper=None, - clause=None, - bind=None, - _sa_skip_events=None, - _sa_skip_for_implicit_returning=False, - ): + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + _sa_skip_events: Optional[bool] = None, + _sa_skip_for_implicit_returning: bool = False, + ) -> Union[Engine, Connection]: r"""Return a "bind" to which this :class:`.Session` is bound. .. container:: class_bases @@ -933,7 +997,9 @@ class scoped_session(ScopedSessionMixin): _sa_skip_for_implicit_returning=_sa_skip_for_implicit_returning, ) - def is_modified(self, instance, include_collections=True): + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: r"""Return ``True`` if the given instance has locally modified attributes. @@ -997,11 +1063,11 @@ class scoped_session(ScopedSessionMixin): def bulk_save_objects( self, - objects, - return_defaults=False, - update_changed_only=True, - preserve_order=True, - ): + objects: Iterable[object], + return_defaults: bool = False, + update_changed_only: bool = True, + preserve_order: bool = True, + ) -> None: r"""Perform a bulk save of the given list of objects. .. container:: class_bases @@ -1109,8 +1175,12 @@ class scoped_session(ScopedSessionMixin): ) def bulk_insert_mappings( - self, mapper, mappings, return_defaults=False, render_nulls=False - ): + self, + mapper: Mapper[Any], + mappings: Iterable[Dict[str, Any]], + return_defaults: bool = False, + render_nulls: bool = False, + ) -> None: r"""Perform a bulk insert of the given list of mapping dictionaries. .. container:: class_bases @@ -1221,7 +1291,9 @@ class scoped_session(ScopedSessionMixin): render_nulls=render_nulls, ) - def bulk_update_mappings(self, mapper, mappings): + def bulk_update_mappings( + self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]] + ) -> None: r"""Perform a bulk update of the given list of mapping dictionaries. .. container:: class_bases @@ -1287,7 +1359,13 @@ class scoped_session(ScopedSessionMixin): return self._proxied.bulk_update_mappings(mapper, mappings) - def merge(self, instance, load=True, options=None): + def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: r"""Copy the state of a given instance into a corresponding instance within this :class:`.Session`. @@ -1355,7 +1433,9 @@ class scoped_session(ScopedSessionMixin): return self._proxied.merge(instance, load=load, options=options) - def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query: + def query( + self, *entities: _ColumnsClauseArgument, **kwargs: Any + ) -> Query[Any]: r"""Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -1381,7 +1461,12 @@ class scoped_session(ScopedSessionMixin): return self._proxied.query(*entities, **kwargs) - def refresh(self, instance, attribute_names=None, with_for_update=None): + def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: r"""Expire and refresh attributes on the given instance. .. container:: class_bases @@ -1452,7 +1537,7 @@ class scoped_session(ScopedSessionMixin): with_for_update=with_for_update, ) - def rollback(self): + def rollback(self) -> None: r"""Rollback the current transaction in progress. .. container:: class_bases @@ -1479,12 +1564,12 @@ class scoped_session(ScopedSessionMixin): def scalar( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + **kw: Any, + ) -> Any: r"""Execute a statement and return a scalar result. .. container:: class_bases @@ -1509,12 +1594,12 @@ class scoped_session(ScopedSessionMixin): def scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + **kw: Any, + ) -> ScalarResult[Any]: r"""Execute a statement and return the results as scalars. .. container:: class_bases @@ -1615,7 +1700,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.new @property - def identity_map(self) -> identity.IdentityMap: + def identity_map(self) -> IdentityMap: r"""Proxy for the :attr:`_orm.Session.identity_map` attribute on behalf of the :class:`_orm.scoping.scoped_session` class. @@ -1624,7 +1709,7 @@ class scoped_session(ScopedSessionMixin): return self._proxied.identity_map @identity_map.setter - def identity_map(self, attr: identity.IdentityMap) -> None: + def identity_map(self, attr: IdentityMap) -> None: self._proxied.identity_map = attr @property @@ -1726,19 +1811,6 @@ class scoped_session(ScopedSessionMixin): return self._proxied.info - @property - def autocommit(self) -> Any: - r"""Proxy for the :attr:`_orm.Session.autocommit` attribute - on behalf of the :class:`_orm.scoping.scoped_session` class. - - """ # noqa: E501 - - return self._proxied.autocommit - - @autocommit.setter - def autocommit(self, attr: Any) -> None: - self._proxied.autocommit = attr - @classmethod def close_all(cls) -> None: r"""Close *all* sessions in memory. @@ -1755,7 +1827,7 @@ class scoped_session(ScopedSessionMixin): return Session.close_all() @classmethod - def object_session(cls, instance: Any) -> "Session": + def object_session(cls, instance: object) -> Optional[Session]: r"""Return the :class:`.Session` to which an object belongs. .. container:: class_bases @@ -1773,13 +1845,13 @@ class scoped_session(ScopedSessionMixin): @classmethod def identity_key( cls, - class_=None, - ident=None, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, *, - instance=None, - row=None, - identity_token=None, - ) -> _IdentityKeyType: + instance: Optional[Any] = None, + row: Optional[Row] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: r"""Return an identity key. .. container:: class_bases diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 77a97936b2..55ce73cf54 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -13,12 +13,20 @@ import itertools import sys import typing from typing import Any +from typing import Callable +from typing import cast from typing import Dict +from typing import Iterable +from typing import Iterator from typing import List +from typing import NoReturn from typing import Optional -from typing import overload +from typing import Sequence +from typing import Set from typing import Tuple from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -30,14 +38,20 @@ from . import loading from . import persistence from . import query from . import state as statelib +from ._typing import is_composite_class +from ._typing import is_user_defined_option from .base import _class_to_mapper -from .base import _IdentityKeyType from .base import _none_set from .base import _state_mapper from .base import instance_str +from .base import LoaderCallableStatus from .base import object_mapper from .base import object_state +from .base import PassiveFlag from .base import state_str +from .context import FromStatement +from .context import ORMCompileState +from .identity import IdentityMap from .query import Query from .state import InstanceState from .state_changes import _StateChange @@ -51,22 +65,41 @@ from .. import util from ..engine import Connection from ..engine import Engine from ..engine.util import TransactionalContext +from ..event import dispatcher +from ..event import EventTarget from ..inspection import inspect from ..sql import coercions from ..sql import dml from ..sql import roles +from ..sql import Select from ..sql import visitors from ..sql.base import CompileState +from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import IdentitySet from ..util.typing import Literal +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from .interfaces import ORMOption + from .interfaces import UserDefinedOption from .mapper import Mapper + from .path_registry import PathRegistry + from ..engine import Result from ..engine import Row + from ..engine.base import Transaction + from ..engine.base import TwoPhaseTransaction + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _ExecuteOptionsParameter + from ..engine.result import ScalarResult + from ..event import _InstanceLevelDispatch from ..sql._typing import _ColumnsClauseArgument - from ..sql._typing import _ExecuteOptions - from ..sql._typing import _ExecuteParams from ..sql.base import Executable + from ..sql.elements import ClauseElement from ..sql.schema import Table __all__ = [ @@ -80,14 +113,45 @@ __all__ = [ "object_session", ] -_sessions = weakref.WeakValueDictionary() +_sessions: weakref.WeakValueDictionary[ + int, Session +] = weakref.WeakValueDictionary() """Weak-referencing dictionary of :class:`.Session` objects. """ +_O = TypeVar("_O", bound=object) statelib._sessions = _sessions +_PKIdentityArgument = Union[Any, Tuple[Any, ...]] -def _state_session(state): +_EntityBindKey = Union[Type[_O], "Mapper[_O]"] +_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"] +_SessionBind = Union["Engine", "Connection"] + + +class _ConnectionCallableProto(Protocol): + """a callable that returns a :class:`.Connection` given an instance. + + This callable, when present on a :class:`.Session`, is called only from the + ORM's persistence mechanism (i.e. the unit of work flush process) to allow + for connection-per-instance schemes (i.e. horizontal sharding) to be used + as persistence time. + + This callable is not present on a plain :class:`.Session`, however + is established when using the horizontal sharding extension. + + """ + + def __call__( + self, + mapper: Optional[Mapper[Any]] = None, + instance: Optional[object] = None, + **kw: Any, + ) -> Connection: + ... + + +def _state_session(state: InstanceState[Any]) -> Optional[Session]: """Given an :class:`.InstanceState`, return the :class:`.Session` associated, if any. """ @@ -109,40 +173,17 @@ class _SessionClassMethods: close_all_sessions() - @classmethod - @overload - def identity_key( - cls, - class_: type, - ident: Tuple[Any, ...], - *, - identity_token: Optional[str], - ) -> _IdentityKeyType: - ... - - @classmethod - @overload - def identity_key(cls, *, instance: Any) -> _IdentityKeyType: - ... - - @classmethod - @overload - def identity_key( - cls, class_: type, *, row: "Row", identity_token: Optional[str] - ) -> _IdentityKeyType: - ... - @classmethod @util.preload_module("sqlalchemy.orm.util") def identity_key( cls, - class_=None, - ident=None, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, *, - instance=None, - row=None, - identity_token=None, - ) -> _IdentityKeyType: + instance: Optional[Any] = None, + row: Optional[Row] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: """Return an identity key. This is an alias of :func:`.util.identity_key`. @@ -157,7 +198,7 @@ class _SessionClassMethods: ) @classmethod - def object_session(cls, instance: Any) -> "Session": + def object_session(cls, instance: object) -> Optional[Session]: """Return the :class:`.Session` to which an object belongs. This is an alias of :func:`.object_session`. @@ -205,26 +246,26 @@ class ORMExecuteState(util.MemoizedSlots): "_update_execution_options", ) - session: "Session" - statement: "Executable" - parameters: "_ExecuteParams" - execution_options: "_ExecuteOptions" - local_execution_options: "_ExecuteOptions" + session: Session + statement: Executable + parameters: Optional[_CoreAnyExecuteParams] + execution_options: _ExecuteOptions + local_execution_options: _ExecuteOptions bind_arguments: Dict[str, Any] - _compile_state_cls: Type[context.ORMCompileState] - _starting_event_idx: Optional[int] + _compile_state_cls: Optional[Type[ORMCompileState]] + _starting_event_idx: int _events_todo: List[Any] - _update_execution_options: Optional["_ExecuteOptions"] + _update_execution_options: Optional[_ExecuteOptions] def __init__( self, - session: "Session", - statement: "Executable", - parameters: "_ExecuteParams", - execution_options: "_ExecuteOptions", + session: Session, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams], + execution_options: _ExecuteOptions, bind_arguments: Dict[str, Any], - compile_state_cls: Type[context.ORMCompileState], - events_todo: List[Any], + compile_state_cls: Optional[Type[ORMCompileState]], + events_todo: List[_InstanceLevelDispatch[Session]], ): self.session = session self.statement = statement @@ -237,16 +278,16 @@ class ORMExecuteState(util.MemoizedSlots): self._compile_state_cls = compile_state_cls self._events_todo = list(events_todo) - def _remaining_events(self): + def _remaining_events(self) -> List[_InstanceLevelDispatch[Session]]: return self._events_todo[self._starting_event_idx + 1 :] def invoke_statement( self, - statement=None, - params=None, - execution_options=None, - bind_arguments=None, - ): + statement: Optional[Executable] = None, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + bind_arguments: Optional[Dict[str, Any]] = None, + ) -> Result: """Execute the statement represented by this :class:`.ORMExecuteState`, without re-invoking events that have already proceeded. @@ -270,9 +311,12 @@ class ORMExecuteState(util.MemoizedSlots): :param statement: optional statement to be invoked, in place of the statement currently represented by :attr:`.ORMExecuteState.statement`. - :param params: optional dictionary of parameters which will be merged - into the existing :attr:`.ORMExecuteState.parameters` of this - :class:`.ORMExecuteState`. + :param params: optional dictionary of parameters or list of parameters + which will be merged into the existing + :attr:`.ORMExecuteState.parameters` of this :class:`.ORMExecuteState`. + + .. versionchanged:: 2.0 a list of parameter dictionaries is accepted + for executemany executions. :param execution_options: optional dictionary of execution options will be merged into the existing @@ -302,9 +346,32 @@ class ORMExecuteState(util.MemoizedSlots): _bind_arguments.update(bind_arguments) _bind_arguments["_sa_skip_events"] = True + _params: Optional[_CoreAnyExecuteParams] if params: - _params = dict(self.parameters) - _params.update(params) + if self.is_executemany: + _params = [] + exec_many_parameters = cast( + "List[Dict[str, Any]]", self.parameters + ) + for _existing_params, _new_params in itertools.zip_longest( + exec_many_parameters, + cast("List[Dict[str, Any]]", params), + ): + if _existing_params is None or _new_params is None: + raise sa_exc.InvalidRequestError( + f"Can't apply executemany parameters to " + f"statement; number of parameter sets passed to " + f"Session.execute() ({len(exec_many_parameters)}) " + f"does not match number of parameter sets given " + f"to ORMExecuteState.invoke_statement() " + f"({len(params)})" + ) + _existing_params = dict(_existing_params) + _existing_params.update(_new_params) + _params.append(_existing_params) + else: + _params = dict(cast("Dict[str, Any]", self.parameters)) + _params.update(cast("Dict[str, Any]", params)) else: _params = self.parameters @@ -321,7 +388,7 @@ class ORMExecuteState(util.MemoizedSlots): ) @property - def bind_mapper(self): + def bind_mapper(self) -> Optional[Mapper[Any]]: """Return the :class:`_orm.Mapper` that is the primary "bind" mapper. For an :class:`_orm.ORMExecuteState` object invoking an ORM @@ -349,7 +416,7 @@ class ORMExecuteState(util.MemoizedSlots): return self.bind_arguments.get("mapper", None) @property - def all_mappers(self): + def all_mappers(self) -> Sequence[Mapper[Any]]: """Return a sequence of all :class:`_orm.Mapper` objects that are involved at the top level of this statement. @@ -369,7 +436,7 @@ class ORMExecuteState(util.MemoizedSlots): """ if not self.is_orm_statement: return [] - elif self.is_select: + elif isinstance(self.statement, (Select, FromStatement)): result = [] seen = set() for d in self.statement.column_descriptions: @@ -380,13 +447,13 @@ class ORMExecuteState(util.MemoizedSlots): seen.add(insp.mapper) result.append(insp.mapper) return result - elif self.is_update or self.is_delete: + elif self.statement.is_dml and self.bind_mapper: return [self.bind_mapper] else: return [] @property - def is_orm_statement(self): + def is_orm_statement(self) -> bool: """return True if the operation is an ORM statement. This indicates that the select(), update(), or delete() being @@ -399,44 +466,64 @@ class ORMExecuteState(util.MemoizedSlots): return self._compile_state_cls is not None @property - def is_select(self): + def is_executemany(self) -> bool: + """return True if the parameters are a multi-element list of + dictionaries with more than one dictionary. + + .. versionadded:: 2.0 + + """ + return isinstance(self.parameters, list) + + @property + def is_select(self) -> bool: """return True if this is a SELECT operation.""" return self.statement.is_select @property - def is_insert(self): + def is_insert(self) -> bool: """return True if this is an INSERT operation.""" return self.statement.is_dml and self.statement.is_insert @property - def is_update(self): + def is_update(self) -> bool: """return True if this is an UPDATE operation.""" return self.statement.is_dml and self.statement.is_update @property - def is_delete(self): + def is_delete(self) -> bool: """return True if this is a DELETE operation.""" return self.statement.is_dml and self.statement.is_delete @property - def _is_crud(self): + def _is_crud(self) -> bool: return isinstance(self.statement, (dml.Update, dml.Delete)) - def update_execution_options(self, **opts): + def update_execution_options(self, **opts: _ExecuteOptions) -> None: + """Update the local execution options with new values.""" # TODO: no coverage self.local_execution_options = self.local_execution_options.union(opts) - def _orm_compile_options(self): + def _orm_compile_options( + self, + ) -> Optional[ + Union[ + context.ORMCompileState.default_compile_options, + Type[context.ORMCompileState.default_compile_options], + ] + ]: if not self.is_select: return None opts = self.statement._compile_options - if opts.isinstance(context.ORMCompileState.default_compile_options): - return opts + if opts is not None and opts.isinstance( + context.ORMCompileState.default_compile_options + ): + return opts # type: ignore else: return None @property - def lazy_loaded_from(self): + def lazy_loaded_from(self) -> Optional[InstanceState[Any]]: """An :class:`.InstanceState` that is using this statement execution for a lazy load operation. @@ -451,7 +538,7 @@ class ORMExecuteState(util.MemoizedSlots): return self.load_options._lazy_loaded_from @property - def loader_strategy_path(self): + def loader_strategy_path(self) -> Optional[PathRegistry]: """Return the :class:`.PathRegistry` for the current load path. This object represents the "path" in a query along relationships @@ -465,7 +552,7 @@ class ORMExecuteState(util.MemoizedSlots): return None @property - def is_column_load(self): + def is_column_load(self) -> bool: """Return True if the operation is refreshing column-oriented attributes on an existing ORM object. @@ -492,7 +579,7 @@ class ORMExecuteState(util.MemoizedSlots): return opts is not None and opts._for_refresh_state @property - def is_relationship_load(self): + def is_relationship_load(self) -> bool: """Return True if this load is loading objects on behalf of a relationship. @@ -518,7 +605,12 @@ class ORMExecuteState(util.MemoizedSlots): return path is not None and not path.is_root @property - def load_options(self): + def load_options( + self, + ) -> Union[ + context.QueryContext.default_load_options, + Type[context.QueryContext.default_load_options], + ]: """Return the load_options that will be used for this execution.""" if not self.is_select: @@ -531,7 +623,12 @@ class ORMExecuteState(util.MemoizedSlots): ) @property - def update_delete_options(self): + def update_delete_options( + self, + ) -> Union[ + persistence.BulkUDCompileState.default_update_options, + Type[persistence.BulkUDCompileState.default_update_options], + ]: """Return the update_delete_options that will be used for this execution.""" @@ -546,7 +643,7 @@ class ORMExecuteState(util.MemoizedSlots): ) @property - def user_defined_options(self): + def user_defined_options(self) -> Sequence[UserDefinedOption]: """The sequence of :class:`.UserDefinedOptions` that have been associated with the statement being invoked. @@ -554,7 +651,7 @@ class ORMExecuteState(util.MemoizedSlots): return [ opt for opt in self.statement._with_options - if not opt._is_compile_state and not opt._is_legacy_option + if is_user_defined_option(opt) ] @@ -597,14 +694,29 @@ class SessionTransaction(_StateChange, TransactionalContext): """ - _rollback_exception = None + _rollback_exception: Optional[BaseException] = None + + _connections: Dict[ + Union[Engine, Connection], Tuple[Connection, Transaction, bool, bool] + ] + session: Session + _parent: Optional[SessionTransaction] + + _state: SessionTransactionState + + _new: weakref.WeakKeyDictionary[InstanceState[Any], object] + _deleted: weakref.WeakKeyDictionary[InstanceState[Any], object] + _dirty: weakref.WeakKeyDictionary[InstanceState[Any], object] + _key_switches: weakref.WeakKeyDictionary[ + InstanceState[Any], Tuple[Any, Any] + ] def __init__( self, - session, - parent=None, - nested=False, - autobegin=False, + session: Session, + parent: Optional[SessionTransaction] = None, + nested: bool = False, + autobegin: bool = False, ): TransactionalContext._trans_ctx_check(session) @@ -629,7 +741,9 @@ class SessionTransaction(_StateChange, TransactionalContext): self.session.dispatch.after_transaction_create(self.session, self) - def _raise_for_prerequisite_state(self, operation_name, state): + def _raise_for_prerequisite_state( + self, operation_name: str, state: SessionTransactionState + ) -> NoReturn: if state is SessionTransactionState.DEACTIVE: if self._rollback_exception: raise sa_exc.PendingRollbackError( @@ -655,7 +769,7 @@ class SessionTransaction(_StateChange, TransactionalContext): ) @property - def parent(self): + def parent(self) -> Optional[SessionTransaction]: """The parent :class:`.SessionTransaction` of this :class:`.SessionTransaction`. @@ -673,7 +787,7 @@ class SessionTransaction(_StateChange, TransactionalContext): """ return self._parent - nested = False + nested: bool = False """Indicates if this is a nested, or SAVEPOINT, transaction. When :attr:`.SessionTransaction.nested` is True, it is expected @@ -682,33 +796,40 @@ class SessionTransaction(_StateChange, TransactionalContext): """ @property - def is_active(self): + def is_active(self) -> bool: return ( self.session is not None and self._state is SessionTransactionState.ACTIVE ) @property - def _is_transaction_boundary(self): + def _is_transaction_boundary(self) -> bool: return self.nested or not self._parent @_StateChange.declare_states( (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE ) - def connection(self, bindkey, execution_options=None, **kwargs): + def connection( + self, + bindkey: Optional[Mapper[Any]], + execution_options: Optional[_ExecuteOptions] = None, + **kwargs: Any, + ) -> Connection: bind = self.session.get_bind(bindkey, **kwargs) return self._connection_for_bind(bind, execution_options) @_StateChange.declare_states( (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE ) - def _begin(self, nested=False): + def _begin(self, nested: bool = False) -> SessionTransaction: return SessionTransaction(self.session, self, nested=nested) - def _iterate_self_and_parents(self, upto=None): + def _iterate_self_and_parents( + self, upto: Optional[SessionTransaction] = None + ) -> Iterable[SessionTransaction]: current = self - result = () + result: Tuple[SessionTransaction, ...] = () while current: result += (current,) if current._parent is upto: @@ -723,12 +844,14 @@ class SessionTransaction(_StateChange, TransactionalContext): return result - def _take_snapshot(self, autobegin=False): + def _take_snapshot(self, autobegin: bool = False) -> None: if not self._is_transaction_boundary: - self._new = self._parent._new - self._deleted = self._parent._deleted - self._dirty = self._parent._dirty - self._key_switches = self._parent._key_switches + parent = self._parent + assert parent is not None + self._new = parent._new + self._deleted = parent._deleted + self._dirty = parent._dirty + self._key_switches = parent._key_switches return if not autobegin and not self.session._flushing: @@ -739,7 +862,7 @@ class SessionTransaction(_StateChange, TransactionalContext): self._dirty = weakref.WeakKeyDictionary() self._key_switches = weakref.WeakKeyDictionary() - def _restore_snapshot(self, dirty_only=False): + def _restore_snapshot(self, dirty_only: bool = False) -> None: """Restore the restoration state taken before a transaction began. Corresponds to a rollback. @@ -771,7 +894,7 @@ class SessionTransaction(_StateChange, TransactionalContext): if not dirty_only or s.modified or s in self._dirty: s._expire(s.dict, self.session.identity_map._modified) - def _remove_snapshot(self): + def _remove_snapshot(self) -> None: """Remove the restoration state taken before a transaction began. Corresponds to a commit. @@ -788,15 +911,21 @@ class SessionTransaction(_StateChange, TransactionalContext): ) self._deleted.clear() elif self.nested: - self._parent._new.update(self._new) - self._parent._dirty.update(self._dirty) - self._parent._deleted.update(self._deleted) - self._parent._key_switches.update(self._key_switches) + parent = self._parent + assert parent is not None + parent._new.update(self._new) + parent._dirty.update(self._dirty) + parent._deleted.update(self._deleted) + parent._key_switches.update(self._key_switches) @_StateChange.declare_states( (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE ) - def _connection_for_bind(self, bind, execution_options): + def _connection_for_bind( + self, + bind: _SessionBind, + execution_options: Optional[_ExecuteOptions], + ) -> Connection: if bind in self._connections: if execution_options: @@ -829,6 +958,7 @@ class SessionTransaction(_StateChange, TransactionalContext): if execution_options: conn = conn.execution_options(**execution_options) + transaction: Transaction if self.session.twophase and self._parent is None: transaction = conn.begin_twophase() elif self.nested: @@ -837,9 +967,9 @@ class SessionTransaction(_StateChange, TransactionalContext): # if given a future connection already in a transaction, don't # commit that transaction unless it is a savepoint if conn.in_nested_transaction(): - transaction = conn.get_nested_transaction() + transaction = conn._get_required_nested_transaction() else: - transaction = conn.get_transaction() + transaction = conn._get_required_transaction() should_commit = False else: transaction = conn.begin() @@ -861,7 +991,7 @@ class SessionTransaction(_StateChange, TransactionalContext): self.session.dispatch.after_begin(self.session, self, conn) return conn - def prepare(self): + def prepare(self) -> None: if self._parent is not None or not self.session.twophase: raise sa_exc.InvalidRequestError( "'twophase' mode not enabled, or not root transaction; " @@ -872,12 +1002,13 @@ class SessionTransaction(_StateChange, TransactionalContext): @_StateChange.declare_states( (SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED ) - def _prepare_impl(self): + def _prepare_impl(self) -> None: if self._parent is None or self.nested: self.session.dispatch.before_commit(self.session) stx = self.session._transaction + assert stx is not None if stx is not self: for subtransaction in stx._iterate_self_and_parents(upto=self): subtransaction.commit() @@ -897,7 +1028,7 @@ class SessionTransaction(_StateChange, TransactionalContext): if self._parent is None and self.session.twophase: try: for t in set(self._connections.values()): - t[1].prepare() + cast("TwoPhaseTransaction", t[1]).prepare() except: with util.safe_reraise(): self.rollback() @@ -929,9 +1060,7 @@ class SessionTransaction(_StateChange, TransactionalContext): self.close() if _to_root and self._parent: - return self._parent.commit(_to_root=True) - - return self._parent + self._parent.commit(_to_root=True) @_StateChange.declare_states( ( @@ -941,9 +1070,12 @@ class SessionTransaction(_StateChange, TransactionalContext): ), SessionTransactionState.CLOSED, ) - def rollback(self, _capture_exception=False, _to_root=False): + def rollback( + self, _capture_exception: bool = False, _to_root: bool = False + ) -> None: stx = self.session._transaction + assert stx is not None if stx is not self: for subtransaction in stx._iterate_self_and_parents(upto=self): subtransaction.close() @@ -993,19 +1125,18 @@ class SessionTransaction(_StateChange, TransactionalContext): if self._parent and _capture_exception: self._parent._rollback_exception = sys.exc_info()[1] - if rollback_err: + if rollback_err and rollback_err[1]: raise rollback_err[1].with_traceback(rollback_err[2]) sess.dispatch.after_soft_rollback(sess, self) if _to_root and self._parent: - return self._parent.rollback(_to_root=True) - return self._parent + self._parent.rollback(_to_root=True) @_StateChange.declare_states( _StateChangeStates.ANY, SessionTransactionState.CLOSED ) - def close(self, invalidate=False): + def close(self, invalidate: bool = False) -> None: if self.nested: self.session._nested_transaction = ( self._previous_nested_transaction @@ -1027,25 +1158,30 @@ class SessionTransaction(_StateChange, TransactionalContext): self._state = SessionTransactionState.CLOSED sess = self.session - self.session = None - self._connections = None + # TODO: these two None sets were historically after the + # event hook below, and in 2.0 I changed it this way for some reason, + # and I remember there being a reason, but not what it was. + # Why do we need to get rid of them at all? test_memusage::CycleTest + # passes with these commented out. + # self.session = None # type: ignore + # self._connections = None # type: ignore sess.dispatch.after_transaction_end(sess, self) - def _get_subject(self): + def _get_subject(self) -> Session: return self.session - def _transaction_is_active(self): + def _transaction_is_active(self) -> bool: return self._state is SessionTransactionState.ACTIVE - def _transaction_is_closed(self): + def _transaction_is_closed(self) -> bool: return self._state is SessionTransactionState.CLOSED - def _rollback_can_be_called(self): + def _rollback_can_be_called(self) -> bool: return self._state not in (COMMITTED, CLOSED) -class Session(_SessionClassMethods): +class Session(_SessionClassMethods, EventTarget): """Manages persistence operations for ORM-mapped objects. The Session's usage paradigm is described at :doc:`/orm/session`. @@ -1055,15 +1191,27 @@ class Session(_SessionClassMethods): _is_asyncio = False - identity_map: identity.IdentityMap - _new: Dict["InstanceState", Any] - _deleted: Dict["InstanceState", Any] + dispatch: dispatcher[Session] + + identity_map: IdentityMap + """A mapping of object identities to objects themselves. + + Iterating through ``Session.identity_map.values()`` provides + access to the full set of persistent objects (i.e., those + that have row identity) currently in the session. + + .. seealso:: + + :func:`.identity_key` - helper function to produce the keys used + in this dictionary. + + """ + + _new: Dict[InstanceState[Any], Any] + _deleted: Dict[InstanceState[Any], Any] bind: Optional[Union[Engine, Connection]] - __binds: Dict[ - Union[type, "Mapper", "Table"], - Union[engine.Engine, engine.Connection], - ] - _flusing: bool + __binds: Dict[_SessionBindKey, _SessionBind] + _flushing: bool _warn_on_events: bool _transaction: Optional[SessionTransaction] _nested_transaction: Optional[SessionTransaction] @@ -1072,24 +1220,19 @@ class Session(_SessionClassMethods): expire_on_commit: bool enable_baked_queries: bool twophase: bool - _query_cls: Type[Query] + _query_cls: Type[Query[Any]] def __init__( self, - bind: Optional[Union[engine.Engine, engine.Connection]] = None, + bind: Optional[_SessionBind] = None, autoflush: bool = True, future: Literal[True] = True, expire_on_commit: bool = True, twophase: bool = False, - binds: Optional[ - Dict[ - Union[type, "Mapper", "Table"], - Union[engine.Engine, engine.Connection], - ] - ] = None, + binds: Optional[Dict[_SessionBindKey, _SessionBind]] = None, enable_baked_queries: bool = True, info: Optional[Dict[Any, Any]] = None, - query_cls: Optional[Type[query.Query]] = None, + query_cls: Optional[Type[Query[Any]]] = None, autocommit: Literal[False] = False, ): r"""Construct a new Session. @@ -1249,23 +1392,23 @@ class Session(_SessionClassMethods): _sessions[self.hash_key] = self # used by sqlalchemy.engine.util.TransactionalContext - _trans_context_manager = None + _trans_context_manager: Optional[TransactionalContext] = None - connection_callable = None + connection_callable: Optional[_ConnectionCallableProto] = None - def __enter__(self): + def __enter__(self) -> Session: return self - def __exit__(self, type_, value, traceback): + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: self.close() @contextlib.contextmanager - def _maker_context_manager(self): + def _maker_context_manager(self) -> Iterator[Session]: with self: with self.begin(): yield self - def in_transaction(self): + def in_transaction(self) -> bool: """Return True if this :class:`_orm.Session` has begun a transaction. .. versionadded:: 1.4 @@ -1278,7 +1421,7 @@ class Session(_SessionClassMethods): """ return self._transaction is not None - def in_nested_transaction(self): + def in_nested_transaction(self) -> bool: """Return True if this :class:`_orm.Session` has begun a nested transaction, e.g. SAVEPOINT. @@ -1287,7 +1430,7 @@ class Session(_SessionClassMethods): """ return self._nested_transaction is not None - def get_transaction(self): + def get_transaction(self) -> Optional[SessionTransaction]: """Return the current root transaction in progress, if any. .. versionadded:: 1.4 @@ -1298,7 +1441,7 @@ class Session(_SessionClassMethods): trans = trans._parent return trans - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[SessionTransaction]: """Return the current nested transaction in progress, if any. .. versionadded:: 1.4 @@ -1308,7 +1451,7 @@ class Session(_SessionClassMethods): return self._nested_transaction @util.memoized_property - def info(self): + def info(self) -> Dict[Any, Any]: """A user-modifiable dictionary. The initial value of this dictionary can be populated using the @@ -1320,16 +1463,18 @@ class Session(_SessionClassMethods): """ return {} - def _autobegin(self): + def _autobegin_t(self) -> SessionTransaction: if self._transaction is None: trans = SessionTransaction(self, autobegin=True) assert self._transaction is trans - return True + return trans - return False + return self._transaction - def begin(self, nested=False, _subtrans=False): + def begin( + self, nested: bool = False, _subtrans: bool = False + ) -> SessionTransaction: """Begin a transaction, or nested transaction, on this :class:`.Session`, if one is not already begun. @@ -1364,13 +1509,16 @@ class Session(_SessionClassMethods): """ - if self._autobegin(): + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + if not nested and not _subtrans: - return self._transaction + return trans - if self._transaction is not None: + if trans is not None: if _subtrans or nested: - trans = self._transaction._begin(nested=nested) + trans = trans._begin(nested=nested) assert self._transaction is trans if nested: self._nested_transaction = trans @@ -1386,9 +1534,12 @@ class Session(_SessionClassMethods): trans = SessionTransaction(self) assert self._transaction is trans - return self._transaction # needed for __enter__/__exit__ hook + if TYPE_CHECKING: + assert self._transaction is not None + + return trans # needed for __enter__/__exit__ hook - def begin_nested(self): + def begin_nested(self) -> SessionTransaction: """Begin a "nested" transaction on this Session, e.g. SAVEPOINT. The target database(s) and associated drivers must support SQL @@ -1413,7 +1564,7 @@ class Session(_SessionClassMethods): """ return self.begin(nested=True) - def rollback(self): + def rollback(self) -> None: """Rollback the current transaction in progress. If no transaction is in progress, this method is a pass-through. @@ -1450,11 +1601,11 @@ class Session(_SessionClassMethods): :ref:`unitofwork_transaction` """ - if self._transaction is None: - if not self._autobegin(): - raise sa_exc.InvalidRequestError("No transaction is begun.") + trans = self._transaction + if trans is None: + trans = self._autobegin_t() - self._transaction.commit(_to_root=True) + trans.commit(_to_root=True) def prepare(self) -> None: """Prepare the current transaction in progress for two phase commit. @@ -1467,16 +1618,16 @@ class Session(_SessionClassMethods): :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. """ - if self._transaction is None: - if not self._autobegin(): - raise sa_exc.InvalidRequestError("No transaction is begun.") + trans = self._transaction + if trans is None: + trans = self._autobegin_t() - self._transaction.prepare() + trans.prepare() def connection( self, bind_arguments: Optional[Dict[str, Any]] = None, - execution_options: Optional["_ExecuteOptions"] = None, + execution_options: Optional[_ExecuteOptions] = None, ) -> "Connection": r"""Return a :class:`_engine.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -1521,24 +1672,28 @@ class Session(_SessionClassMethods): execution_options=execution_options, ) - def _connection_for_bind(self, engine, execution_options=None, **kw): + def _connection_for_bind( + self, + engine: _SessionBind, + execution_options: Optional[_ExecuteOptions] = None, + **kw: Any, + ) -> Connection: TransactionalContext._trans_ctx_check(self) - if self._transaction is None: - assert self._autobegin() - return self._transaction._connection_for_bind( - engine, execution_options - ) + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + return trans._connection_for_bind(engine, execution_options) def execute( self, - statement: "Executable", - params: Optional["_ExecuteParams"] = None, - execution_options: "_ExecuteOptions" = util.EMPTY_DICT, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ): + ) -> Result: r"""Execute a SQL expression construct. Returns a :class:`_engine.Result` object representing @@ -1603,6 +1758,8 @@ class Session(_SessionClassMethods): compile_state_cls = CompileState._get_plugin_class_for_plugin( statement, "orm" ) + if TYPE_CHECKING: + assert isinstance(compile_state_cls, ORMCompileState) else: compile_state_cls = None @@ -1645,9 +1802,9 @@ class Session(_SessionClassMethods): ) for idx, fn in enumerate(events_todo): orm_exec_state._starting_event_idx = idx - result = fn(orm_exec_state) - if result: - return result + fn_result: Optional[Result] = fn(orm_exec_state) + if fn_result: + return fn_result statement = orm_exec_state.statement execution_options = orm_exec_state.local_execution_options @@ -1655,7 +1812,9 @@ class Session(_SessionClassMethods): bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind) - result = conn.execute(statement, params or {}, execution_options) + result: Result = conn.execute( + statement, params or {}, execution_options + ) if compile_state_cls: result = compile_state_cls.orm_setup_cursor_result( @@ -1671,12 +1830,12 @@ class Session(_SessionClassMethods): def scalar( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + **kw: Any, + ) -> Any: """Execute a statement and return a scalar result. Usage and parameters are the same as that of @@ -1695,12 +1854,12 @@ class Session(_SessionClassMethods): def scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + **kw: Any, + ) -> ScalarResult[Any]: """Execute a statement and return the results as scalars. Usage and parameters are the same as that of @@ -1722,7 +1881,7 @@ class Session(_SessionClassMethods): **kw, ).scalars() - def close(self): + def close(self) -> None: """Close out the transactional resources and ORM objects used by this :class:`_orm.Session`. @@ -1754,7 +1913,7 @@ class Session(_SessionClassMethods): """ self._close_impl(invalidate=False) - def invalidate(self): + def invalidate(self) -> None: """Close this Session, using connection invalidation. This is a variant of :meth:`.Session.close` that will additionally @@ -1790,13 +1949,13 @@ class Session(_SessionClassMethods): """ self._close_impl(invalidate=True) - def _close_impl(self, invalidate): + def _close_impl(self, invalidate: bool) -> None: self.expunge_all() if self._transaction is not None: for transaction in self._transaction._iterate_self_and_parents(): transaction.close(invalidate) - def expunge_all(self): + def expunge_all(self) -> None: """Remove all object instances from this ``Session``. This is equivalent to calling ``expunge(obj)`` on all objects in this @@ -1812,7 +1971,7 @@ class Session(_SessionClassMethods): statelib.InstanceState._detach_states(all_states, self) - def _add_bind(self, key, bind): + def _add_bind(self, key: _SessionBindKey, bind: _SessionBind) -> None: try: insp = inspect(key) except sa_exc.NoInspectionAvailable as err: @@ -1834,7 +1993,9 @@ class Session(_SessionClassMethods): "Not an acceptable bind target: %s" % key ) - def bind_mapper(self, mapper, bind): + def bind_mapper( + self, mapper: _EntityBindKey[_O], bind: _SessionBind + ) -> None: """Associate a :class:`_orm.Mapper` or arbitrary Python class with a "bind", e.g. an :class:`_engine.Engine` or :class:`_engine.Connection`. @@ -1862,7 +2023,7 @@ class Session(_SessionClassMethods): """ self._add_bind(mapper, bind) - def bind_table(self, table, bind): + def bind_table(self, table: Table, bind: _SessionBind) -> None: """Associate a :class:`_schema.Table` with a "bind", e.g. an :class:`_engine.Engine` or :class:`_engine.Connection`. @@ -1892,12 +2053,12 @@ class Session(_SessionClassMethods): def get_bind( self, - mapper=None, - clause=None, - bind=None, - _sa_skip_events=None, - _sa_skip_for_implicit_returning=False, - ): + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + _sa_skip_events: Optional[bool] = None, + _sa_skip_for_implicit_returning: bool = False, + ) -> Union[Engine, Connection]: """Return a "bind" to which this :class:`.Session` is bound. The "bind" is usually an instance of :class:`_engine.Engine`, @@ -1995,23 +2156,25 @@ class Session(_SessionClassMethods): # look more closely at the mapper. if mapper is not None: try: - mapper = inspect(mapper) + inspected_mapper = inspect(mapper) except sa_exc.NoInspectionAvailable as err: if isinstance(mapper, type): raise exc.UnmappedClassError(mapper) from err else: raise + else: + inspected_mapper = None # match up the mapper or clause in the __binds if self.__binds: # matching mappers and selectables to entries in the # binds dictionary; supported use case. - if mapper: - for cls in mapper.class_.__mro__: + if inspected_mapper: + for cls in inspected_mapper.class_.__mro__: if cls in self.__binds: return self.__binds[cls] if clause is None: - clause = mapper.persist_selectable + clause = inspected_mapper.persist_selectable if clause is not None: plugin_subject = clause._propagate_attrs.get( @@ -2025,6 +2188,8 @@ class Session(_SessionClassMethods): for obj in visitors.iterate(clause): if obj in self.__binds: + if TYPE_CHECKING: + assert isinstance(obj, Table) return self.__binds[obj] # none of the __binds matched, but we have a fallback bind. @@ -2033,17 +2198,19 @@ class Session(_SessionClassMethods): return self.bind context = [] - if mapper is not None: - context.append("mapper %s" % mapper) + if inspected_mapper is not None: + context.append(f"mapper {inspected_mapper}") if clause is not None: context.append("SQL expression") raise sa_exc.UnboundExecutionError( - "Could not locate a bind configured on %s or this Session." - % (", ".join(context),), + f"Could not locate a bind configured on " + f'{", ".join(context)} or this Session.' ) - def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query: + def query( + self, *entities: _ColumnsClauseArgument, **kwargs: Any + ) -> Query[Any]: """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -2065,12 +2232,12 @@ class Session(_SessionClassMethods): def _identity_lookup( self, - mapper, - primary_key_identity, - identity_token=None, - passive=attributes.PASSIVE_OFF, - lazy_loaded_from=None, - ): + mapper: Mapper[_O], + primary_key_identity: Union[Any, Tuple[Any, ...]], + identity_token: Any = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, + ) -> Union[Optional[_O], LoaderCallableStatus]: """Locate an object in the identity map. Given a primary key identity, constructs an identity key and then @@ -2117,9 +2284,9 @@ class Session(_SessionClassMethods): ) return loading.get_from_identity(self, mapper, key, passive) - @property + @util.non_memoized_property @contextlib.contextmanager - def no_autoflush(self): + def no_autoflush(self) -> Iterator[Session]: """Return a context manager that disables autoflush. e.g.:: @@ -2145,7 +2312,7 @@ class Session(_SessionClassMethods): finally: self.autoflush = autoflush - def _autoflush(self): + def _autoflush(self) -> None: if self.autoflush and not self._flushing: try: self.flush() @@ -2161,7 +2328,12 @@ class Session(_SessionClassMethods): ) raise e.with_traceback(sys.exc_info()[2]) - def refresh(self, instance, attribute_names=None, with_for_update=None): + def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: """Expire and refresh attributes on the given instance. The selected attributes will first be expired as they would when using @@ -2233,7 +2405,7 @@ class Session(_SessionClassMethods): "A blank dictionary is ambiguous." ) - with_for_update = query.ForUpdateArg._from_argument(with_for_update) + with_for_update = ForUpdateArg._from_argument(with_for_update) stmt = sql.select(object_mapper(instance)) if ( @@ -2251,7 +2423,7 @@ class Session(_SessionClassMethods): "Could not refresh instance '%s'" % instance_str(instance) ) - def expire_all(self): + def expire_all(self) -> None: """Expires all persistent instances within this Session. When any attributes on a persistent instance is next accessed, @@ -2286,7 +2458,9 @@ class Session(_SessionClassMethods): for state in self.identity_map.all_states(): state._expire(state.dict, self.identity_map._modified) - def expire(self, instance, attribute_names=None): + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: """Expire the attributes on an instance. Marks the attributes of an instance as out of date. When an expired @@ -2329,7 +2503,11 @@ class Session(_SessionClassMethods): raise exc.UnmappedInstanceError(instance) from err self._expire_state(state, attribute_names) - def _expire_state(self, state, attribute_names): + def _expire_state( + self, + state: InstanceState[Any], + attribute_names: Optional[Iterable[str]], + ) -> None: self._validate_persistent(state) if attribute_names: state._expire_attributes(state.dict, attribute_names) @@ -2343,7 +2521,9 @@ class Session(_SessionClassMethods): for o, m, st_, dct_ in cascaded: self._conditional_expire(st_) - def _conditional_expire(self, state, autoflush=None): + def _conditional_expire( + self, state: InstanceState[Any], autoflush: Optional[bool] = None + ) -> None: """Expire a state if persistent, else expunge if pending""" if state.key: @@ -2352,7 +2532,7 @@ class Session(_SessionClassMethods): self._new.pop(state) state._detach(self) - def expunge(self, instance): + def expunge(self, instance: object) -> None: """Remove the `instance` from this ``Session``. This will free all internal references to the instance. Cascading @@ -2373,7 +2553,9 @@ class Session(_SessionClassMethods): ) self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded]) - def _expunge_states(self, states, to_transient=False): + def _expunge_states( + self, states: Iterable[InstanceState[Any]], to_transient: bool = False + ) -> None: for state in states: if state in self._new: self._new.pop(state) @@ -2388,7 +2570,7 @@ class Session(_SessionClassMethods): states, self, to_transient=to_transient ) - def _register_persistent(self, states): + def _register_persistent(self, states: Set[InstanceState[Any]]) -> None: """Register all persistent objects from a flush. This is used both for pending objects moving to the persistent @@ -2429,11 +2611,13 @@ class Session(_SessionClassMethods): # state has already replaced this one in the identity # map (see test/orm/test_naturalpks.py ReversePKsTest) self.identity_map.safe_discard(state) - if state in self._transaction._key_switches: - orig_key = self._transaction._key_switches[state][0] + trans = self._transaction + assert trans is not None + if state in trans._key_switches: + orig_key = trans._key_switches[state][0] else: orig_key = state.key - self._transaction._key_switches[state] = ( + trans._key_switches[state] = ( orig_key, instance_key, ) @@ -2470,7 +2654,7 @@ class Session(_SessionClassMethods): for state in set(states).intersection(self._new): self._new.pop(state) - def _register_altered(self, states): + def _register_altered(self, states: Iterable[InstanceState[Any]]) -> None: if self._transaction: for state in states: if state in self._new: @@ -2478,7 +2662,9 @@ class Session(_SessionClassMethods): else: self._transaction._dirty[state] = True - def _remove_newly_deleted(self, states): + def _remove_newly_deleted( + self, states: Iterable[InstanceState[Any]] + ) -> None: persistent_to_deleted = self.dispatch.persistent_to_deleted or None for state in states: if self._transaction: @@ -2498,7 +2684,7 @@ class Session(_SessionClassMethods): if persistent_to_deleted is not None: persistent_to_deleted(self, state) - def add(self, instance: Any, _warn: bool = True) -> None: + def add(self, instance: object, _warn: bool = True) -> None: """Place an object in the ``Session``. Its state will be persisted to the database on the next flush @@ -2518,7 +2704,7 @@ class Session(_SessionClassMethods): self._save_or_update_state(state) - def add_all(self, instances): + def add_all(self, instances: Iterable[object]) -> None: """Add the given collection of instances to this ``Session``.""" if self._warn_on_events: @@ -2527,7 +2713,7 @@ class Session(_SessionClassMethods): for instance in instances: self.add(instance, _warn=False) - def _save_or_update_state(self, state): + def _save_or_update_state(self, state: InstanceState[Any]) -> None: state._orphaned_outside_of_session = False self._save_or_update_impl(state) @@ -2537,7 +2723,7 @@ class Session(_SessionClassMethods): ): self._save_or_update_impl(st_) - def delete(self, instance): + def delete(self, instance: object) -> None: """Mark an instance as deleted. The database delete operation occurs upon ``flush()``. @@ -2553,7 +2739,9 @@ class Session(_SessionClassMethods): self._delete_impl(state, instance, head=True) - def _delete_impl(self, state, obj, head): + def _delete_impl( + self, state: InstanceState[Any], obj: object, head: bool + ) -> None: if state.key is None: if head: @@ -2580,23 +2768,28 @@ class Session(_SessionClassMethods): cascade_states = list( state.manager.mapper.cascade_iterator("delete", state) ) + else: + cascade_states = None self._deleted[state] = obj if head: + if TYPE_CHECKING: + assert cascade_states is not None for o, m, st_, dct_ in cascade_states: self._delete_impl(st_, o, False) def get( self, - entity, - ident, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - execution_options=None, - ): + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -2696,7 +2889,7 @@ class Session(_SessionClassMethods): entity, ident, loading.load_on_pk_identity, - options, + options=options, populate_existing=populate_existing, with_for_update=with_for_update, identity_token=identity_token, @@ -2705,23 +2898,24 @@ class Session(_SessionClassMethods): def _get_impl( self, - entity, - primary_key_identity, - db_load_fn, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - execution_options=None, - ): + entity: _EntityBindKey[_O], + primary_key_identity: _PKIdentityArgument, + db_load_fn: Callable[..., _O], + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Optional[_O]: # convert composite types to individual args - if hasattr(primary_key_identity, "__composite_values__"): + if is_composite_class(primary_key_identity): primary_key_identity = primary_key_identity.__composite_values__() - mapper = inspect(entity) + mapper: Optional[Mapper[_O]] = inspect(entity) - if not mapper or not mapper.is_mapper: + if mapper is None or not mapper.is_mapper: raise sa_exc.ArgumentError( "Expected mapped class or mapper, got: %r" % entity ) @@ -2729,7 +2923,7 @@ class Session(_SessionClassMethods): is_dict = isinstance(primary_key_identity, dict) if not is_dict: primary_key_identity = util.to_list( - primary_key_identity, default=(None,) + primary_key_identity, default=[None] ) if len(primary_key_identity) != len(mapper.primary_key): @@ -2770,11 +2964,12 @@ class Session(_SessionClassMethods): if instance is not None: # reject calls for id in identity map but class # mismatch. - if not issubclass(instance.__class__, mapper.class_): + if not isinstance(instance, mapper.class_): return None return instance - elif instance is attributes.PASSIVE_CLASS_MISMATCH: - return None + + # TODO: this was being tested before, but this is not possible + assert instance is not LoaderCallableStatus.PASSIVE_CLASS_MISMATCH # set_label_style() not strictly necessary, however this will ensure # that tablename_colname style is used which at the moment is @@ -2788,7 +2983,7 @@ class Session(_SessionClassMethods): LABEL_STYLE_TABLENAME_PLUS_COL ) if with_for_update is not None: - statement._for_update_arg = query.ForUpdateArg._from_argument( + statement._for_update_arg = ForUpdateArg._from_argument( with_for_update ) @@ -2803,7 +2998,13 @@ class Session(_SessionClassMethods): load_options=load_options, ) - def merge(self, instance, load=True, options=None): + def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: """Copy the state of a given instance into a corresponding instance within this :class:`.Session`. @@ -2866,8 +3067,8 @@ class Session(_SessionClassMethods): if self._warn_on_events: self._flush_warning("Session.merge()") - _recursive = {} - _resolve_conflict_map = {} + _recursive: Dict[InstanceState[Any], object] = {} + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object] = {} if load: # flush current contents if we expect to load data @@ -2890,20 +3091,23 @@ class Session(_SessionClassMethods): def _merge( self, - state, - state_dict, - load=True, - options=None, - _recursive=None, - _resolve_conflict_map=None, - ): + state: InstanceState[_O], + state_dict: _InstanceDict, + *, + options: Optional[Sequence[ORMOption]] = None, + load: bool, + _recursive: Dict[InstanceState[Any], object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> _O: mapper = _state_mapper(state) if state in _recursive: - return _recursive[state] + return cast(_O, _recursive[state]) new_instance = False key = state.key + merged: Optional[_O] + if key is None: if state in self._new: util.warn( @@ -2920,7 +3124,9 @@ class Session(_SessionClassMethods): "load=False." ) key = mapper._identity_key_from_state(state) - key_is_persistent = attributes.NEVER_SET not in key[1] and ( + key_is_persistent = LoaderCallableStatus.NEVER_SET not in key[ + 1 + ] and ( not _none_set.intersection(key[1]) or ( mapper.allow_partial_pks @@ -2941,7 +3147,7 @@ class Session(_SessionClassMethods): if merged is None: if key_is_persistent and key in _resolve_conflict_map: - merged = _resolve_conflict_map[key] + merged = cast(_O, _resolve_conflict_map[key]) elif not load: if state.modified: @@ -2986,19 +3192,21 @@ class Session(_SessionClassMethods): state, state_dict, mapper.version_id_col, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) merged_version = mapper._get_state_attr_by_column( merged_state, merged_dict, mapper.version_id_col, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) if ( - existing_version is not attributes.PASSIVE_NO_RESULT - and merged_version is not attributes.PASSIVE_NO_RESULT + existing_version + is not LoaderCallableStatus.PASSIVE_NO_RESULT + and merged_version + is not LoaderCallableStatus.PASSIVE_NO_RESULT and existing_version != merged_version ): raise exc.StaleDataError( @@ -3043,14 +3251,14 @@ class Session(_SessionClassMethods): merged_state.manager.dispatch.load(merged_state, None) return merged - def _validate_persistent(self, state): + def _validate_persistent(self, state: InstanceState[Any]) -> None: if not self.identity_map.contains_state(state): raise sa_exc.InvalidRequestError( "Instance '%s' is not persistent within this Session" % state_str(state) ) - def _save_impl(self, state): + def _save_impl(self, state: InstanceState[Any]) -> None: if state.key is not None: raise sa_exc.InvalidRequestError( "Object '%s' already has an identity - " @@ -3065,7 +3273,9 @@ class Session(_SessionClassMethods): if to_attach: self._after_attach(state, obj) - def _update_impl(self, state, revert_deletion=False): + def _update_impl( + self, state: InstanceState[Any], revert_deletion: bool = False + ) -> None: if state.key is None: raise sa_exc.InvalidRequestError( "Instance '%s' is not persisted" % state_str(state) @@ -3103,13 +3313,13 @@ class Session(_SessionClassMethods): elif revert_deletion: self.dispatch.deleted_to_persistent(self, state) - def _save_or_update_impl(self, state): + def _save_or_update_impl(self, state: InstanceState[Any]) -> None: if state.key is None: self._save_impl(state) else: self._update_impl(state) - def enable_relationship_loading(self, obj): + def enable_relationship_loading(self, obj: object) -> None: """Associate an object with this :class:`.Session` for related object loading. @@ -3174,8 +3384,8 @@ class Session(_SessionClassMethods): if to_attach: self._after_attach(state, obj) - def _before_attach(self, state, obj): - self._autobegin() + def _before_attach(self, state: InstanceState[Any], obj: object) -> bool: + self._autobegin_t() if state.session_id == self.hash_key: return False @@ -3191,7 +3401,7 @@ class Session(_SessionClassMethods): return True - def _after_attach(self, state, obj): + def _after_attach(self, state: InstanceState[Any], obj: object) -> None: state.session_id = self.hash_key if state.modified and state._strong_obj is None: state._strong_obj = obj @@ -3202,7 +3412,7 @@ class Session(_SessionClassMethods): else: self.dispatch.transient_to_pending(self, state) - def __contains__(self, instance): + def __contains__(self, instance: object) -> bool: """Return True if the instance is associated with this session. The instance may be pending or persistent within the Session for a @@ -3215,7 +3425,7 @@ class Session(_SessionClassMethods): raise exc.UnmappedInstanceError(instance) from err return self._contains_state(state) - def __iter__(self): + def __iter__(self) -> Iterator[object]: """Iterate over all pending or persistent instances within this Session. @@ -3224,10 +3434,10 @@ class Session(_SessionClassMethods): list(self._new.values()) + list(self.identity_map.values()) ) - def _contains_state(self, state): + def _contains_state(self, state: InstanceState[Any]) -> bool: return state in self._new or self.identity_map.contains_state(state) - def flush(self, objects=None): + def flush(self, objects: Optional[Sequence[Any]] = None) -> None: """Flush all the object changes to the database. Writes out all pending object creations, deletions and modifications @@ -3261,7 +3471,7 @@ class Session(_SessionClassMethods): finally: self._flushing = False - def _flush_warning(self, method): + def _flush_warning(self, method: Any) -> None: util.warn( "Usage of the '%s' operation is not currently supported " "within the execution stage of the flush process. " @@ -3269,14 +3479,14 @@ class Session(_SessionClassMethods): "event listeners or connection-level operations instead." % method ) - def _is_clean(self): + def _is_clean(self) -> bool: return ( not self.identity_map.check_modified() and not self._deleted and not self._new ) - def _flush(self, objects=None): + def _flush(self, objects: Optional[Sequence[object]] = None) -> None: dirty = self._dirty_states if not dirty and not self._deleted and not self._new: @@ -3398,11 +3608,11 @@ class Session(_SessionClassMethods): def bulk_save_objects( self, - objects, - return_defaults=False, - update_changed_only=True, - preserve_order=True, - ): + objects: Iterable[object], + return_defaults: bool = False, + update_changed_only: bool = True, + preserve_order: bool = True, + ) -> None: """Perform a bulk save of the given list of objects. The bulk save feature allows mapped objects to be used as the @@ -3496,6 +3706,8 @@ class Session(_SessionClassMethods): """ + obj_states: Iterable[InstanceState[Any]] + obj_states = (attributes.instance_state(obj) for obj in objects) if not preserve_order: @@ -3508,7 +3720,9 @@ class Session(_SessionClassMethods): key=lambda state: (id(state.mapper), state.key is not None), ) - def grouping_key(state): + def grouping_key( + state: InstanceState[_O], + ) -> Tuple[Mapper[_O], bool]: return (state.mapper, state.key is not None) for (mapper, isupdate), states in itertools.groupby( @@ -3525,8 +3739,12 @@ class Session(_SessionClassMethods): ) def bulk_insert_mappings( - self, mapper, mappings, return_defaults=False, render_nulls=False - ): + self, + mapper: Mapper[Any], + mappings: Iterable[Dict[str, Any]], + return_defaults: bool = False, + render_nulls: bool = False, + ) -> None: """Perform a bulk insert of the given list of mapping dictionaries. The bulk insert feature allows plain Python dictionaries to be used as @@ -3633,7 +3851,9 @@ class Session(_SessionClassMethods): render_nulls, ) - def bulk_update_mappings(self, mapper, mappings): + def bulk_update_mappings( + self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]] + ) -> None: """Perform a bulk update of the given list of mapping dictionaries. The bulk update feature allows plain Python dictionaries to be used as @@ -3696,14 +3916,14 @@ class Session(_SessionClassMethods): def _bulk_save_mappings( self, - mapper, - mappings, - isupdate, - isstates, - return_defaults, - update_changed_only, - render_nulls, - ): + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + isupdate: bool, + isstates: bool, + return_defaults: bool, + update_changed_only: bool, + render_nulls: bool, + ) -> None: mapper = _class_to_mapper(mapper) self._flushing = True @@ -3734,7 +3954,9 @@ class Session(_SessionClassMethods): finally: self._flushing = False - def is_modified(self, instance, include_collections=True): + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: r"""Return ``True`` if the given instance has locally modified attributes. @@ -3800,7 +4022,7 @@ class Session(_SessionClassMethods): continue (added, unchanged, deleted) = attr.impl.get_history( - state, dict_, passive=attributes.NO_CHANGE + state, dict_, passive=PassiveFlag.NO_CHANGE ) if added or deleted: @@ -3809,7 +4031,7 @@ class Session(_SessionClassMethods): return False @property - def is_active(self): + def is_active(self) -> bool: """True if this :class:`.Session` not in "partial rollback" state. .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins @@ -3838,22 +4060,8 @@ class Session(_SessionClassMethods): """ return self._transaction is None or self._transaction.is_active - identity_map = None - """A mapping of object identities to objects themselves. - - Iterating through ``Session.identity_map.values()`` provides - access to the full set of persistent objects (i.e., those - that have row identity) currently in the session. - - .. seealso:: - - :func:`.identity_key` - helper function to produce the keys used - in this dictionary. - - """ - @property - def _dirty_states(self): + def _dirty_states(self) -> Iterable[InstanceState[Any]]: """The set of all persistent states considered dirty. This method returns all states that were modified including @@ -3863,7 +4071,7 @@ class Session(_SessionClassMethods): return self.identity_map._dirty_states() @property - def dirty(self): + def dirty(self) -> IdentitySet: """The set of all persistent instances considered dirty. E.g.:: @@ -3886,7 +4094,7 @@ class Session(_SessionClassMethods): attributes, use the :meth:`.Session.is_modified` method. """ - return util.IdentitySet( + return IdentitySet( [ state.obj() for state in self._dirty_states @@ -3895,13 +4103,13 @@ class Session(_SessionClassMethods): ) @property - def deleted(self): + def deleted(self) -> IdentitySet: "The set of all instances marked as 'deleted' within this ``Session``" return util.IdentitySet(list(self._deleted.values())) @property - def new(self): + def new(self) -> IdentitySet: "The set of all instances marked as 'new' within this ``Session``." return util.IdentitySet(list(self._new.values())) @@ -4002,14 +4210,16 @@ class sessionmaker(_SessionClassMethods): """ + class_: Type[Session] + def __init__( self, - bind=None, - class_=Session, - autoflush=True, - expire_on_commit=True, - info=None, - **kw, + bind: Optional[_SessionBind] = None, + class_: Type[Session] = Session, + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[Dict[Any, Any]] = None, + **kw: Any, ): r"""Construct a new :class:`.sessionmaker`. @@ -4052,7 +4262,7 @@ class sessionmaker(_SessionClassMethods): # events can be associated with it specifically. self.class_ = type(class_.__name__, (class_,), {}) - def begin(self): + def begin(self) -> contextlib.AbstractContextManager[Session]: """Produce a context manager that both provides a new :class:`_orm.Session` as well as a transaction that commits. @@ -4074,7 +4284,7 @@ class sessionmaker(_SessionClassMethods): session = self() return session._maker_context_manager() - def __call__(self, **local_kw): + def __call__(self, **local_kw: Any) -> Session: """Produce a new :class:`.Session` object using the configuration established in this :class:`.sessionmaker`. @@ -4094,7 +4304,7 @@ class sessionmaker(_SessionClassMethods): local_kw.setdefault(k, v) return self.class_(**local_kw) - def configure(self, **new_kw): + def configure(self, **new_kw: Any) -> None: """(Re)configure the arguments for this sessionmaker. e.g.:: @@ -4105,7 +4315,7 @@ class sessionmaker(_SessionClassMethods): """ self.kw.update(new_kw) - def __repr__(self): + def __repr__(self) -> str: return "%s(class_=%r, %s)" % ( self.__class__.__name__, self.class_.__name__, @@ -4113,7 +4323,7 @@ class sessionmaker(_SessionClassMethods): ) -def close_all_sessions(): +def close_all_sessions() -> None: """Close all sessions in memory. This function consults a global registry of all :class:`.Session` objects @@ -4131,7 +4341,7 @@ def close_all_sessions(): sess.close() -def make_transient(instance): +def make_transient(instance: object) -> None: """Alter the state of the given instance so that it is :term:`transient`. .. note:: @@ -4195,7 +4405,7 @@ def make_transient(instance): del state._deleted -def make_transient_to_detached(instance): +def make_transient_to_detached(instance: object) -> None: """Make the given transient instance :term:`detached`. .. note:: @@ -4234,7 +4444,7 @@ def make_transient_to_detached(instance): state._expire_attributes(state.dict, state.unloaded_expirable) -def object_session(instance): +def object_session(instance: object) -> Optional[Session]: """Return the :class:`.Session` to which the given instance belongs. This is essentially the same as the :attr:`.InstanceState.session` diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index c3e4e299ab..7ccda95659 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -14,13 +14,25 @@ defines a large part of the ORM's interactivity. from __future__ import annotations +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar import weakref from . import base from . import exc as orm_exc from . import interfaces +from ._typing import is_collection_impl from .base import ATTR_WAS_SET from .base import INIT_OK +from .base import LoaderCallableStatus from .base import NEVER_SET from .base import NO_VALUE from .base import PASSIVE_NO_INITIALIZE @@ -31,17 +43,47 @@ from .path_registry import PathRegistry from .. import exc as sa_exc from .. import inspection from .. import util +from ..util.typing import Protocol +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _LoaderCallable + from .attributes import AttributeImpl + from .attributes import History + from .base import LoaderCallableStatus + from .base import PassiveFlag + from .identity import IdentityMap + from .instrumentation import ClassManager + from .interfaces import ORMOption + from .mapper import Mapper + from .session import Session + from ..engine import Row + from ..ext.asyncio.session import async_session as _async_provider + from ..ext.asyncio.session import AsyncSession -# late-populated by session.py -_sessions = None +_T = TypeVar("_T", bound=Any) -# optionally late-provided by sqlalchemy.ext.asyncio.session -_async_provider = None +if TYPE_CHECKING: + _sessions: weakref.WeakValueDictionary[int, Session] +else: + # late-populated by session.py + _sessions = None + + +if not TYPE_CHECKING: + # optionally late-provided by sqlalchemy.ext.asyncio.session + + _async_provider = None # noqa + + +class _InstanceDictProto(Protocol): + def __call__(self) -> Optional[IdentityMap]: + ... @inspection._self_inspects -class InstanceState(interfaces.InspectionAttrInfo): +class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): """tracks state information at the instance level. The :class:`.InstanceState` is a key object used by the @@ -67,23 +109,57 @@ class InstanceState(interfaces.InspectionAttrInfo): """ - session_id = None - key = None - runid = None - load_options = () - load_path = PathRegistry.root - insert_order = None - _strong_obj = None - modified = False - expired = False - _deleted = False - _load_pending = False - _orphaned_outside_of_session = False - is_instance = True - identity_token = None - _last_known_values = () - - callables = () + __slots__ = ( + "__dict__", + "__weakref__", + "class_", + "manager", + "obj", + "committed_state", + "expired_attributes", + ) + + manager: ClassManager[_T] + session_id: Optional[int] = None + key: Optional[_IdentityKeyType[_T]] = None + runid: Optional[int] = None + load_options: Tuple[ORMOption, ...] = () + load_path: PathRegistry = PathRegistry.root + insert_order: Optional[int] = None + _strong_obj: Optional[object] = None + obj: weakref.ref[_T] + + committed_state: Dict[str, Any] + + modified: bool = False + expired: bool = False + _deleted: bool = False + _load_pending: bool = False + _orphaned_outside_of_session: bool = False + is_instance: bool = True + identity_token: object = None + _last_known_values: Optional[Dict[str, Any]] = None + + _instance_dict: _InstanceDictProto + """A weak reference, or in the default case a plain callable, that + returns a reference to the current :class:`.IdentityMap`, if any. + + """ + if not TYPE_CHECKING: + + def _instance_dict(self): + """default 'weak reference' for _instance_dict""" + return None + + expired_attributes: Set[str] + """The set of keys which are 'expired' to be loaded by + the manager's deferred scalar loader, assuming no pending + changes. + + see also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs.""" + + callables: Dict[str, Callable[[InstanceState[_T], PassiveFlag], Any]] """A namespace where a per-state loader callable can be associated. In SQLAlchemy 1.0, this is only used for lazy loaders / deferred @@ -95,23 +171,18 @@ class InstanceState(interfaces.InspectionAttrInfo): """ - def __init__(self, obj, manager): + if not TYPE_CHECKING: + callables = util.EMPTY_DICT + + def __init__(self, obj: _T, manager: ClassManager[_T]): self.class_ = obj.__class__ self.manager = manager self.obj = weakref.ref(obj, self._cleanup) self.committed_state = {} self.expired_attributes = set() - expired_attributes = None - """The set of keys which are 'expired' to be loaded by - the manager's deferred scalar loader, assuming no pending - changes. - - see also the ``unmodified`` collection which is intersected - against this set when a refresh operation occurs.""" - @util.memoized_property - def attrs(self): + def attrs(self) -> util.ReadOnlyProperties[AttributeState]: """Return a namespace representing each attribute on the mapped object, including its current value and history. @@ -123,11 +194,11 @@ class InstanceState(interfaces.InspectionAttrInfo): """ return util.ReadOnlyProperties( - dict((key, AttributeState(self, key)) for key in self.manager) + {key: AttributeState(self, key) for key in self.manager} ) @property - def transient(self): + def transient(self) -> bool: """Return ``True`` if the object is :term:`transient`. .. seealso:: @@ -138,7 +209,7 @@ class InstanceState(interfaces.InspectionAttrInfo): return self.key is None and not self._attached @property - def pending(self): + def pending(self) -> bool: """Return ``True`` if the object is :term:`pending`. @@ -150,7 +221,7 @@ class InstanceState(interfaces.InspectionAttrInfo): return self.key is None and self._attached @property - def deleted(self): + def deleted(self) -> bool: """Return ``True`` if the object is :term:`deleted`. An object that is in the deleted state is guaranteed to @@ -180,7 +251,7 @@ class InstanceState(interfaces.InspectionAttrInfo): return self.key is not None and self._attached and self._deleted @property - def was_deleted(self): + def was_deleted(self) -> bool: """Return True if this object is or was previously in the "deleted" state and has not been reverted to persistent. @@ -204,7 +275,7 @@ class InstanceState(interfaces.InspectionAttrInfo): return self._deleted @property - def persistent(self): + def persistent(self) -> bool: """Return ``True`` if the object is :term:`persistent`. An object that is in the persistent state is guaranteed to @@ -225,7 +296,7 @@ class InstanceState(interfaces.InspectionAttrInfo): return self.key is not None and self._attached and not self._deleted @property - def detached(self): + def detached(self) -> bool: """Return ``True`` if the object is :term:`detached`. .. seealso:: @@ -235,15 +306,15 @@ class InstanceState(interfaces.InspectionAttrInfo): """ return self.key is not None and not self._attached - @property + @util.non_memoized_property @util.preload_module("sqlalchemy.orm.session") - def _attached(self): + def _attached(self) -> bool: return ( self.session_id is not None and self.session_id in util.preloaded.orm_session._sessions ) - def _track_last_known_value(self, key): + def _track_last_known_value(self, key: str) -> None: """Track the last known value of a particular key after expiration operations. @@ -251,12 +322,14 @@ class InstanceState(interfaces.InspectionAttrInfo): """ - if key not in self._last_known_values: - self._last_known_values = dict(self._last_known_values) - self._last_known_values[key] = NO_VALUE + lkv = self._last_known_values + if lkv is None: + self._last_known_values = lkv = {} + if key not in lkv: + lkv[key] = NO_VALUE @property - def session(self): + def session(self) -> Optional[Session]: """Return the owning :class:`.Session` for this instance, or ``None`` if none available. @@ -280,7 +353,7 @@ class InstanceState(interfaces.InspectionAttrInfo): return None @property - def async_session(self): + def async_session(self) -> Optional[AsyncSession]: """Return the owning :class:`_asyncio.AsyncSession` for this instance, or ``None`` if none available. @@ -308,13 +381,17 @@ class InstanceState(interfaces.InspectionAttrInfo): return None @property - def object(self): + def object(self) -> Optional[_T]: """Return the mapped object represented by this - :class:`.InstanceState`.""" + :class:`.InstanceState`. + + Returns None if the object has been garbage collected + + """ return self.obj() @property - def identity(self): + def identity(self) -> Optional[Tuple[Any, ...]]: """Return the mapped identity of the mapped object. This is the primary key identity as persisted by the ORM which can always be passed directly to @@ -334,7 +411,7 @@ class InstanceState(interfaces.InspectionAttrInfo): return self.key[1] @property - def identity_key(self): + def identity_key(self) -> Optional[_IdentityKeyType[_T]]: """Return the identity key for the mapped object. This is the key used to locate the object within @@ -343,29 +420,27 @@ class InstanceState(interfaces.InspectionAttrInfo): """ - # TODO: just change .key to .identity_key across - # the board ? probably return self.key @util.memoized_property - def parents(self): + def parents(self) -> Dict[int, InstanceState[Any]]: return {} @util.memoized_property - def _pending_mutations(self): + def _pending_mutations(self) -> Dict[str, PendingCollection]: return {} @util.memoized_property - def _empty_collections(self): + def _empty_collections(self) -> Dict[Any, Any]: return {} @util.memoized_property - def mapper(self): + def mapper(self) -> Mapper[_T]: """Return the :class:`_orm.Mapper` used for this mapped object.""" return self.manager.mapper @property - def has_identity(self): + def has_identity(self) -> bool: """Return ``True`` if this object has an identity key. This should always have the same value as the @@ -375,7 +450,12 @@ class InstanceState(interfaces.InspectionAttrInfo): return bool(self.key) @classmethod - def _detach_states(self, states, session, to_transient=False): + def _detach_states( + self, + states: Iterable[InstanceState[_T]], + session: Session, + to_transient: bool = False, + ) -> None: persistent_to_detached = ( session.dispatch.persistent_to_detached or None ) @@ -407,17 +487,17 @@ class InstanceState(interfaces.InspectionAttrInfo): state._strong_obj = None - def _detach(self, session=None): + def _detach(self, session: Optional[Session] = None) -> None: if session: InstanceState._detach_states([self], session) else: self.session_id = self._strong_obj = None - def _dispose(self): + def _dispose(self) -> None: + # used by the test suite, apparently self._detach() - del self.obj - def _cleanup(self, ref): + def _cleanup(self, ref: weakref.ref[_T]) -> None: """Weakref callback cleanup. This callable cleans out the state when it is being garbage @@ -445,13 +525,9 @@ class InstanceState(interfaces.InspectionAttrInfo): # assert self not in instance_dict._modified self.session_id = self._strong_obj = None - del self.obj - - def obj(self): - return None @property - def dict(self): + def dict(self) -> _InstanceDict: """Return the instance dict used by the object. Under normal circumstances, this is always synonymous @@ -469,35 +545,39 @@ class InstanceState(interfaces.InspectionAttrInfo): else: return {} - def _initialize_instance(*mixed, **kwargs): + def _initialize_instance(*mixed: Any, **kwargs: Any) -> None: self, instance, args = mixed[0], mixed[1], mixed[2:] # noqa manager = self.manager manager.dispatch.init(self, args, kwargs) try: - return manager.original_init(*mixed[1:], **kwargs) + manager.original_init(*mixed[1:], **kwargs) except: with util.safe_reraise(): manager.dispatch.init_failure(self, args, kwargs) - def get_history(self, key, passive): + def get_history(self, key: str, passive: PassiveFlag) -> History: return self.manager[key].impl.get_history(self, self.dict, passive) - def get_impl(self, key): + def get_impl(self, key: str) -> AttributeImpl: return self.manager[key].impl - def _get_pending_mutation(self, key): + def _get_pending_mutation(self, key: str) -> PendingCollection: if key not in self._pending_mutations: self._pending_mutations[key] = PendingCollection() return self._pending_mutations[key] - def __getstate__(self): - state_dict = {"instance": self.obj()} + def __getstate__(self) -> Dict[str, Any]: + state_dict = { + "instance": self.obj(), + "class_": self.class_, + "committed_state": self.committed_state, + "expired_attributes": self.expired_attributes, + } state_dict.update( (k, self.__dict__[k]) for k in ( - "committed_state", "_pending_mutations", "modified", "expired", @@ -518,21 +598,18 @@ class InstanceState(interfaces.InspectionAttrInfo): return state_dict - def __setstate__(self, state_dict): + def __setstate__(self, state_dict: Dict[str, Any]) -> None: inst = state_dict["instance"] if inst is not None: self.obj = weakref.ref(inst, self._cleanup) self.class_ = inst.__class__ else: - # None being possible here generally new as of 0.7.4 - # due to storage of state in "parents". "class_" - # also new. - self.obj = None + self.obj = lambda: None # type: ignore self.class_ = state_dict["class_"] self.committed_state = state_dict.get("committed_state", {}) - self._pending_mutations = state_dict.get("_pending_mutations", {}) - self.parents = state_dict.get("parents", {}) + self._pending_mutations = state_dict.get("_pending_mutations", {}) # type: ignore # noqa E501 + self.parents = state_dict.get("parents", {}) # type: ignore self.modified = state_dict.get("modified", False) self.expired = state_dict.get("expired", False) if "info" in state_dict: @@ -540,15 +617,7 @@ class InstanceState(interfaces.InspectionAttrInfo): if "callables" in state_dict: self.callables = state_dict["callables"] - try: - self.expired_attributes = state_dict["expired_attributes"] - except KeyError: - self.expired_attributes = set() - # 0.9 and earlier compat - for k in list(self.callables): - if self.callables[k] is self: - self.expired_attributes.add(k) - del self.callables[k] + self.expired_attributes = state_dict["expired_attributes"] else: if "expired_attributes" in state_dict: self.expired_attributes = state_dict["expired_attributes"] @@ -563,57 +632,61 @@ class InstanceState(interfaces.InspectionAttrInfo): ] ) if self.key: - try: - self.identity_token = self.key[2] - except IndexError: - # 1.1 and earlier compat before identity_token - assert len(self.key) == 2 - self.key = self.key + (None,) - self.identity_token = None + self.identity_token = self.key[2] if "load_path" in state_dict: self.load_path = PathRegistry.deserialize(state_dict["load_path"]) state_dict["manager"](self, inst, state_dict) - def _reset(self, dict_, key): + def _reset(self, dict_: _InstanceDict, key: str) -> None: """Remove the given attribute and any callables associated with it.""" old = dict_.pop(key, None) - if old is not None and self.manager[key].impl.collection: - self.manager[key].impl._invalidate_collection(old) + manager_impl = self.manager[key].impl + if old is not None and is_collection_impl(manager_impl): + manager_impl._invalidate_collection(old) self.expired_attributes.discard(key) if self.callables: self.callables.pop(key, None) - def _copy_callables(self, from_): + def _copy_callables(self, from_: InstanceState[Any]) -> None: if "callables" in from_.__dict__: self.callables = dict(from_.callables) @classmethod - def _instance_level_callable_processor(cls, manager, fn, key): + def _instance_level_callable_processor( + cls, manager: ClassManager[_T], fn: _LoaderCallable, key: Any + ) -> Callable[[InstanceState[_T], _InstanceDict, Row], None]: impl = manager[key].impl - if impl.collection: + if is_collection_impl(impl): + fixed_impl = impl - def _set_callable(state, dict_, row): + def _set_callable( + state: InstanceState[_T], dict_: _InstanceDict, row: Row + ) -> None: if "callables" not in state.__dict__: state.callables = {} old = dict_.pop(key, None) if old is not None: - impl._invalidate_collection(old) + fixed_impl._invalidate_collection(old) state.callables[key] = fn else: - def _set_callable(state, dict_, row): + def _set_callable( + state: InstanceState[_T], dict_: _InstanceDict, row: Row + ) -> None: if "callables" not in state.__dict__: state.callables = {} state.callables[key] = fn return _set_callable - def _expire(self, dict_, modified_set): + def _expire( + self, dict_: _InstanceDict, modified_set: Set[InstanceState[Any]] + ) -> None: self.expired = True if self.modified: modified_set.discard(self) @@ -653,7 +726,7 @@ class InstanceState(interfaces.InspectionAttrInfo): if self._last_known_values: self._last_known_values.update( - (k, dict_[k]) for k in self._last_known_values if k in dict_ + {k: dict_[k] for k in self._last_known_values if k in dict_} ) for key in self.manager._all_key_set.intersection(dict_): @@ -661,7 +734,12 @@ class InstanceState(interfaces.InspectionAttrInfo): self.manager.dispatch.expire(self, None) - def _expire_attributes(self, dict_, attribute_names, no_loader=False): + def _expire_attributes( + self, + dict_: _InstanceDict, + attribute_names: Iterable[str], + no_loader: bool = False, + ) -> None: pending = self.__dict__.get("_pending_mutations", None) callables = self.callables @@ -676,15 +754,12 @@ class InstanceState(interfaces.InspectionAttrInfo): if callables and key in callables: del callables[key] old = dict_.pop(key, NO_VALUE) - if impl.collection and old is not NO_VALUE: + if is_collection_impl(impl) and old is not NO_VALUE: impl._invalidate_collection(old) - if ( - self._last_known_values - and key in self._last_known_values - and old is not NO_VALUE - ): - self._last_known_values[key] = old + lkv = self._last_known_values + if lkv is not None and key in lkv and old is not NO_VALUE: + lkv[key] = old self.committed_state.pop(key, None) if pending: @@ -692,7 +767,9 @@ class InstanceState(interfaces.InspectionAttrInfo): self.manager.dispatch.expire(self, attribute_names) - def _load_expired(self, state, passive): + def _load_expired( + self, state: InstanceState[_T], passive: PassiveFlag + ) -> LoaderCallableStatus: """__call__ allows the InstanceState to act as a deferred callable for loading expired attributes, which is also serializable (picklable). @@ -720,12 +797,12 @@ class InstanceState(interfaces.InspectionAttrInfo): return ATTR_WAS_SET @property - def unmodified(self): + def unmodified(self) -> Set[str]: """Return the set of keys which have no uncommitted changes""" return set(self.manager).difference(self.committed_state) - def unmodified_intersection(self, keys): + def unmodified_intersection(self, keys: Iterable[str]) -> Set[str]: """Return self.unmodified.intersection(keys).""" return ( @@ -735,7 +812,7 @@ class InstanceState(interfaces.InspectionAttrInfo): ) @property - def unloaded(self): + def unloaded(self) -> Set[str]: """Return the set of keys which do not have a loaded value. This includes expired attributes and any other attribute that @@ -749,7 +826,7 @@ class InstanceState(interfaces.InspectionAttrInfo): ) @property - def unloaded_expirable(self): + def unloaded_expirable(self) -> Set[str]: """Return the set of keys which do not have a loaded value. This includes expired attributes and any other attribute that @@ -759,19 +836,21 @@ class InstanceState(interfaces.InspectionAttrInfo): return self.unloaded @property - def _unloaded_non_object(self): + def _unloaded_non_object(self) -> Set[str]: return self.unloaded.intersection( attr for attr in self.manager if self.manager[attr].impl.accepts_scalar_loader ) - def _instance_dict(self): - return None - def _modified_event( - self, dict_, attr, previous, collection=False, is_userland=False - ): + self, + dict_: _InstanceDict, + attr: AttributeImpl, + previous: Any, + collection: bool = False, + is_userland: bool = False, + ) -> None: if attr: if not attr.send_modified_events: return @@ -782,6 +861,8 @@ class InstanceState(interfaces.InspectionAttrInfo): ) if attr.key not in self.committed_state or is_userland: if collection: + if TYPE_CHECKING: + assert is_collection_impl(attr) if previous is NEVER_SET: if attr.key in dict_: previous = dict_[attr.key] @@ -790,8 +871,9 @@ class InstanceState(interfaces.InspectionAttrInfo): previous = attr.copy(previous) self.committed_state[attr.key] = previous - if attr.key in self._last_known_values: - self._last_known_values[attr.key] = NO_VALUE + lkv = self._last_known_values + if lkv is not None and attr.key in lkv: + lkv[attr.key] = NO_VALUE # assert self._strong_obj is None or self.modified @@ -823,7 +905,7 @@ class InstanceState(interfaces.InspectionAttrInfo): pass else: if session._transaction is None: - session._autobegin() + session._autobegin_t() if inst is None and attr: raise orm_exc.ObjectDereferencedError( @@ -833,7 +915,7 @@ class InstanceState(interfaces.InspectionAttrInfo): % (self.manager[attr.key], base.state_class_str(self)) ) - def _commit(self, dict_, keys): + def _commit(self, dict_: _InstanceDict, keys: Iterable[str]) -> None: """Commit attributes. This is used by a partial-attribute load operation to mark committed @@ -862,7 +944,9 @@ class InstanceState(interfaces.InspectionAttrInfo): ): del self.callables[key] - def _commit_all(self, dict_, instance_dict=None): + def _commit_all( + self, dict_: _InstanceDict, instance_dict: Optional[IdentityMap] = None + ) -> None: """commit all attributes unconditionally. This is used after a flush() or a full load/refresh @@ -881,7 +965,11 @@ class InstanceState(interfaces.InspectionAttrInfo): self._commit_all_states([(self, dict_)], instance_dict) @classmethod - def _commit_all_states(self, iter_, instance_dict=None): + def _commit_all_states( + self, + iter_: Iterable[Tuple[InstanceState[Any], _InstanceDict]], + instance_dict: Optional[IdentityMap] = None, + ) -> None: """Mass / highly inlined version of commit_all().""" for state, dict_ in iter_: @@ -916,12 +1004,17 @@ class AttributeState: """ - def __init__(self, state, key): + __slots__ = ("state", "key") + + state: InstanceState[Any] + key: str + + def __init__(self, state: InstanceState[Any], key: str): self.state = state self.key = key @property - def loaded_value(self): + def loaded_value(self) -> Any: """The current value of this attribute as loaded from the database. If the value has not been loaded, or is otherwise not present @@ -931,7 +1024,7 @@ class AttributeState: return self.state.dict.get(self.key, NO_VALUE) @property - def value(self): + def value(self) -> Any: """Return the value of this attribute. This operation is equivalent to accessing the object's @@ -944,7 +1037,7 @@ class AttributeState: ) @property - def history(self): + def history(self) -> History: """Return the current **pre-flush** change history for this attribute, via the :class:`.History` interface. @@ -971,7 +1064,7 @@ class AttributeState: """ return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE) - def load_history(self): + def load_history(self) -> History: """Return the current **pre-flush** change history for this attribute, via the :class:`.History` interface. @@ -1008,17 +1101,22 @@ class PendingCollection: """ - def __init__(self): + __slots__ = ("deleted_items", "added_items") + + deleted_items: util.IdentitySet + added_items: util.OrderedIdentitySet + + def __init__(self) -> None: self.deleted_items = util.IdentitySet() self.added_items = util.OrderedIdentitySet() - def append(self, value): + def append(self, value: Any) -> None: if value in self.deleted_items: self.deleted_items.remove(value) else: self.added_items.add(value) - def remove(self, value): + def remove(self, value: Any) -> None: if value in self.added_items: self.added_items.remove(value) else: diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py index 1afeab05bc..b7bf965585 100644 --- a/lib/sqlalchemy/orm/state_changes.py +++ b/lib/sqlalchemy/orm/state_changes.py @@ -16,12 +16,15 @@ from typing import Any from typing import Callable from typing import Optional from typing import Tuple +from typing import TypeVar from typing import Union from .. import exc as sa_exc from .. import util from ..util.typing import Literal +_F = TypeVar("_F", bound=Callable[..., Any]) + class _StateChangeState(Enum): pass @@ -60,7 +63,7 @@ class _StateChange: Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...] ], moves_to: _StateChangeState, - ) -> Callable[..., Any]: + ) -> Callable[[_F], _F]: """Method decorator declaring valid states. :param prerequisite_states: sequence of acceptable prerequisite diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index da098e8c5f..9ff284e733 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -15,6 +15,12 @@ organizes them in order of dependency, and executes. from __future__ import annotations +from typing import Any +from typing import Dict +from typing import Optional +from typing import Set +from typing import TYPE_CHECKING + from . import attributes from . import exc as orm_exc from . import util as orm_util @@ -23,6 +29,15 @@ from .. import util from ..util import topological +if TYPE_CHECKING: + from .dependency import DependencyProcessor + from .interfaces import MapperProperty + from .mapper import Mapper + from .session import Session + from .session import SessionTransaction + from .state import InstanceState + + def track_cascade_events(descriptor, prop): """Establish event listeners on object attributes which handle cascade-on-set/append. @@ -131,7 +146,13 @@ def track_cascade_events(descriptor, prop): class UOWTransaction: - def __init__(self, session): + session: Session + transaction: SessionTransaction + attributes: Dict[str, Any] + deps: util.defaultdict[Mapper[Any], Set[DependencyProcessor]] + mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]] + + def __init__(self, session: Session): self.session = session # dictionary used by external actors to @@ -275,13 +296,13 @@ class UOWTransaction: def register_object( self, - state, - isdelete=False, - listonly=False, - cancel_delete=False, - operation=None, - prop=None, - ): + state: InstanceState[Any], + isdelete: bool = False, + listonly: bool = False, + cancel_delete: bool = False, + operation: Optional[str] = None, + prop: Optional[MapperProperty] = None, + ) -> bool: if not self.session._contains_state(state): # this condition is normal when objects are registered # as part of a relationship cascade operation. it should @@ -408,7 +429,7 @@ class UOWTransaction: [a for a in self.postsort_actions.values() if not a.disabled] ).difference(cycles) - def execute(self): + def execute(self) -> None: postsort_actions = self._generate_actions() postsort_actions = sorted( @@ -435,7 +456,7 @@ class UOWTransaction: for rec in topological.sort(self.dependencies, postsort_actions): rec.execute(self) - def finalize_flush_changes(self): + def finalize_flush_changes(self) -> None: """Mark processed objects as clean / deleted after a successful flush(). diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index baca8f5476..233085f305 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -13,7 +13,6 @@ import typing from typing import Any from typing import Generic from typing import Optional -from typing import overload from typing import Tuple from typing import Type from typing import TypeVar @@ -22,7 +21,6 @@ import weakref from . import attributes # noqa from .base import _class_to_mapper # noqa -from .base import _IdentityKeyType from .base import _never_set # noqa from .base import _none_set # noqa from .base import attribute_str # noqa @@ -62,6 +60,9 @@ from ..util.typing import de_stringify_annotation from ..util.typing import is_origin_of if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InternalEntityType from .mapper import Mapper from ..engine import Row from ..sql._typing import _PropagateAttrsType @@ -297,27 +298,13 @@ def polymorphic_union( return sql.union_all(*result).alias(aliasname) -@overload def identity_key( - class_: type, ident: Tuple[Any, ...], *, identity_token: Optional[str] -) -> _IdentityKeyType: - ... - - -@overload -def identity_key(*, instance: Any) -> _IdentityKeyType: - ... - - -@overload -def identity_key( - class_: type, *, row: "Row", identity_token: Optional[str] -) -> _IdentityKeyType: - ... - - -def identity_key( - class_=None, ident=None, *, instance=None, row=None, identity_token=None + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Row] = None, + identity_token: Optional[Any] = None, ) -> _IdentityKeyType: r"""Generate "identity key" tuples, as are used as keys in the :attr:`.Session.identity_map` dictionary. @@ -634,6 +621,7 @@ class AliasedInsp( sql_base.HasCacheKey, InspectionAttr, MemoizedSlots, + Generic[_T], ): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -699,8 +687,8 @@ class AliasedInsp( def __init__( self, - entity, - inspected, + entity: _EntityType, + inspected: _InternalEntityType, selectable, name, with_polymorphic_mappers, @@ -1797,6 +1785,32 @@ def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type): return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") +def _cleanup_mapped_str_annotation(annotation): + # fix up an annotation that comes in as the form: + # 'Mapped[List[Address]]' so that it instead looks like: + # 'Mapped[List["Address"]]' , which will allow us to get + # "Address" as a string + mm = re.match(r"^(.+?)\[(.+)\]$", annotation) + if mm and mm.group(1) == "Mapped": + stack = [] + inner = mm + while True: + stack.append(inner.group(1)) + g2 = inner.group(2) + inner = re.match(r"^(.+?)\[(.+)\]$", g2) + if inner is None: + stack.append(g2) + break + + # stack: ['Mapped', 'List', 'Address'] + if not re.match(r"""^["'].*["']$""", stack[-1]): + stack[-1] = f'"{stack[-1]}"' + # stack: ['Mapped', 'List', '"Address"'] + + annotation = "[".join(stack) + ("]" * (len(stack) - 1)) + return annotation + + def _extract_mapped_subtype( raw_annotation: Union[type, str], cls: type, @@ -1816,7 +1830,9 @@ def _extract_mapped_subtype( ) return None - annotated = de_stringify_annotation(cls, raw_annotation) + annotated = de_stringify_annotation( + cls, raw_annotation, _cleanup_mapped_str_annotation + ) if is_dataclass_field: return annotated diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index bc1e0672c4..7e3a1c4e8d 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from sqlalchemy.sql.base import Executable from . import roles from .. import util from ..inspection import Inspectable @@ -183,10 +184,14 @@ if TYPE_CHECKING: def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: ... - def is_select_base(t: ReturnsRows) -> TypeGuard[SelectBase]: + def is_select_base( + t: Union[Executable, ReturnsRows] + ) -> TypeGuard[SelectBase]: ... - def is_select_statement(t: ReturnsRows) -> TypeGuard[Select]: + def is_select_statement( + t: Union[Executable, ReturnsRows] + ) -> TypeGuard[Select]: ... def is_table(t: FromClause) -> TypeGuard[TableClause]: diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 7fb9c26026..ccd5e8c40e 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -648,7 +648,9 @@ class CompileState: return None @classmethod - def _get_plugin_class_for_plugin(cls, statement, plugin_name): + def _get_plugin_class_for_plugin( + cls, statement: Executable, plugin_name: str + ) -> Optional[Type[CompileState]]: try: return cls.plugins[ (plugin_name, statement._effective_plugin_target) @@ -790,7 +792,7 @@ class Options(metaclass=_MetaOptions): ) @classmethod - def isinstance(cls, klass): + def isinstance(cls, klass: Type[Any]) -> bool: return issubclass(cls, klass) @hybridmethod @@ -912,6 +914,8 @@ class ExecutableOption(HasCopyInternals): _is_has_cache_key = False + _is_core = True + def _clone(self, **kw): """Create a shallow copy of this ExecutableOption.""" c = self.__class__.__new__(self.__class__) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 4c71ca38b1..623bb0be2e 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -114,7 +114,7 @@ def _deep_is_literal(element): schema.SchemaEventTarget, HasCacheKey, Options, - util.langhelpers._symbol, + util.langhelpers.symbol, ), ) and not hasattr(element, "__clause_element__") diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d7cc327333..99a6baa890 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3037,7 +3037,9 @@ class ForUpdateArg(ClauseElement): skip_locked: bool @classmethod - def _from_argument(cls, with_for_update): + def _from_argument( + cls, with_for_update: Union[ForUpdateArg, None, bool, Dict[str, Any]] + ) -> Optional[ForUpdateArg]: if isinstance(with_for_update, ForUpdateArg): return with_for_update elif with_for_update in (None, False): @@ -3045,7 +3047,7 @@ class ForUpdateArg(ClauseElement): elif with_for_update is True: return ForUpdateArg() else: - return ForUpdateArg(**with_for_update) + return ForUpdateArg(**cast("Dict[str, Any]", with_for_update)) def __eq__(self, other): return ( diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 4253aa61bc..da6292fcf4 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -49,6 +49,7 @@ from .config import combinations_list from .config import db from .config import fixture from .config import requirements as requires +from .config import skip_test from .exclusions import _is_excluded from .exclusions import _server_version from .exclusions import against as _against diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index c0c2e7dfb7..6d41231d98 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -96,6 +96,7 @@ from .langhelpers import dictlike_iteritems as dictlike_iteritems from .langhelpers import duck_type_collection as duck_type_collection from .langhelpers import ellipses_string as ellipses_string from .langhelpers import EnsureKWArg as EnsureKWArg +from .langhelpers import FastIntFlag as FastIntFlag from .langhelpers import format_argspec_init as format_argspec_init from .langhelpers import format_argspec_plus as format_argspec_plus from .langhelpers import generic_fn_descriptor as generic_fn_descriptor diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index eb5b16b650..eea76f60b0 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -22,6 +22,7 @@ from typing import Generic from typing import Iterable from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Set @@ -123,7 +124,7 @@ def merge_lists_w_ordering(a, b): return result -def coerce_to_immutabledict(d): +def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: if not d: return EMPTY_DICT elif isinstance(d, immutabledict): @@ -161,6 +162,8 @@ class FacadeDict(ImmutableDictBase[_KT, _VT]): _DT = TypeVar("_DT", bound=Any) +_F = TypeVar("_F", bound=Any) + class Properties(Generic[_T]): """Provide a __getattr__/__setattr__ interface over a dict.""" @@ -169,7 +172,7 @@ class Properties(Generic[_T]): _data: Dict[str, _T] - def __init__(self, data): + def __init__(self, data: Dict[str, _T]): object.__setattr__(self, "_data", data) def __len__(self) -> int: @@ -178,30 +181,30 @@ class Properties(Generic[_T]): def __iter__(self) -> Iterator[_T]: return iter(list(self._data.values())) - def __dir__(self): + def __dir__(self) -> List[str]: return dir(super(Properties, self)) + [ str(k) for k in self._data.keys() ] - def __add__(self, other): - return list(self) + list(other) + def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: + return list(self) + list(other) # type: ignore - def __setitem__(self, key, obj): + def __setitem__(self, key: str, obj: _T) -> None: self._data[key] = obj def __getitem__(self, key: str) -> _T: return self._data[key] - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._data[key] - def __setattr__(self, key, obj): + def __setattr__(self, key: str, obj: _T) -> None: self._data[key] = obj - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return {"_data": self._data} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: object.__setattr__(self, "_data", state["_data"]) def __getattr__(self, key: str) -> _T: @@ -213,12 +216,12 @@ class Properties(Generic[_T]): def __contains__(self, key: str) -> bool: return key in self._data - def as_readonly(self) -> "ReadOnlyProperties[_T]": + def as_readonly(self) -> ReadOnlyProperties[_T]: """Return an immutable proxy for this :class:`.Properties`.""" return ReadOnlyProperties(self._data) - def update(self, value): + def update(self, value: Dict[str, _T]) -> None: self._data.update(value) @overload @@ -249,7 +252,7 @@ class Properties(Generic[_T]): def has_key(self, key: str) -> bool: return key in self._data - def clear(self): + def clear(self) -> None: self._data.clear() @@ -318,7 +321,7 @@ class WeakSequence: class OrderedIdentitySet(IdentitySet): - def __init__(self, iterable=None): + def __init__(self, iterable: Optional[Iterable[Any]] = None): IdentitySet.__init__(self) self._members = OrderedDict() if iterable: @@ -615,7 +618,9 @@ class ScopedRegistry(Generic[_T]): scopefunc: _ScopeFuncType registry: Any - def __init__(self, createfunc, scopefunc): + def __init__( + self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] + ): """Construct a new :class:`.ScopedRegistry`. :param createfunc: A creation function that will generate diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index d649a0bea7..725f6930ee 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -263,52 +263,54 @@ class IdentitySet: """ - def __init__(self, iterable=None): + _members: Dict[int, Any] + + def __init__(self, iterable: Optional[Iterable[Any]] = None): self._members = dict() if iterable: self.update(iterable) - def add(self, value): + def add(self, value: Any) -> None: self._members[id(value)] = value - def __contains__(self, value): + def __contains__(self, value: Any) -> bool: return id(value) in self._members - def remove(self, value): + def remove(self, value: Any) -> None: del self._members[id(value)] - def discard(self, value): + def discard(self, value: Any) -> None: try: self.remove(value) except KeyError: pass - def pop(self): + def pop(self) -> Any: try: pair = self._members.popitem() return pair[1] except KeyError: raise KeyError("pop from an empty set") - def clear(self): + def clear(self) -> None: self._members.clear() - def __cmp__(self, other): + def __cmp__(self, other: Any) -> NoReturn: raise TypeError("cannot compare sets using cmp()") - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, IdentitySet): return self._members == other._members else: return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: if isinstance(other, IdentitySet): return self._members != other._members else: return True - def issubset(self, iterable): + def issubset(self, iterable: Iterable[Any]) -> bool: if isinstance(iterable, self.__class__): other = iterable else: @@ -322,17 +324,17 @@ class IdentitySet: return False return True - def __le__(self, other): + def __le__(self, other: Any) -> bool: if not isinstance(other, IdentitySet): return NotImplemented return self.issubset(other) - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: if not isinstance(other, IdentitySet): return NotImplemented return len(self) < len(other) and self.issubset(other) - def issuperset(self, iterable): + def issuperset(self, iterable: Iterable[Any]) -> bool: if isinstance(iterable, self.__class__): other = iterable else: @@ -347,38 +349,38 @@ class IdentitySet: return False return True - def __ge__(self, other): + def __ge__(self, other: Any) -> bool: if not isinstance(other, IdentitySet): return NotImplemented return self.issuperset(other) - def __gt__(self, other): + def __gt__(self, other: Any) -> bool: if not isinstance(other, IdentitySet): return NotImplemented return len(self) > len(other) and self.issuperset(other) - def union(self, iterable): + def union(self, iterable: Iterable[Any]) -> IdentitySet: result = self.__class__() members = self._members result._members.update(members) result._members.update((id(obj), obj) for obj in iterable) return result - def __or__(self, other): + def __or__(self, other: Any) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented return self.union(other) - def update(self, iterable): + def update(self, iterable: Iterable[Any]) -> None: self._members.update((id(obj), obj) for obj in iterable) - def __ior__(self, other): + def __ior__(self, other: Any) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented self.update(other) return self - def difference(self, iterable): + def difference(self, iterable: Iterable[Any]) -> IdentitySet: result = self.__new__(self.__class__) other: Collection[Any] @@ -391,21 +393,21 @@ class IdentitySet: } return result - def __sub__(self, other): + def __sub__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented return self.difference(other) - def difference_update(self, iterable): + def difference_update(self, iterable: Iterable[Any]) -> None: self._members = self.difference(iterable)._members - def __isub__(self, other): + def __isub__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented self.difference_update(other) return self - def intersection(self, iterable): + def intersection(self, iterable: Iterable[Any]) -> IdentitySet: result = self.__new__(self.__class__) other: Collection[Any] @@ -419,21 +421,21 @@ class IdentitySet: } return result - def __and__(self, other): + def __and__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented return self.intersection(other) - def intersection_update(self, iterable): + def intersection_update(self, iterable: Iterable[Any]) -> None: self._members = self.intersection(iterable)._members - def __iand__(self, other): + def __iand__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented self.intersection_update(other) return self - def symmetric_difference(self, iterable): + def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet: result = self.__new__(self.__class__) if isinstance(iterable, self.__class__): other = iterable._members @@ -447,37 +449,37 @@ class IdentitySet: ) return result - def __xor__(self, other): + def __xor__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented return self.symmetric_difference(other) - def symmetric_difference_update(self, iterable): + def symmetric_difference_update(self, iterable: Iterable[Any]) -> None: self._members = self.symmetric_difference(iterable)._members - def __ixor__(self, other): + def __ixor__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented self.symmetric_difference(other) return self - def copy(self): + def copy(self) -> IdentitySet: result = self.__new__(self.__class__) result._members = self._members.copy() return result __copy__ = copy - def __len__(self): + def __len__(self) -> int: return len(self._members) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self._members.values()) - def __hash__(self): + def __hash__(self) -> NoReturn: raise TypeError("set objects are unhashable") - def __repr__(self): + def __repr__(self) -> str: return "%s(%r)" % (type(self).__name__, list(self._members.values())) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 3e89c72bbb..2cb9c45d6b 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -12,6 +12,7 @@ modules, classes, hierarchies, attributes, functions, and methods. from __future__ import annotations import collections +import enum from functools import update_wrapper import hashlib import inspect @@ -671,13 +672,13 @@ def format_argspec_init(method, grouped=True): def create_proxy_methods( - target_cls, - target_cls_sphinx_name, - proxy_cls_sphinx_name, - classmethods=(), - methods=(), - attributes=(), -): + target_cls: Type[Any], + target_cls_sphinx_name: str, + proxy_cls_sphinx_name: str, + classmethods: Sequence[str] = (), + methods: Sequence[str] = (), + attributes: Sequence[str] = (), +) -> Callable[[_T], _T]: """A class decorator indicating attributes should refer to a proxy class. @@ -1539,24 +1540,50 @@ class hybridmethod(Generic[_T]): return self -class _symbol(int): +class symbol(int): + """A constant symbol. + + >>> symbol('foo') is symbol('foo') + True + >>> symbol('foo') + + + A slight refinement of the MAGICCOOKIE=object() pattern. The primary + advantage of symbol() is its repr(). They are also singletons. + + Repeated calls of symbol('name') will all return the same instance. + + In SQLAlchemy 2.0, symbol() is used for the implementation of + ``_FastIntFlag``, but otherwise should be mostly replaced by + ``enum.Enum`` and variants. + + + """ + name: str + symbols: Dict[str, symbol] = {} + _lock = threading.Lock() + def __new__( cls, name: str, doc: Optional[str] = None, canonical: Optional[int] = None, - ) -> "_symbol": - """Construct a new named symbol.""" - assert isinstance(name, str) - if canonical is None: - canonical = hash(name) - v = int.__new__(_symbol, canonical) - v.name = name - if doc: - v.__doc__ = doc - return v + ) -> symbol: + with cls._lock: + sym = cls.symbols.get(name) + if sym is None: + assert isinstance(name, str) + if canonical is None: + canonical = hash(name) + sym = int.__new__(symbol, canonical) + sym.name = name + if doc: + sym.__doc__ = doc + + cls.symbols[name] = sym + return sym def __reduce__(self): return symbol, (self.name, "x", int(self)) @@ -1565,90 +1592,60 @@ class _symbol(int): return repr(self) def __repr__(self): - return "symbol(%r)" % self.name + return f"symbol({self.name!r})" -_symbol.__name__ = "symbol" +class _IntFlagMeta(type): + def __init__( + cls, + classname: str, + bases: Tuple[Type[Any], ...], + dict_: Dict[str, Any], + **kw: Any, + ) -> None: + items: List[symbol] + cls._items = items = [] + for k, v in dict_.items(): + if isinstance(v, int): + sym = symbol(k, canonical=v) + elif not k.startswith("_"): + raise TypeError("Expected integer values for IntFlag") + else: + continue + setattr(cls, k, sym) + items.append(sym) + def __iter__(self) -> Iterator[symbol]: + return iter(self._items) -class symbol: - """A constant symbol. - >>> symbol('foo') is symbol('foo') - True - >>> symbol('foo') - +class _FastIntFlag(metaclass=_IntFlagMeta): + """An 'IntFlag' copycat that isn't slow when performing bitwise + operations. - A slight refinement of the MAGICCOOKIE=object() pattern. The primary - advantage of symbol() is its repr(). They are also singletons. + the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING + and ``_FastIntFlag`` otherwise. - Repeated calls of symbol('name') will all return the same instance. + """ - The optional ``doc`` argument assigns to ``__doc__``. This - is strictly so that Sphinx autoattr picks up the docstring we want - (it doesn't appear to pick up the in-module docstring if the datamember - is in a different module - autoattribute also blows up completely). - If Sphinx fixes/improves this then we would no longer need - ``doc`` here. - """ +if TYPE_CHECKING: + from enum import IntFlag - symbols: Dict[str, "_symbol"] = {} - _lock = threading.Lock() + FastIntFlag = IntFlag +else: + FastIntFlag = _FastIntFlag - def __new__( # type: ignore[misc] - cls, - name: str, - doc: Optional[str] = None, - canonical: Optional[int] = None, - ) -> _symbol: - with cls._lock: - sym = cls.symbols.get(name) - if sym is None: - cls.symbols[name] = sym = _symbol(name, doc, canonical) - return sym - @classmethod - def parse_user_argument( - cls, arg, choices, name, resolve_symbol_names=False - ): - """Given a user parameter, parse the parameter into a chosen symbol. - - The user argument can be a string name that matches the name of a - symbol, or the symbol object itself, or any number of alternate choices - such as True/False/ None etc. - - :param arg: the user argument. - :param choices: dictionary of symbol object to list of possible - entries. - :param name: name of the argument. Used in an :class:`.ArgumentError` - that is raised if the parameter doesn't match any available argument. - :param resolve_symbol_names: include the name of each symbol as a valid - entry. - - """ - # note using hash lookup is tricky here because symbol's `__hash__` - # is its int value which we don't want included in the lookup - # explicitly, so we iterate and compare each. - for sym, choice in choices.items(): - if arg is sym: - return sym - elif resolve_symbol_names and arg == sym.name: - return sym - elif arg in choice: - return sym - - if arg is None: - return None - - raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) +_E = TypeVar("_E", bound=enum.Enum) def parse_user_argument_for_enum( arg: Any, - choices: Dict[_T, List[Any]], + choices: Dict[_E, List[Any]], name: str, -) -> Optional[_T]: + resolve_symbol_names: bool = False, +) -> Optional[_E]: """Given a user parameter, parse the parameter into a chosen value from a list of choice objects, typically Enum values. @@ -1663,18 +1660,18 @@ def parse_user_argument_for_enum( that is raised if the parameter doesn't match any available argument. """ - # TODO: use whatever built in thing Enum provides for this, - # if applicable for enum_value, choice in choices.items(): if arg is enum_value: return enum_value + elif resolve_symbol_names and arg == enum_value.name: + return enum_value elif arg in choice: return enum_value if arg is None: return None - raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) + raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}") _creation_order = 1 diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py index c861c83b3f..907c510649 100644 --- a/lib/sqlalchemy/util/preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -23,6 +23,8 @@ _FN = TypeVar("_FN", bound=Callable[..., Any]) if TYPE_CHECKING: from sqlalchemy.engine import default as engine_default + from sqlalchemy.orm import session as orm_session + from sqlalchemy.orm import util as orm_util from sqlalchemy.sql import dml as sql_dml from sqlalchemy.sql import util as sql_util diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index df54017da7..dd574f3b0f 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -3,10 +3,12 @@ from __future__ import annotations import sys import typing from typing import Any +from typing import Callable from typing import cast from typing import Dict from typing import ForwardRef from typing import Iterable +from typing import Optional from typing import Tuple from typing import Type from typing import TypeVar @@ -82,7 +84,9 @@ else: def de_stringify_annotation( - cls: Type[Any], annotation: Union[str, Type[Any]] + cls: Type[Any], + annotation: Union[str, Type[Any]], + str_cleanup_fn: Optional[Callable[[str], str]] = None, ) -> Union[str, Type[Any]]: """Resolve annotations that may be string based into real objects. @@ -105,9 +109,13 @@ def de_stringify_annotation( annotation = cast(ForwardRef, annotation).__forward_arg__ if isinstance(annotation, str): + if str_cleanup_fn: + annotation = str_cleanup_fn(annotation) + base_globals: "Dict[str, Any]" = getattr( sys.modules.get(cls.__module__, None), "__dict__", {} ) + try: annotation = eval(annotation, base_globals, None) except NameError: diff --git a/pyproject.toml b/pyproject.toml index 012f1bffa9..8f7f50715a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,10 @@ module = [ "sqlalchemy.engine.*", "sqlalchemy.pool.*", + "sqlalchemy.orm.scoping", + "sqlalchemy.orm.session", + "sqlalchemy.orm.state", + # modules "sqlalchemy.events", "sqlalchemy.exc", diff --git a/test/base/test_utils.py b/test/base/test_utils.py index e22340da68..c5a47ddf97 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -27,6 +27,7 @@ from sqlalchemy.testing.util import gc_collect from sqlalchemy.testing.util import picklers from sqlalchemy.util import classproperty from sqlalchemy.util import compat +from sqlalchemy.util import FastIntFlag from sqlalchemy.util import get_callable_argspec from sqlalchemy.util import langhelpers from sqlalchemy.util import preloaded @@ -2300,6 +2301,20 @@ class SymbolTest(fixtures.TestBase): assert sym1 is not sym3 assert sym1 != sym3 + def test_fast_int_flag(self): + class Enum(FastIntFlag): + sym1 = 1 + sym2 = 2 + + sym3 = 3 + + assert Enum.sym1 is not Enum.sym3 + assert Enum.sym1 != Enum.sym3 + + assert Enum.sym1.name == "sym1" + + eq_(list(Enum), [Enum.sym1, Enum.sym2, Enum.sym3]) + def test_pickle(self): sym1 = util.symbol("foo") sym2 = util.symbol("foo") @@ -2338,17 +2353,19 @@ class SymbolTest(fixtures.TestBase): assert (sym1 | sym2) & (sym2 | sym4) def test_parser(self): - sym1 = util.symbol("sym1", canonical=1) - sym2 = util.symbol("sym2", canonical=2) - sym3 = util.symbol("sym3", canonical=4) - sym4 = util.symbol("sym4", canonical=8) + class MyEnum(FastIntFlag): + sym1 = 1 + sym2 = 2 + sym3 = 4 + sym4 = 8 + sym1, sym2, sym3, sym4 = tuple(MyEnum) lookup_one = {sym1: [], sym2: [True], sym3: [False], sym4: [None]} lookup_two = {sym1: [], sym2: [True], sym3: [False]} lookup_three = {sym1: [], sym2: ["symbol2"], sym3: []} is_( - util.symbol.parse_user_argument( + langhelpers.parse_user_argument_for_enum( "sym2", lookup_one, "some_name", resolve_symbol_names=True ), sym2, @@ -2357,35 +2374,41 @@ class SymbolTest(fixtures.TestBase): assert_raises_message( exc.ArgumentError, "Invalid value for 'some_name': 'sym2'", - util.symbol.parse_user_argument, + langhelpers.parse_user_argument_for_enum, "sym2", lookup_one, "some_name", ) is_( - util.symbol.parse_user_argument( + langhelpers.parse_user_argument_for_enum( True, lookup_one, "some_name", resolve_symbol_names=False ), sym2, ) is_( - util.symbol.parse_user_argument(sym2, lookup_one, "some_name"), + langhelpers.parse_user_argument_for_enum( + sym2, lookup_one, "some_name" + ), sym2, ) is_( - util.symbol.parse_user_argument(None, lookup_one, "some_name"), + langhelpers.parse_user_argument_for_enum( + None, lookup_one, "some_name" + ), sym4, ) is_( - util.symbol.parse_user_argument(None, lookup_two, "some_name"), + langhelpers.parse_user_argument_for_enum( + None, lookup_two, "some_name" + ), None, ) is_( - util.symbol.parse_user_argument( + langhelpers.parse_user_argument_for_enum( "symbol2", lookup_three, "some_name" ), sym2, @@ -2394,7 +2417,7 @@ class SymbolTest(fixtures.TestBase): assert_raises_message( exc.ArgumentError, "Invalid value for 'some_name': 'foo'", - util.symbol.parse_user_argument, + langhelpers.parse_user_argument_for_enum, "foo", lookup_three, "some_name", diff --git a/test/ext/mypy/plain_files/session.py b/test/ext/mypy/plain_files/session.py new file mode 100644 index 0000000000..24c685e84b --- /dev/null +++ b/test/ext/mypy/plain_files/session.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import List +from typing import Sequence + +from sqlalchemy import create_engine +from sqlalchemy import ForeignKey +from sqlalchemy import select +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + addresses: Mapped[List[Address]] = relationship(back_populates="user") + + +class Address(Base): + __tablename__ = "address" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email: Mapped[str] + + user: Mapped[User] = relationship(back_populates="addresses") + + +e = create_engine("sqlite://") +Base.metadata.create_all(e) + +with Session(e) as sess: + u1 = User(name="u1") + sess.add(u1) + sess.add_all([Address(user=u1, email="e1"), Address(user=u1, email="e2")]) + sess.commit() + +with Session(e) as sess: + users: Sequence[User] = sess.scalars( + select(User), execution_options={"stream_results": False} + ).all() diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index c7022dc31c..f8abd686a0 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -1,9 +1,56 @@ from __future__ import annotations +from typing import List + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.testing import is_ from .test_typed_mapping import MappedColumnTest # noqa -from .test_typed_mapping import RelationshipLHSTest # noqa +from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest """runs the annotation-sensitive tests from test_typed_mappings while having ``from __future__ import annotations`` in effect. """ + + +class RelationshipLHSTest(_RelationshipLHSTest): + def test_bidirectional_literal_annotations(self, decl_base): + """test the 'string cleanup' function in orm/util.py, where + we receive a string annotation like:: + + "Mapped[List[B]]" + + Which then fails to evaluate because we don't have "B" yet. + The annotation is converted on the fly to:: + + 'Mapped[List["B"]]' + + so that when we evaluated it, we get ``Mapped[List["B"]]`` and + can extract "B" as a string. + + """ + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs: Mapped[List[B]] = relationship(back_populates="a") + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + a: Mapped[A] = relationship( + back_populates="bs", primaryjoin=a_id == A.id + ) + + a1 = A(data="data") + b1 = B() + a1.bs.append(b1) + is_(a1, b1.a) diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index d6d229f792..058e1735b6 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -190,6 +190,7 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): }, ], ), + argnames="cols, expected", ) def test_column_descriptions(self, cols, expected): User, Address = self.classes("User", "Address") @@ -211,8 +212,13 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): ) stmt = select(*cols) + eq_(stmt.column_descriptions, expected) + if stmt._propagate_attrs: + stmt = select(*cols).from_statement(stmt) + eq_(stmt.column_descriptions, expected) + @testing.combinations(insert, update, delete, argnames="dml_construct") @testing.combinations( ( diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 79b20e285a..4cecac0de4 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -5,7 +5,9 @@ from unittest.mock import Mock import sqlalchemy as sa from sqlalchemy import delete from sqlalchemy import event +from sqlalchemy import exc as sa_exc from sqlalchemy import ForeignKey +from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal_column @@ -42,6 +44,7 @@ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_not +from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -236,6 +239,84 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): ), ) + def test_override_parameters_executesingle(self): + User = self.classes.User + + sess = Session(testing.db, future=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + return ctx.invoke_statement(params={"name": "overridden"}) + + orig_params = {"id": 18, "name": "original"} + with self.sql_execution_asserter() as asserter: + sess.execute(insert(User), orig_params) + asserter.assert_( + CompiledSQL( + "INSERT INTO users (id, name) VALUES (:id, :name)", + [{"id": 18, "name": "overridden"}], + ) + ) + # orig params weren't mutated + eq_(orig_params, {"id": 18, "name": "original"}) + + def test_override_parameters_executemany(self): + User = self.classes.User + + sess = Session(testing.db, future=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + return ctx.invoke_statement( + params=[{"name": "overridden1"}, {"name": "overridden2"}] + ) + + orig_params = [ + {"id": 18, "name": "original1"}, + {"id": 19, "name": "original2"}, + ] + with self.sql_execution_asserter() as asserter: + sess.execute(insert(User), orig_params) + asserter.assert_( + CompiledSQL( + "INSERT INTO users (id, name) VALUES (:id, :name)", + [ + {"id": 18, "name": "overridden1"}, + {"id": 19, "name": "overridden2"}, + ], + ) + ) + # orig params weren't mutated + eq_( + orig_params, + [{"id": 18, "name": "original1"}, {"id": 19, "name": "original2"}], + ) + + def test_override_parameters_executemany_mismatch(self): + User = self.classes.User + + sess = Session(testing.db, future=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + return ctx.invoke_statement( + params=[{"name": "overridden1"}, {"name": "overridden2"}] + ) + + orig_params = [ + {"id": 18, "name": "original1"}, + {"id": 19, "name": "original2"}, + {"id": 20, "name": "original3"}, + ] + with expect_raises_message( + sa_exc.InvalidRequestError, + r"Can't apply executemany parameters to statement; number " + r"of parameter sets passed to Session.execute\(\) \(3\) does " + r"not match number of parameter sets given to " + r"ORMExecuteState.invoke_statement\(\) \(2\)", + ): + sess.execute(insert(User), orig_params) + def test_chained_events_one(self): sess = Session(testing.db, future=True) diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index a4250e375c..c006babc84 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -11,7 +11,6 @@ from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import exc as orm_exc -from sqlalchemy.orm import instrumentation from sqlalchemy.orm import lazyload from sqlalchemy.orm import relationship from sqlalchemy.orm import state as sa_state @@ -410,73 +409,6 @@ class PickleTest(fixtures.MappedTest): u2 = loads(dumps(u1)) eq_(u1, u2) - def test_09_pickle(self): - users = self.tables.users - self.mapper_registry.map_imperatively(User, users) - sess = fixture_session() - sess.add(User(id=1, name="ed")) - sess.commit() - sess.close() - - inst = User(id=1, name="ed") - del inst._sa_instance_state - - state = sa_state.InstanceState.__new__(sa_state.InstanceState) - state_09 = { - "class_": User, - "modified": False, - "committed_state": {}, - "instance": inst, - "callables": {"name": state, "id": state}, - "key": (User, (1,)), - "expired": True, - } - manager = instrumentation._SerializeManager.__new__( - instrumentation._SerializeManager - ) - manager.class_ = User - state_09["manager"] = manager - state.__setstate__(state_09) - eq_(state.expired_attributes, {"name", "id"}) - - sess = fixture_session() - sess.add(inst) - eq_(inst.name, "ed") - # test identity_token expansion - eq_(sa.inspect(inst).key, (User, (1,), None)) - - def test_11_pickle(self): - users = self.tables.users - self.mapper_registry.map_imperatively(User, users) - sess = fixture_session() - u1 = User(id=1, name="ed") - sess.add(u1) - sess.commit() - - sess.close() - - manager = instrumentation._SerializeManager.__new__( - instrumentation._SerializeManager - ) - manager.class_ = User - - state_11 = { - "class_": User, - "modified": False, - "committed_state": {}, - "instance": u1, - "manager": manager, - "key": (User, (1,)), - "expired_attributes": set(), - "expired": True, - } - - state = sa_state.InstanceState.__new__(sa_state.InstanceState) - state.__setstate__(state_11) - - eq_(state.identity_token, None) - eq_(state.identity_key, (User, (1,), None)) - def test_state_info_pickle(self): users = self.tables.users self.mapper_registry.map_imperatively(User, users) diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index f2d7d8569a..33e66d52f6 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -5,6 +5,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import util from sqlalchemy.orm import query from sqlalchemy.orm import relationship from sqlalchemy.orm import scoped_session @@ -158,7 +159,7 @@ class ScopedSessionTest(fixtures.MappedTest): populate_existing=False, with_for_update=None, identity_token=None, - execution_options=None, + execution_options=util.EMPTY_DICT, ), ], ) diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py index eec4d878ac..ffc470972f 100644 --- a/tools/generate_proxy_methods.py +++ b/tools/generate_proxy_methods.py @@ -149,13 +149,27 @@ def process_class( iscoroutine = inspect.iscoroutinefunction(fn) - if spec.defaults: - new_defaults = tuple( - _repr_sym("util.EMPTY_DICT") if df is util.EMPTY_DICT else df - for df in spec.defaults - ) + if spec.defaults or spec.kwonlydefaults: elem = list(spec) - elem[3] = tuple(new_defaults) + + if spec.defaults: + new_defaults = tuple( + _repr_sym("util.EMPTY_DICT") + if df is util.EMPTY_DICT + else df + for df in spec.defaults + ) + elem[3] = new_defaults + + if spec.kwonlydefaults: + new_kwonlydefaults = { + name: _repr_sym("util.EMPTY_DICT") + if df is util.EMPTY_DICT + else df + for name, df in spec.kwonlydefaults.items() + } + elem[5] = new_kwonlydefaults + spec = compat.FullArgSpec(*elem) caller_argspec = format_argspec_plus(spec, grouped=False)