From 1fa3e2e3814b4d28deca7426bb3f36e7fb515496 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 28 Apr 2022 16:19:43 -0400 Subject: [PATCH] pep484: attributes and related also implements __slots__ for QueryableAttribute, InstrumentedAttribute, Relationship.Comparator. Change-Id: I47e823160706fc35a616f1179a06c7864089e5b5 --- doc/build/orm/internals.rst | 2 +- .../custom_attributes/custom_management.py | 9 +- lib/sqlalchemy/event/base.py | 6 +- lib/sqlalchemy/ext/hybrid.py | 13 +- lib/sqlalchemy/orm/__init__.py | 2 +- lib/sqlalchemy/orm/_typing.py | 2 + lib/sqlalchemy/orm/attributes.py | 925 +++++++++++++----- lib/sqlalchemy/orm/base.py | 25 +- lib/sqlalchemy/orm/collections.py | 208 ++-- lib/sqlalchemy/orm/dynamic.py | 8 +- lib/sqlalchemy/orm/instrumentation.py | 121 ++- lib/sqlalchemy/orm/interfaces.py | 45 +- lib/sqlalchemy/orm/relationships.py | 77 +- lib/sqlalchemy/orm/state.py | 9 +- lib/sqlalchemy/orm/util.py | 16 +- lib/sqlalchemy/sql/base.py | 4 +- lib/sqlalchemy/sql/cache_key.py | 8 +- lib/sqlalchemy/sql/util.py | 11 +- lib/sqlalchemy/util/_collections.py | 4 +- lib/sqlalchemy/util/langhelpers.py | 13 +- lib/sqlalchemy/util/typing.py | 39 + test/ext/test_extendedattr.py | 57 +- test/ext/test_mutable.py | 5 +- test/orm/test_attributes.py | 298 +++--- test/orm/test_collection.py | 64 +- test/orm/test_deprecations.py | 8 +- test/orm/test_instrumentation.py | 9 +- 27 files changed, 1319 insertions(+), 669 deletions(-) diff --git a/doc/build/orm/internals.rst b/doc/build/orm/internals.rst index f251e43bd0..9aa3b2db67 100644 --- a/doc/build/orm/internals.rst +++ b/doc/build/orm/internals.rst @@ -37,7 +37,7 @@ sections, are listed here. .. autodata:: CompositeProperty -.. autoclass:: AttributeEvent +.. autoclass:: AttributeEventToken :members: .. autoclass:: IdentityMap diff --git a/examples/custom_attributes/custom_management.py b/examples/custom_attributes/custom_management.py index 5ee5a45f83..aa9ea7a689 100644 --- a/examples/custom_attributes/custom_management.py +++ b/examples/custom_attributes/custom_management.py @@ -17,7 +17,7 @@ from sqlalchemy import MetaData from sqlalchemy import Table from sqlalchemy import Text from sqlalchemy.ext.instrumentation import InstrumentationManager -from sqlalchemy.orm import mapper +from sqlalchemy.orm import registry as _reg from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import del_attribute @@ -26,6 +26,9 @@ from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.instrumentation import is_instrumented +registry = _reg() + + class MyClassState(InstrumentationManager): def get_instance_dict(self, class_, instance): return instance._goofy_dict @@ -97,9 +100,9 @@ if __name__ == "__main__": class B(MyClass): pass - mapper(A, table1, properties={"bs": relationship(B)}) + registry.map_imperatively(A, table1, properties={"bs": relationship(B)}) - mapper(B, table2) + registry.map_imperatively(B, table2) a1 = A(name="a1", bs=[B(name="b1"), B(name="b2")]) diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index c16f6870be..83b34a17fc 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -108,10 +108,12 @@ class _Dispatch(_DispatchCommon[_ET]): """ - # In one ORM edge case, an attribute is added to _Dispatch, - # so __dict__ is used in just that case and potentially others. + # "active_history" is an ORM case we add here. ideally a better + # system would be in place for ad-hoc attributes. __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners" + _active_history: bool + _empty_listener_reg: MutableMapping[ Type[_ET], Dict[str, _EmptyListener[_ET]] ] = weakref.WeakKeyDictionary() diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 7200414a18..ea558495b4 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -824,15 +824,14 @@ from ..orm import attributes from ..orm import InspectionAttrExtensionType from ..orm import interfaces from ..orm import ORMDescriptor +from ..sql import roles from ..sql._typing import is_has_clause_element from ..sql.elements import ColumnElement from ..sql.elements import SQLCoreOperations from ..util.typing import Literal from ..util.typing import Protocol - if TYPE_CHECKING: - from ..orm._typing import _ORMColumnExprArgument from ..orm.interfaces import MapperProperty from ..orm.util import AliasedInsp from ..sql._typing import _ColumnExpressionArgument @@ -840,7 +839,6 @@ if TYPE_CHECKING: from ..sql._typing import _HasClauseElement from ..sql._typing import _InfoType from ..sql.operators import OperatorType - from ..sql.roles import ColumnsClauseRole _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) @@ -1290,7 +1288,7 @@ class Comparator(interfaces.PropComparator[_T]): ): self.expression = expression - def __clause_element__(self) -> _ORMColumnExprArgument[_T]: + def __clause_element__(self) -> roles.ColumnsClauseRole: expr = self.expression if is_has_clause_element(expr): ret_expr = expr.__clause_element__() @@ -1306,7 +1304,7 @@ class Comparator(interfaces.PropComparator[_T]): assert isinstance(ret_expr, ColumnElement) return ret_expr - @util.ro_non_memoized_property + @util.non_memoized_property def property(self) -> Optional[interfaces.MapperProperty[_T]]: return None @@ -1345,8 +1343,11 @@ class ExprComparator(Comparator[_T]): else: return [(self.expression, value)] - @util.ro_non_memoized_property + @util.non_memoized_property def property(self) -> Optional[MapperProperty[_T]]: + # this accessor is not normally used, however is accessed by things + # like ORM synonyms if the hybrid is used in this context; the + # .property attribute is not necessarily accessible return self.expression.property # type: ignore def operate( diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 58900ab99a..b7d1df5322 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -41,7 +41,7 @@ from ._orm_constructors import synonym as synonym from ._orm_constructors import SynonymProperty as SynonymProperty from ._orm_constructors import with_loader_criteria as with_loader_criteria from ._orm_constructors import with_polymorphic as with_polymorphic -from .attributes import AttributeEvent as AttributeEvent +from .attributes import AttributeEventToken as AttributeEventToken from .attributes import InstrumentedAttribute as InstrumentedAttribute from .attributes import QueryableAttribute as QueryableAttribute from .base import class_mapper as class_mapper diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 339844f147..29d82340ab 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -47,6 +47,8 @@ if TYPE_CHECKING: _InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"] +_ExternalEntityType = Union[Type[_T], "AliasedClass[_T]"] + _EntityType = Union[ Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]" ] diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 9a6e94e228..9aeaeaa272 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +# mypy: allow-untyped-defs, allow-untyped-calls """Defines instrumentation for class attributes and their interaction with instances. @@ -17,16 +17,18 @@ defines a large part of the ORM's interactivity. from __future__ import annotations -from collections import namedtuple +import dataclasses import operator from typing import Any from typing import Callable -from typing import Collection +from typing import cast +from typing import ClassVar from typing import Dict from typing import List from typing import NamedTuple from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -36,6 +38,7 @@ from typing import Union from . import collections from . import exc as orm_exc from . import interfaces +from ._typing import insp_is_aliased_class from .base import ATTR_EMPTY from .base import ATTR_WAS_SET from .base import CALLABLES_OK @@ -45,6 +48,7 @@ from .base import instance_dict as instance_dict from .base import instance_state as instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED +from .base import LoaderCallableStatus from .base import manager_of_class as manager_of_class from .base import Mapped as Mapped # noqa from .base import NEVER_SET # noqa @@ -70,17 +74,41 @@ from .. import event from .. import exc from .. import inspection from .. import util +from ..event import dispatcher +from ..event import EventTarget from ..sql import base as sql_base from ..sql import cache_key +from ..sql import coercions from ..sql import roles -from ..sql import traversals from ..sql import visitors +from ..util.typing import Literal +from ..util.typing import TypeGuard if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _LoaderCallable + from ._typing import _O + from .collections import _AdaptedCollectionProtocol + from .collections import CollectionAdapter + from .dynamic import DynamicAttributeImpl from .interfaces import MapperProperty + from .relationships import Relationship from .state import InstanceState - from ..sql.dml import _DMLColumnElement + from .util import AliasedInsp + from ..event.base import _Dispatch + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType + from ..sql._typing import _PropagateAttrsType + from ..sql.annotation import _AnnotationDict from ..sql.elements import ColumnElement + from ..sql.elements import Label + from ..sql.operators import OperatorType + from ..sql.selectable import FromClause + _T = TypeVar("_T") @@ -89,19 +117,27 @@ class NoKey(str): pass +_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]] + NO_KEY = NoKey("no name") +SelfQueryableAttribute = TypeVar( + "SelfQueryableAttribute", bound="QueryableAttribute[Any]" +) + @inspection._self_inspects class QueryableAttribute( + roles.ExpressionElementRole[_T], interfaces._MappedAttribute[_T], interfaces.InspectionAttr, interfaces.PropComparator[_T], - traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, sql_base.Immutable, - cache_key.MemoizedHasCacheKey, + cache_key.SlotsMemoizedHasCacheKey, + util.MemoizedSlots, + EventTarget, ): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` @@ -121,9 +157,33 @@ class QueryableAttribute( :attr:`_orm.Mapper.attrs` """ + __slots__ = ( + "class_", + "key", + "impl", + "comparator", + "property", + "parent", + "expression", + "_of_type", + "_extra_criteria", + "_slots_dispatch", + "_propagate_attrs", + "_doc", + ) + is_attribute = True + dispatch: dispatcher[QueryableAttribute[_T]] + + class_: _ExternalEntityType[Any] + key: str + parententity: _InternalEntityType[Any] impl: AttributeImpl + comparator: interfaces.PropComparator[_T] + _of_type: Optional[_InternalEntityType[Any]] + _extra_criteria: Tuple[ColumnElement[bool], ...] + _doc: Optional[str] # PropComparator has a __visit_name__ to participate within # traversals. Disambiguate the attribute vs. a comparator. @@ -131,21 +191,30 @@ class QueryableAttribute( def __init__( self, - class_, - key, - parententity, - impl=None, - comparator=None, - of_type=None, - extra_criteria=(), + class_: _ExternalEntityType[_O], + key: str, + parententity: _InternalEntityType[_O], + comparator: interfaces.PropComparator[_T], + impl: Optional[AttributeImpl] = None, + of_type: Optional[_InternalEntityType[Any]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), ): self.class_ = class_ self.key = key - self._parententity = parententity - self.impl = impl + + self._parententity = self.parent = parententity + + # this attribute is non-None after mappers are set up, however in the + # interim class manager setup, there's a check for None to see if it + # needs to be populated, so we assign None here leaving the attribute + # in a temporarily not-type-correct state + self.impl = impl # type: ignore + + assert comparator is not None self.comparator = comparator self._of_type = of_type self._extra_criteria = extra_criteria + self._doc = None manager = opt_manager_of_class(class_) # manager is None in the case of AliasedClass @@ -156,7 +225,7 @@ class QueryableAttribute( if key in base: self.dispatch._update(base[key].dispatch) if base[key].dispatch._active_history: - self.dispatch._active_history = True + self.dispatch._active_history = True # type: ignore _cache_key_traversal = [ ("key", visitors.ExtendedInternalTraversal.dp_string), @@ -165,7 +234,7 @@ class QueryableAttribute( ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), ] - def __reduce__(self): + def __reduce__(self) -> Any: # this method is only used in terms of the # sqlalchemy.ext.serializer extension return ( @@ -178,21 +247,19 @@ class QueryableAttribute( ), ) - @util.memoized_property - def _supports_population(self): - return self.impl.supports_population - @property - def _impl_uses_objects(self): + def _impl_uses_objects(self) -> bool: return self.impl.uses_objects - def get_history(self, instance, passive=PASSIVE_OFF): + def get_history( + self, instance: Any, passive: PassiveFlag = PASSIVE_OFF + ) -> History: return self.impl.get_history( instance_state(instance), instance_dict(instance), passive ) - @util.memoized_property - def info(self): + @property + def info(self) -> _InfoType: """Return the 'info' dictionary for the underlying SQL element. The behavior here is as follows: @@ -233,27 +300,28 @@ class QueryableAttribute( """ return self.comparator.info - @util.memoized_property - def parent(self): - """Return an inspection instance representing the parent. + parent: _InternalEntityType[Any] + """Return an inspection instance representing the parent. - This will be either an instance of :class:`_orm.Mapper` - or :class:`.AliasedInsp`, depending upon the nature - of the parent entity which this attribute is associated - with. + This will be either an instance of :class:`_orm.Mapper` + or :class:`.AliasedInsp`, depending upon the nature + of the parent entity which this attribute is associated + with. - """ - return inspection.inspect(self._parententity) + """ - @util.memoized_property - def expression(self): - """The SQL expression object represented by this - :class:`.QueryableAttribute`. + expression: ColumnElement[_T] + """The SQL expression object represented by this + :class:`.QueryableAttribute`. - This will typically be an instance of a :class:`_sql.ColumnElement` - subclass representing a column expression. + This will typically be an instance of a :class:`_sql.ColumnElement` + subclass representing a column expression. + + """ + + def _memoized_attr_expression(self) -> ColumnElement[_T]: + annotations: _AnnotationDict - """ if self.key is NO_KEY: annotations = {"entity_namespace": self._entity_namespace} else: @@ -265,6 +333,8 @@ class QueryableAttribute( ce = self.comparator.__clause_element__() try: + if TYPE_CHECKING: + assert isinstance(ce, ColumnElement) anno = ce._annotate except AttributeError as ae: raise exc.InvalidRequestError( @@ -275,29 +345,42 @@ class QueryableAttribute( else: return anno(annotations) + def _memoized_attr__propagate_attrs(self) -> _PropagateAttrsType: + # this suits the case in coercions where we don't actually + # call ``__clause_element__()`` but still need to get + # resolved._propagate_attrs. See #6558. + return util.immutabledict( + { + "compile_state_plugin": "orm", + "plugin_subject": self._parentmapper, + } + ) + @property - def _entity_namespace(self): + def _entity_namespace(self) -> _InternalEntityType[Any]: return self._parententity @property - def _annotations(self): + def _annotations(self) -> _AnnotationDict: return self.__clause_element__()._annotations def __clause_element__(self) -> ColumnElement[_T]: return self.expression @property - def _from_objects(self): + def _from_objects(self) -> List[FromClause]: return self.expression._from_objects def _bulk_update_tuples( self, value: Any - ) -> List[Tuple[_DMLColumnElement, Any]]: + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: """Return setter tuples for a bulk UPDATE.""" return self.comparator._bulk_update_tuples(value) - def adapt_to_entity(self, adapt_to_entity): + def adapt_to_entity( + self: SelfQueryableAttribute, adapt_to_entity: AliasedInsp[Any] + ) -> SelfQueryableAttribute: assert not self._of_type return self.__class__( adapt_to_entity.entity, @@ -307,7 +390,7 @@ class QueryableAttribute( parententity=adapt_to_entity, ) - def of_type(self, entity): + def of_type(self, entity: _EntityType[_T]) -> QueryableAttribute[_T]: return QueryableAttribute( self.class_, self.key, @@ -318,18 +401,28 @@ class QueryableAttribute( extra_criteria=self._extra_criteria, ) - def and_(self, *other): + def and_( + self, *clauses: _ColumnExpressionArgument[bool] + ) -> interfaces.PropComparator[bool]: + if TYPE_CHECKING: + assert isinstance(self.comparator, Relationship.Comparator) + + exprs = tuple( + coercions.expect(roles.WhereHavingRole, clause) + for clause in util.coerce_generator_arg(clauses) + ) + return QueryableAttribute( self.class_, self.key, self._parententity, impl=self.impl, - comparator=self.comparator.and_(*other), + comparator=self.comparator.and_(*exprs), of_type=self._of_type, - extra_criteria=self._extra_criteria + other, + extra_criteria=self._extra_criteria + exprs, ) - def _clone(self, **kw): + def _clone(self, **kw: Any) -> QueryableAttribute[_T]: return QueryableAttribute( self.class_, self.key, @@ -340,19 +433,30 @@ class QueryableAttribute( extra_criteria=self._extra_criteria, ) - def label(self, name): + def label(self, name: Optional[str]) -> Label[_T]: return self.__clause_element__().label(name) - def operate(self, op, *other, **kwargs): - return op(self.comparator, *other, **kwargs) + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501 - def reverse_operate(self, op, other, **kwargs): - return op(other, self.comparator, **kwargs) + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501 - def hasparent(self, state, optimistic=False): + def hasparent( + self, state: InstanceState[Any], optimistic: bool = False + ) -> bool: return self.impl.hasparent(state, optimistic=optimistic) is not False - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: + try: + return util.MemoizedSlots.__getattr__(self, key) + except AttributeError: + pass + try: return getattr(self.comparator, key) except AttributeError as err: @@ -367,27 +471,22 @@ class QueryableAttribute( ) ) from err - def __str__(self): - return "%s.%s" % (self.class_.__name__, self.key) + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" - @util.memoized_property - def property(self) -> MapperProperty[_T]: - """Return the :class:`.MapperProperty` associated with this - :class:`.QueryableAttribute`. - - - Return values here will commonly be instances of - :class:`.ColumnProperty` or :class:`.Relationship`. - - - """ + def _memoized_attr_property(self) -> Optional[MapperProperty[Any]]: return self.comparator.property -def _queryable_attribute_unreduce(key, mapped_class, parententity, entity): +def _queryable_attribute_unreduce( + key: str, + mapped_class: Type[_O], + parententity: _InternalEntityType[_O], + entity: _ExternalEntityType[Any], +) -> Any: # this method is only used in terms of the # sqlalchemy.ext.serializer extension - if parententity.is_aliased_class: + if insp_is_aliased_class(parententity): return entity._get_from_serialized(key, mapped_class, parententity) else: return getattr(entity, key) @@ -402,45 +501,60 @@ class InstrumentedAttribute(QueryableAttribute[_T]): """ + __slots__ = () + inherit_cache = True - def __set__(self, instance, value): + # if not TYPE_CHECKING: + + @property # type: ignore + def __doc__(self) -> Optional[str]: # type: ignore + return self._doc + + @__doc__.setter + def __doc__(self, value: Optional[str]) -> None: + self._doc = value + + def __set__(self, instance: object, value: Any) -> None: self.impl.set( instance_state(instance), instance_dict(instance), value, None ) - def __delete__(self, instance): + def __delete__(self, instance: object) -> None: self.impl.delete(instance_state(instance), instance_dict(instance)) @overload - def __get__( - self, instance: None, owner: Type[Any] - ) -> InstrumentedAttribute: + def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]: ... @overload - def __get__(self, instance: object, owner: Type[Any]) -> Optional[_T]: + def __get__(self, instance: object, owner: Any) -> _T: ... def __get__( - self, instance: Optional[object], owner: Type[Any] - ) -> Union[InstrumentedAttribute, Optional[_T]]: + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T], _T]: if instance is None: return self dict_ = instance_dict(instance) - if self._supports_population and self.key in dict_: - return dict_[self.key] + if self.impl.supports_population and self.key in dict_: + return dict_[self.key] # type: ignore[no-any-return] else: try: state = instance_state(instance) except AttributeError as err: raise orm_exc.UnmappedInstanceError(instance) from err - return self.impl.get(state, dict_) + return self.impl.get(state, dict_) # type: ignore[no-any-return] -HasEntityNamespace = namedtuple("HasEntityNamespace", ["entity_namespace"]) -HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False +@dataclasses.dataclass(frozen=True) +class AdHocHasEntityNamespace: + # py37 compat, no slots=True on dataclass + __slots__ = ("entity_namespace",) + entity_namespace: _ExternalEntityType[Any] + is_mapper: ClassVar[bool] = False + is_aliased_class: ClassVar[bool] = False def create_proxied_attribute( @@ -455,7 +569,7 @@ def create_proxied_attribute( # TODO: can move this to descriptor_props if the need for this # function is removed from ext/hybrid.py - class Proxy(QueryableAttribute): + class Proxy(QueryableAttribute[Any]): """Presents the :class:`.QueryableAttribute` interface as a proxy on top of a Python descriptor / :class:`.PropComparator` combination. @@ -464,6 +578,10 @@ def create_proxied_attribute( _extra_criteria = () + # the attribute error catches inside of __getattr__ basically create a + # singularity if you try putting slots on this too + # __slots__ = ("descriptor", "original_property", "_comparator") + def __init__( self, class_, @@ -480,7 +598,15 @@ def create_proxied_attribute( self.original_property = original_property self._comparator = comparator self._adapt_to_entity = adapt_to_entity - self.__doc__ = doc + self._doc = self.__doc__ = doc + + @property + def _parententity(self): + return inspection.inspect(self.class_) + + @property + def parent(self): + return inspection.inspect(self.class_) _is_internal_proxy = True @@ -496,10 +622,6 @@ def create_proxied_attribute( and getattr(self.class_, self.key).impl.uses_objects ) - @property - def _parententity(self): - return inspection.inspect(self.class_, raiseerr=False) - @property def _entity_namespace(self): if hasattr(self._comparator, "_parententity"): @@ -507,7 +629,7 @@ def create_proxied_attribute( else: # used by hybrid attributes which try to remain # agnostic of any ORM concepts like mappers - return HasEntityNamespace(self.class_) + return AdHocHasEntityNamespace(self.class_) @property def property(self): @@ -552,12 +674,22 @@ def create_proxied_attribute( else: return retval - def __str__(self): - return "%s.%s" % (self.class_.__name__, self.key) + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" def __getattr__(self, attribute): """Delegate __getattr__ to the original descriptor and/or comparator.""" + + # this is unfortunately very complicated, and is easily prone + # to recursion overflows when implementations of related + # __getattr__ schemes are changed + + try: + return util.MemoizedSlots.__getattr__(self, attribute) + except AttributeError: + pass + try: return getattr(descriptor, attribute) except AttributeError as err: @@ -602,7 +734,7 @@ OP_BULK_REPLACE = util.symbol("BULK_REPLACE") OP_MODIFIED = util.symbol("MODIFIED") -class AttributeEvent: +class AttributeEventToken: """A token propagated throughout the course of a chain of attribute events. @@ -619,7 +751,8 @@ class AttributeEvent: event handlers, and is used to control the propagation of operations across two mutually-dependent attributes. - .. versionadded:: 0.9.0 + .. versionchanged:: 2.0 Changed the name from ``AttributeEvent`` + to ``AttributeEventToken``. :attribute impl: The :class:`.AttributeImpl` which is the current event initiator. @@ -639,7 +772,7 @@ class AttributeEvent: def __eq__(self, other): return ( - isinstance(other, AttributeEvent) + isinstance(other, AttributeEventToken) and other.impl is self.impl and other.op == self.op ) @@ -652,28 +785,37 @@ class AttributeEvent: return self.impl.hasparent(state) -Event = AttributeEvent +AttributeEvent = AttributeEventToken # legacy +Event = AttributeEventToken # legacy class AttributeImpl: """internal implementation for instrumented attributes.""" collection: bool + default_accepts_scalar_loader: bool + uses_objects: bool + supports_population: bool + dynamic: bool + + _replace_token: AttributeEventToken + _remove_token: AttributeEventToken + _append_token: AttributeEventToken def __init__( self, - class_, - key, - callable_, - dispatch, - trackparent=False, - compare_function=None, - active_history=False, - parent_token=None, - load_on_unexpire=True, - send_modified_events=True, - accepts_scalar_loader=None, - **kwargs, + class_: _ExternalEntityType[_O], + key: str, + callable_: _LoaderCallable, + dispatch: _Dispatch[QueryableAttribute[Any]], + trackparent: bool = False, + compare_function: Optional[Callable[..., bool]] = None, + active_history: bool = False, + parent_token: Optional[AttributeEventToken] = None, + load_on_unexpire: bool = True, + send_modified_events: bool = True, + accepts_scalar_loader: Optional[bool] = None, + **kwargs: Any, ): r"""Construct an AttributeImpl. @@ -743,7 +885,7 @@ class AttributeImpl: self.dispatch._active_history = True self.load_on_unexpire = load_on_unexpire - self._modified_token = Event(self, OP_MODIFIED) + self._modified_token = AttributeEventToken(self, OP_MODIFIED) __slots__ = ( "class_", @@ -760,8 +902,8 @@ class AttributeImpl: "_deferred_history", ) - def __str__(self): - return "%s.%s" % (self.class_.__name__, self.key) + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" def _get_active_history(self): """Backwards compat for impl.active_history""" @@ -773,7 +915,9 @@ class AttributeImpl: active_history = property(_get_active_history, _set_active_history) - def hasparent(self, state, optimistic=False): + def hasparent( + self, state: InstanceState[Any], optimistic: bool = False + ) -> bool: """Return the boolean value of a `hasparent` flag attached to the given state. @@ -796,7 +940,12 @@ class AttributeImpl: state.parents.get(id(self.parent_token), optimistic) is not False ) - def sethasparent(self, state, parent_state, value): + def sethasparent( + self, + state: InstanceState[Any], + parent_state: InstanceState[Any], + value: bool, + ) -> None: """Set a boolean flag on the given item corresponding to whether or not it is attached to a parent object via the attribute represented by this ``InstrumentedAttribute``. @@ -839,11 +988,16 @@ class AttributeImpl: self, state: InstanceState[Any], dict_: _InstanceDict, - passive=PASSIVE_OFF, + passive: PassiveFlag = PASSIVE_OFF, ) -> History: raise NotImplementedError() - def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: """Return a list of tuples of (state, obj) for all objects in this attribute's current state + history. @@ -861,7 +1015,9 @@ class AttributeImpl: """ raise NotImplementedError() - def _default_value(self, state, dict_): + def _default_value( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> Any: """Produce an empty value for an uninitialized scalar attribute.""" assert self.key not in dict_, ( @@ -877,7 +1033,12 @@ class AttributeImpl: return value - def get(self, state, dict_, passive=PASSIVE_OFF): + def get( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Any: """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and passive is False, the callable will be executed and the @@ -917,7 +1078,9 @@ class AttributeImpl: else: return self._default_value(state, dict_) - def _fire_loader_callables(self, state, key, passive): + def _fire_loader_callables( + self, state: InstanceState[Any], key: str, passive: PassiveFlag + ) -> Any: if ( self.accepts_scalar_loader and self.load_on_unexpire @@ -932,15 +1095,36 @@ class AttributeImpl: else: return ATTR_EMPTY - def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: self.set(state, dict_, value, initiator, passive=passive) - def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: self.set( state, dict_, None, initiator, passive=passive, check_old=value ) - def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: self.set( state, dict_, @@ -953,17 +1137,25 @@ class AttributeImpl: def set( self, - state, - dict_, - value, - initiator, - passive=PASSIVE_OFF, - check_old=None, - pop=False, - ): + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + ) -> None: + raise NotImplementedError() + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: raise NotImplementedError() - def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): + def get_committed_value( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Any: """return the unchanged value of this attribute""" if self.key in state.committed_state: @@ -996,10 +1188,12 @@ class ScalarAttributeImpl(AttributeImpl): def __init__(self, *arg, **kw): super(ScalarAttributeImpl, self).__init__(*arg, **kw) - self._replace_token = self._append_token = Event(self, OP_REPLACE) - self._remove_token = Event(self, OP_REMOVE) + self._replace_token = self._append_token = AttributeEventToken( + self, OP_REPLACE + ) + self._remove_token = AttributeEventToken(self, OP_REMOVE) - def delete(self, state, dict_): + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: if self.dispatch._active_history: old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) else: @@ -1042,11 +1236,11 @@ class ScalarAttributeImpl(AttributeImpl): state: InstanceState[Any], dict_: Dict[str, Any], value: Any, - initiator: Optional[Event], + initiator: Optional[AttributeEventToken], passive: PassiveFlag = PASSIVE_OFF, check_old: Optional[object] = None, pop: bool = False, - ): + ) -> None: if self.dispatch._active_history: old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) else: @@ -1059,21 +1253,30 @@ class ScalarAttributeImpl(AttributeImpl): state._modified_event(dict_, self, old) dict_[self.key] = value - def fire_replace_event(self, state, dict_, value, previous, initiator): + def fire_replace_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + previous: Any, + initiator: Optional[AttributeEventToken], + ) -> _T: for fn in self.dispatch.set: value = fn( state, value, previous, initiator or self._replace_token ) return value - def fire_remove_event(self, state, dict_, value, initiator): + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: for fn in self.dispatch.remove: fn(state, value, initiator or self._remove_token) - @property - def type(self): - self.property.columns[0].type - class ScalarObjectAttributeImpl(ScalarAttributeImpl): """represents a scalar-holding InstrumentedAttribute, @@ -1090,7 +1293,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): __slots__ = () - def delete(self, state, dict_): + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: if self.dispatch._active_history: old = self.get( state, @@ -1122,7 +1325,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): ): raise AttributeError("%s object does not have a value" % self) - def get_history(self, state, dict_, passive=PASSIVE_OFF): + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: if self.key in dict_: current = dict_[self.key] else: @@ -1152,7 +1360,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): self, state, current, original=original ) - def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: if self.key in dict_: current = dict_[self.key] elif passive & CALLABLES_OK: @@ -1160,6 +1373,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): else: return [] + ret: _AllPendingType + # can't use __hash__(), can't use __eq__() here if ( current is not None @@ -1184,14 +1399,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): def set( self, - state, - dict_, - value, - initiator, - passive=PASSIVE_OFF, - check_old=None, - pop=False, - ): + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + ) -> None: """Set a value on the given InstanceState.""" if self.dispatch._active_history: @@ -1227,7 +1442,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): value = self.fire_replace_event(state, dict_, value, old, initiator) dict_[self.key] = value - def fire_remove_event(self, state, dict_, value, initiator): + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: if self.trackparent and value not in ( None, PASSIVE_NO_RESULT, @@ -1240,7 +1461,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): state._modified_event(dict_, self, value) - def fire_replace_event(self, state, dict_, value, previous, initiator): + def fire_replace_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + previous: Any, + initiator: Optional[AttributeEventToken], + ) -> _T: if self.trackparent: if previous is not value and previous not in ( None, @@ -1263,7 +1491,64 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): return value -class CollectionAttributeImpl(AttributeImpl): +class HasCollectionAdapter: + __slots__ = () + + def _dispose_previous_collection( + self, + state: InstanceState[Any], + collection: _AdaptedCollectionProtocol, + adapter: CollectionAdapter, + fire_event: bool, + ) -> None: + raise NotImplementedError() + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + raise NotImplementedError() + + +if TYPE_CHECKING: + + def _is_collection_attribute_impl( + impl: AttributeImpl, + ) -> TypeGuard[CollectionAttributeImpl]: + ... + +else: + _is_collection_attribute_impl = operator.attrgetter("collection") + + +class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): """A collection-holding attribute that instruments changes in membership. Only handles collections of instrumented objects. @@ -1275,12 +1560,14 @@ class CollectionAttributeImpl(AttributeImpl): """ - default_accepts_scalar_loader = False uses_objects = True - supports_population = True collection = True + default_accepts_scalar_loader = False + supports_population = True dynamic = False + _bulk_replace_token: AttributeEventToken + __slots__ = ( "copy", "collection_factory", @@ -1316,9 +1603,9 @@ class CollectionAttributeImpl(AttributeImpl): copy_function = self.__copy self.copy = copy_function self.collection_factory = typecallable - self._append_token = Event(self, OP_APPEND) - self._remove_token = Event(self, OP_REMOVE) - self._bulk_replace_token = Event(self, OP_BULK_REPLACE) + self._append_token = AttributeEventToken(self, OP_APPEND) + self._remove_token = AttributeEventToken(self, OP_REMOVE) + self._bulk_replace_token = AttributeEventToken(self, OP_BULK_REPLACE) self._duck_typed_as = util.duck_type_collection( self.collection_factory() ) @@ -1336,14 +1623,24 @@ class CollectionAttributeImpl(AttributeImpl): def __copy(self, item): return [y for y in collections.collection_adapter(item)] - def get_history(self, state, dict_, passive=PASSIVE_OFF): + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: current = self.get(state, dict_, passive=passive) if current is PASSIVE_NO_RESULT: return HISTORY_BLANK else: return History.from_collection(self, state, current) - def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: # NOTE: passive is ignored here at the moment if self.key not in dict_: @@ -1383,7 +1680,13 @@ class CollectionAttributeImpl(AttributeImpl): return [(instance_state(o), o) for o in current] - def fire_append_event(self, state, dict_, value, initiator): + def fire_append_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + initiator: Optional[AttributeEventToken], + ) -> _T: for fn in self.dispatch.append: value = fn(state, value, initiator or self._append_token) @@ -1394,13 +1697,24 @@ class CollectionAttributeImpl(AttributeImpl): return value - def fire_append_wo_mutation_event(self, state, dict_, value, initiator): + def fire_append_wo_mutation_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + initiator: Optional[AttributeEventToken], + ) -> _T: for fn in self.dispatch.append_wo_mutation: value = fn(state, value, initiator or self._append_token) return value - def fire_pre_remove_event(self, state, dict_, initiator): + def fire_pre_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + initiator: Optional[AttributeEventToken], + ) -> None: """A special event used for pop() operations. The "remove" event needs to have the item to be removed passed to @@ -1411,7 +1725,13 @@ class CollectionAttributeImpl(AttributeImpl): """ state._modified_event(dict_, self, NO_VALUE, True) - def fire_remove_event(self, state, dict_, value, initiator): + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: if self.trackparent and value is not None: self.sethasparent(instance_state(value), state, False) @@ -1420,7 +1740,7 @@ class CollectionAttributeImpl(AttributeImpl): state._modified_event(dict_, self, NO_VALUE, True) - def delete(self, state, dict_): + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: if self.key not in dict_: return @@ -1433,7 +1753,9 @@ class CollectionAttributeImpl(AttributeImpl): # del is a no-op if collection not present. del dict_[self.key] - def _default_value(self, state, dict_): + def _default_value( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> _AdaptedCollectionProtocol: """Produce an empty collection for an un-initialized attribute""" assert self.key not in dict_, ( @@ -1448,7 +1770,9 @@ class CollectionAttributeImpl(AttributeImpl): adapter._set_empty(user_data) return user_data - def _initialize_collection(self, state): + def _initialize_collection( + self, state: InstanceState[Any] + ) -> Tuple[CollectionAdapter, _AdaptedCollectionProtocol]: adapter, collection = state.manager.initialize_collection( self.key, state, self.collection_factory @@ -1458,7 +1782,14 @@ class CollectionAttributeImpl(AttributeImpl): return adapter, collection - def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: collection = self.get_collection(state, dict_, passive=passive) if collection is PASSIVE_NO_RESULT: value = self.fire_append_event(state, dict_, value, initiator) @@ -1467,9 +1798,18 @@ class CollectionAttributeImpl(AttributeImpl): ), "Collection was loaded during event handling." state._get_pending_mutation(self.key).append(value) else: + if TYPE_CHECKING: + assert isinstance(collection, CollectionAdapter) collection.append_with_event(value, initiator) - def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: collection = self.get_collection(state, state.dict, passive=passive) if collection is PASSIVE_NO_RESULT: self.fire_remove_event(state, dict_, value, initiator) @@ -1478,9 +1818,18 @@ class CollectionAttributeImpl(AttributeImpl): ), "Collection was loaded during event handling." state._get_pending_mutation(self.key).remove(value) else: + if TYPE_CHECKING: + assert isinstance(collection, CollectionAdapter) collection.remove_with_event(value, initiator) - def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: try: # TODO: better solution here would be to add # a "popper" role to collections.py to complement @@ -1491,15 +1840,15 @@ class CollectionAttributeImpl(AttributeImpl): def set( self, - state, - dict_, - value, - initiator=None, - passive=PASSIVE_OFF, - check_old=None, - pop=False, - _adapt=True, - ): + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: iterable = orig_iterable = value # pulling a new collection first so that an adaptation exception does @@ -1518,7 +1867,7 @@ class CollectionAttributeImpl(AttributeImpl): and "None" or iterable.__class__.__name__ ) - wanted = self._duck_typed_as.__name__ + wanted = self._duck_typed_as.__name__ # type: ignore raise TypeError( "Incompatible collection type: %s is not %s-like" % (given, wanted) @@ -1560,8 +1909,12 @@ class CollectionAttributeImpl(AttributeImpl): self._dispose_previous_collection(state, old, old_collection, True) def _dispose_previous_collection( - self, state, collection, adapter, fire_event - ): + self, + state: InstanceState[Any], + collection: _AdaptedCollectionProtocol, + adapter: CollectionAdapter, + fire_event: bool, + ) -> None: del collection._sa_adapter # discarding old collection make sure it is not referenced in empty @@ -1570,11 +1923,15 @@ class CollectionAttributeImpl(AttributeImpl): if fire_event: self.dispatch.dispose_collection(state, collection, adapter) - def _invalidate_collection(self, collection: Collection) -> None: + def _invalidate_collection( + self, collection: _AdaptedCollectionProtocol + ) -> None: adapter = getattr(collection, "_sa_adapter") adapter.invalidated = True - def set_committed_value(self, state, dict_, value): + def set_committed_value( + self, state: InstanceState[Any], dict_: _InstanceDict, value: Any + ) -> _AdaptedCollectionProtocol: """Set an attribute value on the given instance and 'commit' it.""" collection, user_data = self._initialize_collection(state) @@ -1601,9 +1958,37 @@ class CollectionAttributeImpl(AttributeImpl): return user_data + @overload def get_collection( - self, state, dict_, user_data=None, passive=PASSIVE_OFF - ): + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: """Retrieve the CollectionAdapter associated with the given state. if user_data is None, retrieves it from the state using normal @@ -1612,14 +1997,18 @@ class CollectionAttributeImpl(AttributeImpl): """ if user_data is None: - user_data = self.get(state, dict_, passive=passive) - if user_data is PASSIVE_NO_RESULT: - return user_data + fetch_user_data = self.get(state, dict_, passive=passive) + if fetch_user_data is LoaderCallableStatus.PASSIVE_NO_RESULT: + return fetch_user_data + else: + user_data = cast("_AdaptedCollectionProtocol", fetch_user_data) return user_data._sa_adapter -def backref_listeners(attribute, key, uselist): +def backref_listeners( + attribute: QueryableAttribute[Any], key: str, uselist: bool +) -> None: """Apply listeners to synchronize a two-way relationship.""" # use easily recognizable names for stack traces. @@ -1695,7 +2084,7 @@ def backref_listeners(attribute, key, uselist): check_append_token = child_impl._append_token check_bulk_replace_token = ( child_impl._bulk_replace_token - if child_impl.collection + if _is_collection_attribute_impl(child_impl) else None ) @@ -1728,7 +2117,9 @@ def backref_listeners(attribute, key, uselist): # tokens to test for a recursive loop. check_append_token = child_impl._append_token check_bulk_replace_token = ( - child_impl._bulk_replace_token if child_impl.collection else None + child_impl._bulk_replace_token + if _is_collection_attribute_impl(child_impl) + else None ) if ( @@ -1756,6 +2147,8 @@ def backref_listeners(attribute, key, uselist): ) child_impl = child_state.manager[key].impl + check_replace_token: Optional[AttributeEventToken] + # tokens to test for a recursive loop. if not child_impl.collection and not child_impl.dynamic: check_remove_token = child_impl._remove_token @@ -1765,7 +2158,7 @@ def backref_listeners(attribute, key, uselist): check_remove_token = child_impl._remove_token check_replace_token = ( child_impl._bulk_replace_token - if child_impl.collection + if _is_collection_attribute_impl(child_impl) else None ) check_for_dupes_on_remove = False @@ -1848,10 +2241,10 @@ class History(NamedTuple): unchanged: Union[Tuple[()], List[Any]] deleted: Union[Tuple[()], List[Any]] - def __bool__(self): + def __bool__(self) -> bool: return self != HISTORY_BLANK - def empty(self): + def empty(self) -> bool: """Return True if this :class:`.History` has no changes and no existing, unchanged state. @@ -1859,29 +2252,29 @@ class History(NamedTuple): return not bool((self.added or self.deleted) or self.unchanged) - def sum(self): + def sum(self) -> Sequence[Any]: """Return a collection of added + unchanged + deleted.""" return ( (self.added or []) + (self.unchanged or []) + (self.deleted or []) ) - def non_deleted(self): + def non_deleted(self) -> Sequence[Any]: """Return a collection of added + unchanged.""" return (self.added or []) + (self.unchanged or []) - def non_added(self): + def non_added(self) -> Sequence[Any]: """Return a collection of unchanged + deleted.""" return (self.unchanged or []) + (self.deleted or []) - def has_changes(self): + def has_changes(self) -> bool: """Return True if this :class:`.History` has changes.""" return bool(self.added or self.deleted) - def as_state(self): + def as_state(self) -> History: return History( [ (c is not None) and instance_state(c) or None @@ -1898,9 +2291,16 @@ class History(NamedTuple): ) @classmethod - def from_scalar_attribute(cls, attribute, state, current): + def from_scalar_attribute( + cls, + attribute: ScalarAttributeImpl, + state: InstanceState[Any], + current: Any, + ) -> History: original = state.committed_state.get(attribute.key, _NO_HISTORY) + deleted: Union[Tuple[()], List[Any]] + if original is _NO_HISTORY: if current is NO_VALUE: return cls((), (), ()) @@ -1933,8 +2333,14 @@ class History(NamedTuple): @classmethod def from_object_attribute( - cls, attribute, state, current, original=_NO_HISTORY - ): + cls, + attribute: ScalarObjectAttributeImpl, + state: InstanceState[Any], + current: Any, + original: Any = _NO_HISTORY, + ) -> History: + deleted: Union[Tuple[()], List[Any]] + if original is _NO_HISTORY: original = state.committed_state.get(attribute.key, _NO_HISTORY) @@ -1965,7 +2371,12 @@ class History(NamedTuple): return cls([current], (), deleted) @classmethod - def from_collection(cls, attribute, state, current): + def from_collection( + cls, + attribute: CollectionAttributeImpl, + state: InstanceState[Any], + current: Any, + ) -> History: original = state.committed_state.get(attribute.key, _NO_HISTORY) if current is NO_VALUE: return cls((), (), ()) @@ -1999,7 +2410,9 @@ class History(NamedTuple): HISTORY_BLANK = History((), (), ()) -def get_history(obj, key, passive=PASSIVE_OFF): +def get_history( + obj: object, key: str, passive: PassiveFlag = PASSIVE_OFF +) -> History: """Return a :class:`.History` record for the given object and attribute key. @@ -2037,36 +2450,47 @@ def get_history(obj, key, passive=PASSIVE_OFF): return get_state_history(instance_state(obj), key, passive) -def get_state_history(state, key, passive=PASSIVE_OFF): +def get_state_history( + state: InstanceState[Any], key: str, passive: PassiveFlag = PASSIVE_OFF +) -> History: return state.get_history(key, passive) -def has_parent(cls, obj, key, optimistic=False): +def has_parent( + cls: Type[_O], obj: _O, key: str, optimistic: bool = False +) -> bool: """TODO""" manager = manager_of_class(cls) state = instance_state(obj) return manager.has_parent(state, key, optimistic) -def register_attribute(class_, key, **kw): - comparator = kw.pop("comparator", None) - parententity = kw.pop("parententity", None) - doc = kw.pop("doc", None) - desc = register_descriptor(class_, key, comparator, parententity, doc=doc) +def register_attribute( + class_: Type[_O], + key: str, + *, + comparator: interfaces.PropComparator[_T], + parententity: _InternalEntityType[_O], + doc: Optional[str] = None, + **kw: Any, +) -> InstrumentedAttribute[_T]: + desc = register_descriptor( + class_, key, comparator=comparator, parententity=parententity, doc=doc + ) register_attribute_impl(class_, key, **kw) return desc def register_attribute_impl( - class_, - key, - uselist=False, - callable_=None, - useobject=False, - impl_class=None, - backref=None, - **kw, -): + class_: Type[_O], + key: str, + uselist: bool = False, + callable_: Optional[_LoaderCallable] = None, + useobject: bool = False, + impl_class: Optional[Type[AttributeImpl]] = None, + backref: Optional[str] = None, + **kw: Any, +) -> InstrumentedAttribute[Any]: manager = manager_of_class(class_) if uselist: @@ -2077,10 +2501,18 @@ def register_attribute_impl( else: typecallable = kw.pop("typecallable", None) - dispatch = manager[key].dispatch + dispatch = cast( + "_Dispatch[QueryableAttribute[Any]]", manager[key].dispatch + ) # noqa: E501 + + impl: AttributeImpl if impl_class: - impl = impl_class(class_, key, typecallable, dispatch, **kw) + # TODO: this appears to be the DynamicAttributeImpl constructor + # which is hardcoded + impl = cast("Type[DynamicAttributeImpl]", impl_class)( + class_, key, typecallable, dispatch, **kw + ) elif uselist: impl = CollectionAttributeImpl( class_, key, callable_, dispatch, typecallable=typecallable, **kw @@ -2102,8 +2534,13 @@ def register_attribute_impl( def register_descriptor( - class_, key, comparator=None, parententity=None, doc=None -): + class_: Type[Any], + key: str, + *, + comparator: interfaces.PropComparator[_T], + parententity: _InternalEntityType[Any], + doc: Optional[str] = None, +) -> InstrumentedAttribute[_T]: manager = manager_of_class(class_) descriptor = InstrumentedAttribute( @@ -2116,11 +2553,11 @@ def register_descriptor( return descriptor -def unregister_attribute(class_, key): +def unregister_attribute(class_: Type[Any], key: str) -> None: manager_of_class(class_).uninstrument_attribute(key) -def init_collection(obj, key): +def init_collection(obj: object, key: str) -> CollectionAdapter: """Initialize a collection attribute and return the collection adapter. This function is used to provide direct access to collection internals @@ -2143,7 +2580,9 @@ def init_collection(obj, key): return init_state_collection(state, dict_, key) -def init_state_collection(state, dict_, key): +def init_state_collection( + state: InstanceState[Any], dict_: _InstanceDict, key: str +) -> CollectionAdapter: """Initialize a collection attribute and return the collection adapter. Discards any existing collection which may be there. @@ -2151,6 +2590,9 @@ def init_state_collection(state, dict_, key): """ attr = state.manager[key].impl + if TYPE_CHECKING: + assert isinstance(attr, HasCollectionAdapter) + old = dict_.pop(key, None) # discard old collection if old is not None: old_collection = old._sa_adapter @@ -2182,7 +2624,12 @@ def set_committed_value(instance, key, value): state.manager[key].impl.set_committed_value(state, dict_, value) -def set_attribute(instance, key, value, initiator=None): +def set_attribute( + instance: object, + key: str, + value: Any, + initiator: Optional[AttributeEventToken] = None, +) -> None: """Set the value of an attribute, firing history events. This function may be used regardless of instrumentation @@ -2211,7 +2658,7 @@ def set_attribute(instance, key, value, initiator=None): state.manager[key].impl.set(state, dict_, value, initiator) -def get_attribute(instance, key): +def get_attribute(instance: object, key: str) -> Any: """Get the value of an attribute, firing any callables required. This function may be used regardless of instrumentation @@ -2225,7 +2672,7 @@ def get_attribute(instance, key): return state.manager[key].impl.get(state, dict_) -def del_attribute(instance, key): +def del_attribute(instance: object, key: str) -> None: """Delete the value of an attribute, firing history events. This function may be used regardless of instrumentation @@ -2239,7 +2686,7 @@ def del_attribute(instance, key): state.manager[key].impl.delete(state, dict_) -def flag_modified(instance, key): +def flag_modified(instance: object, key: str) -> None: """Mark an attribute on an instance as 'modified'. This sets the 'modified' flag on the instance and @@ -2262,7 +2709,7 @@ def flag_modified(instance, key): state._modified_event(dict_, impl, NO_VALUE, is_userland=True) -def flag_dirty(instance): +def flag_dirty(instance: object) -> None: """Mark an instance as 'dirty' without any specific attribute mentioned. This is a special operation that will allow the object to travel through diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 367a5332de..0ace9b1cb6 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -38,14 +38,15 @@ from ..util.typing import Literal from ..util.typing import Self if typing.TYPE_CHECKING: + from ._typing import _ExternalEntityType from ._typing import _InternalEntityType from .attributes import InstrumentedAttribute from .instrumentation import ClassManager from .mapper import Mapper from .state import InstanceState + from .util import AliasedClass from ..sql._typing import _InfoType - _T = TypeVar("_T", bound=Any) _O = TypeVar("_O", bound=object) @@ -267,10 +268,22 @@ def _assertions( if TYPE_CHECKING: - def manager_of_class(cls: Type[Any]) -> ClassManager: + def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: + ... + + @overload + def opt_manager_of_class(cls: AliasedClass[Any]) -> None: ... - def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]: + @overload + def opt_manager_of_class( + cls: _ExternalEntityType[_O], + ) -> Optional[ClassManager[_O]]: + ... + + def opt_manager_of_class( + cls: _ExternalEntityType[_O], + ) -> Optional[ClassManager[_O]]: ... def instance_state(instance: _O) -> InstanceState[_O]: @@ -719,7 +732,7 @@ class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly): ... def __get__( - self, instance: object, owner: Any + self, instance: Optional[object], owner: Any ) -> Union[InstrumentedAttribute[_T], _T]: ... @@ -729,10 +742,10 @@ class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly): def __set__( self, instance: Any, value: Union[SQLCoreOperations[_T], _T] - ): + ) -> None: ... - def __delete__(self, instance: Any): + def __delete__(self, instance: Any) -> None: ... diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 717f1d0d68..da0da0fcfc 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +# mypy: allow-untyped-defs, allow-untyped-calls """Support for collections of mapped entities. @@ -109,17 +109,34 @@ import operator import threading import typing from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import weakref from .. import exc as sa_exc from .. import util from ..util.compat import inspect_getfullargspec +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from .attributes import CollectionAttributeImpl from .mapped_collection import attribute_mapped_collection from .mapped_collection import column_mapped_collection from .mapped_collection import mapped_collection from .mapped_collection import MappedCollection # noqa: F401 + from .state import InstanceState + __all__ = [ "collection", @@ -132,6 +149,28 @@ __all__ = [ __instrumentation_mutex = threading.Lock() +_CollectionFactoryType = Callable[[], "_AdaptedCollectionProtocol"] + +_T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) +_COL = TypeVar("_COL", bound="Collection[Any]") +_FN = TypeVar("_FN", bound="Callable[..., Any]") + + +class _CollectionConverterProtocol(Protocol): + def __call__(self, collection: _COL) -> _COL: + ... + + +class _AdaptedCollectionProtocol(Protocol): + _sa_adapter: CollectionAdapter + _sa_appender: Callable[..., Any] + _sa_remover: Callable[..., Any] + _sa_iterator: Callable[..., Iterable[Any]] + _sa_converter: _CollectionConverterProtocol + + class collection: """Decorators for entity collection classes. @@ -396,8 +435,13 @@ class collection: return decorator -collection_adapter = operator.attrgetter("_sa_adapter") -"""Fetch the :class:`.CollectionAdapter` for a collection.""" +if TYPE_CHECKING: + + def collection_adapter(collection: Collection[Any]) -> CollectionAdapter: + """Fetch the :class:`.CollectionAdapter` for a collection.""" + +else: + collection_adapter = operator.attrgetter("_sa_adapter") class CollectionAdapter: @@ -423,10 +467,33 @@ class CollectionAdapter: "empty", ) - def __init__(self, attr, owner_state, data): + attr: CollectionAttributeImpl + _key: str + + # this is actually a weakref; see note in constructor + _data: Callable[..., _AdaptedCollectionProtocol] + + owner_state: InstanceState[Any] + _converter: _CollectionConverterProtocol + invalidated: bool + empty: bool + + def __init__( + self, + attr: CollectionAttributeImpl, + owner_state: InstanceState[Any], + data: _AdaptedCollectionProtocol, + ): self.attr = attr self._key = attr.key - self._data = weakref.ref(data) + + # this weakref stays referenced throughout the lifespan of + # CollectionAdapter. so while the weakref can return None, this + # is realistically only during garbage collection of this object, so + # we type this as a callable that returns _AdaptedCollectionProtocol + # in all cases. + self._data = weakref.ref(data) # type: ignore + self.owner_state = owner_state data._sa_adapter = self self._converter = data._sa_converter @@ -437,7 +504,7 @@ class CollectionAdapter: util.warn("This collection has been invalidated.") @property - def data(self): + def data(self) -> _AdaptedCollectionProtocol: "The entity collection being adapted." return self._data() @@ -634,7 +701,10 @@ class CollectionAdapter: def __setstate__(self, d): self._key = d["key"] self.owner_state = d["owner_state"] - self._data = weakref.ref(d["data"]) + + # see note in constructor regarding this type: ignore + self._data = weakref.ref(d["data"]) # type: ignore + self._converter = d["data"]._sa_converter d["data"]._sa_adapter = self self.invalidated = d["invalidated"] @@ -682,7 +752,9 @@ def bulk_replace(values, existing_adapter, new_adapter, initiator=None): existing_adapter.fire_remove_event(member, initiator=initiator) -def prepare_instrumentation(factory): +def prepare_instrumentation( + factory: Union[Type[Collection[Any]], _CollectionFactoryType], +) -> _CollectionFactoryType: """Prepare a callable for future use as a collection class factory. Given a collection class factory (either a type or no-arg callable), @@ -693,18 +765,30 @@ def prepare_instrumentation(factory): into the run-time behavior of collection_class=InstrumentedList. """ + + impl_factory: _CollectionFactoryType + # Convert a builtin to 'Instrumented*' if factory in __canned_instrumentation: - factory = __canned_instrumentation[factory] + impl_factory = __canned_instrumentation[factory] + else: + impl_factory = cast(_CollectionFactoryType, factory) + + cls: Union[_CollectionFactoryType, Type[Collection[Any]]] # Create a specimen - cls = type(factory()) + cls = type(impl_factory()) # Did factory callable return a builtin? if cls in __canned_instrumentation: - # Wrap it so that it returns our 'Instrumented*' - factory = __converting_factory(cls, factory) - cls = factory() + + # if so, just convert. + # in previous major releases, this codepath wasn't working and was + # not covered by tests. prior to that it supplied a "wrapper" + # function that would return the class, though the rationale for this + # case is not known + impl_factory = __canned_instrumentation[cls] + cls = type(impl_factory()) # Instrument the class if needed. if __instrumentation_mutex.acquire(): @@ -714,26 +798,7 @@ def prepare_instrumentation(factory): finally: __instrumentation_mutex.release() - return factory - - -def __converting_factory(specimen_cls, original_factory): - """Return a wrapper that converts a "canned" collection like - set, dict, list into the Instrumented* version. - - """ - - instrumented_cls = __canned_instrumentation[specimen_cls] - - def wrapper(): - collection = original_factory() - return instrumented_cls(collection) - - # often flawed but better than nothing - wrapper.__name__ = "%sWrapper" % original_factory.__name__ - wrapper.__doc__ = original_factory.__doc__ - - return wrapper + return impl_factory def _instrument_class(cls): @@ -763,8 +828,8 @@ def _locate_roles_and_methods(cls): """ - roles = {} - methods = {} + roles: Dict[str, str] = {} + methods: Dict[str, Tuple[Optional[str], Optional[int], Optional[str]]] = {} for supercls in cls.__mro__: for name, method in vars(supercls).items(): @@ -784,7 +849,9 @@ def _locate_roles_and_methods(cls): # transfer instrumentation requests from decorated function # to the combined queue - before, after = None, None + before: Optional[Tuple[str, int]] = None + after: Optional[str] = None + if hasattr(method, "_sa_instrument_before"): op, argument = method._sa_instrument_before assert op in ("fire_append_event", "fire_remove_event") @@ -809,6 +876,7 @@ def _setup_canned_roles(cls, roles, methods): """ collection_type = util.duck_type_collection(cls) if collection_type in __interfaces: + assert collection_type is not None canned_roles, decorators = __interfaces[collection_type] for role, name in canned_roles.items(): roles.setdefault(role, name) @@ -934,9 +1002,9 @@ def _instrument_membership_mutator(method, before, argument, after): getattr(executor, after)(res, initiator) return res - wrapper._sa_instrumented = True + wrapper._sa_instrumented = True # type: ignore[attr-defined] if hasattr(method, "_sa_instrument_role"): - wrapper._sa_instrument_role = method._sa_instrument_role + wrapper._sa_instrument_role = method._sa_instrument_role # type: ignore[attr-defined] # noqa: E501 wrapper.__name__ = method.__name__ wrapper.__doc__ = method.__doc__ return wrapper @@ -990,7 +1058,7 @@ def __before_pop(collection, _sa_initiator=None): executor.fire_pre_remove_event(_sa_initiator) -def _list_decorators(): +def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: """Tailored instrumentation wrappers for any list-like class.""" def _tidy(fn): @@ -1131,7 +1199,7 @@ def _list_decorators(): return l -def _dict_decorators(): +def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: """Tailored instrumentation wrappers for any dict-like mapping class.""" def _tidy(fn): @@ -1255,7 +1323,7 @@ def _set_binops_check_loose(self: Any, obj: Any) -> bool: ) -def _set_decorators(): +def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: """Tailored instrumentation wrappers for any set-like class.""" def _tidy(fn): @@ -1420,36 +1488,52 @@ def _set_decorators(): return l -class InstrumentedList(list): +class InstrumentedList(List[_T]): """An instrumented version of the built-in list.""" -class InstrumentedSet(set): +class InstrumentedSet(Set[_T]): """An instrumented version of the built-in set.""" -class InstrumentedDict(dict): +class InstrumentedDict(Dict[_KT, _VT]): """An instrumented version of the built-in dict.""" -__canned_instrumentation = { - list: InstrumentedList, - set: InstrumentedSet, - dict: InstrumentedDict, -} - -__interfaces = { - list: ( - {"appender": "append", "remover": "remove", "iterator": "__iter__"}, - _list_decorators(), - ), - set: ( - {"appender": "add", "remover": "remove", "iterator": "__iter__"}, - _set_decorators(), - ), - # decorators are required for dicts and object collections. - dict: ({"iterator": "values"}, _dict_decorators()), -} +__canned_instrumentation: util.immutabledict[ + Any, _CollectionFactoryType +] = util.immutabledict( + { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } +) + +__interfaces: util.immutabledict[ + Any, + Tuple[ + Dict[str, str], + Dict[str, Callable[..., Any]], + ], +] = util.immutabledict( + { + list: ( + { + "appender": "append", + "remover": "remove", + "iterator": "__iter__", + }, + _list_decorators(), + ), + set: ( + {"appender": "add", "remover": "remove", "iterator": "__iter__"}, + _set_decorators(), + ), + # decorators are required for dicts and object collections. + dict: ({"iterator": "values"}, _dict_decorators()), + } +) def __go(lcls): diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 63a37d0dae..1b4f573b50 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -64,7 +64,9 @@ class DynaLoader(strategies.AbstractRelationshipLoader): ) -class DynamicAttributeImpl(attributes.AttributeImpl): +class DynamicAttributeImpl( + attributes.HasCollectionAdapter, attributes.AttributeImpl +): uses_objects = True default_accepts_scalar_loader = False supports_population = False @@ -120,11 +122,11 @@ class DynamicAttributeImpl(attributes.AttributeImpl): @util.memoized_property def _append_token(self): - return attributes.Event(self, attributes.OP_APPEND) + return attributes.AttributeEventToken(self, attributes.OP_APPEND) @util.memoized_property def _remove_token(self): - return attributes.Event(self, attributes.OP_REMOVE) + return attributes.AttributeEventToken(self, attributes.OP_REMOVE) def fire_append_event( self, state, dict_, value, initiator, collection_history=None diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 356958562f..85b85215ea 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +# mypy: allow-untyped-defs, allow-untyped-calls """Defines SQLAlchemy's system of class instrumentation. @@ -35,14 +35,19 @@ from __future__ import annotations from typing import Any from typing import Callable +from typing import cast +from typing import Collection from typing import Dict from typing import Generic +from typing import Iterable +from typing import List from typing import Optional from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union import weakref from . import base @@ -51,15 +56,21 @@ from . import exc from . import interfaces from . import state from ._typing import _O +from .attributes import _is_collection_attribute_impl from .. import util from ..event import EventTarget from ..util import HasMemoized +from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: from ._typing import _RegistryType + from .attributes import AttributeImpl from .attributes import InstrumentedAttribute + from .collections import _AdaptedCollectionProtocol + from .collections import _CollectionFactoryType from .decl_base import _MapperConfig + from .events import InstanceEvents from .mapper import Mapper from .state import InstanceState from ..event import dispatcher @@ -74,7 +85,7 @@ class _ExpiredAttributeLoaderProto(Protocol): state: state.InstanceState[Any], toload: Set[str], passive: base.PassiveFlag, - ): + ) -> None: ... @@ -91,7 +102,7 @@ class ClassManager( ): """Tracks state information at the class level.""" - dispatch: dispatcher[ClassManager] + dispatch: dispatcher[ClassManager[_O]] MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR STATE_ATTR = base.DEFAULT_STATE_ATTR @@ -108,8 +119,9 @@ class ClassManager( declarative_scan: Optional[weakref.ref[_MapperConfig]] = None registry: Optional[_RegistryType] = None - @property - @util.deprecated( + _bases: List[ClassManager[Any]] + + @util.deprecated_property( "1.4", message="The ClassManager.deferred_scalar_loader attribute is now " "named expired_attribute_loader", @@ -117,7 +129,7 @@ class ClassManager( def deferred_scalar_loader(self): return self.expired_attribute_loader - @deferred_scalar_loader.setter + @deferred_scalar_loader.setter # type: ignore[no-redef] @util.deprecated( "1.4", message="The ClassManager.deferred_scalar_loader attribute is now " @@ -138,18 +150,23 @@ class ClassManager( self._bases = [ mgr - for mgr in [ - opt_manager_of_class(base) - for base in self.class_.__bases__ - if isinstance(base, type) - ] + for mgr in cast( + "List[Optional[ClassManager[Any]]]", + [ + opt_manager_of_class(base) + for base in self.class_.__bases__ + if isinstance(base, type) + ], + ) if mgr is not None ] for base_ in self._bases: self.update(base_) - self.dispatch._events._new_classmanager_instance(class_, self) + cast( + "InstanceEvents", self.dispatch._events + )._new_classmanager_instance(class_, self) for basecls in class_.__mro__: mgr = opt_manager_of_class(basecls) @@ -263,7 +280,7 @@ class ClassManager( """ - found = {} + found: Dict[str, Any] = {} # constraints: # 1. yield keys in cls.__dict__ order @@ -303,7 +320,7 @@ class ClassManager( return key in self and self[key].impl is not None - def _subclass_manager(self, cls): + def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]: """Create a new ClassManager for a subclass of this ClassManager's class. @@ -321,7 +338,7 @@ class ClassManager( self.install_member("__init__", self.new_init) @util.memoized_property - def _state_constructor(self): + def _state_constructor(self) -> Type[state.InstanceState[_O]]: self.dispatch.first_init(self, self.class_) return state.InstanceState @@ -393,13 +410,15 @@ class ClassManager( if manager: manager.uninstrument_attribute(key, True) - def unregister(self): + def unregister(self) -> None: """remove all instrumentation established by this ClassManager.""" for key in list(self.originals): self.uninstall_member(key) - self.mapper = self.dispatch = self.new_init = None + self.mapper = None # type: ignore + self.dispatch = None # type: ignore + self.new_init = None self.info.clear() for key in list(self): @@ -409,7 +428,9 @@ class ClassManager( if self.MANAGER_ATTR in self.class_.__dict__: delattr(self.class_, self.MANAGER_ATTR) - def install_descriptor(self, key, inst): + def install_descriptor( + self, key: str, inst: InstrumentedAttribute[Any] + ) -> None: if key in (self.STATE_ATTR, self.MANAGER_ATTR): raise KeyError( "%r: requested attribute name conflicts with " @@ -417,10 +438,10 @@ class ClassManager( ) setattr(self.class_, key, inst) - def uninstall_descriptor(self, key): + def uninstall_descriptor(self, key: str) -> None: delattr(self.class_, key) - def install_member(self, key, implementation): + def install_member(self, key: str, implementation: Any) -> None: if key in (self.STATE_ATTR, self.MANAGER_ATTR): raise KeyError( "%r: requested attribute name conflicts with " @@ -429,34 +450,41 @@ class ClassManager( self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR)) setattr(self.class_, key, implementation) - def uninstall_member(self, key): + def uninstall_member(self, key: str) -> None: original = self.originals.pop(key, None) if original is not DEL_ATTR: setattr(self.class_, key, original) else: delattr(self.class_, key) - def instrument_collection_class(self, key, collection_class): + def instrument_collection_class( + self, key: str, collection_class: Type[Collection[Any]] + ) -> _CollectionFactoryType: return collections.prepare_instrumentation(collection_class) - def initialize_collection(self, key, state, factory): + def initialize_collection( + self, + key: str, + state: InstanceState[_O], + factory: _CollectionFactoryType, + ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]: user_data = factory() - adapter = collections.CollectionAdapter( - self.get_impl(key), state, user_data - ) + impl = self.get_impl(key) + assert _is_collection_attribute_impl(impl) + adapter = collections.CollectionAdapter(impl, state, user_data) return adapter, user_data - def is_instrumented(self, key, search=False): + def is_instrumented(self, key: str, search: bool = False) -> bool: if search: return key in self else: return key in self.local_attrs - def get_impl(self, key): + def get_impl(self, key: str) -> AttributeImpl: return self[key].impl @property - def attributes(self): + def attributes(self) -> Iterable[Any]: return iter(self.values()) # InstanceState management @@ -466,22 +494,26 @@ class ClassManager( if state is None: state = self._state_constructor(instance, self) self._state_setter(instance, state) - return instance + return instance # type: ignore[no-any-return] - def setup_instance(self, instance, state=None): + def setup_instance( + self, instance: _O, state: Optional[InstanceState[_O]] = None + ) -> None: if state is None: state = self._state_constructor(instance, self) self._state_setter(instance, state) - def teardown_instance(self, instance): + def teardown_instance(self, instance: _O) -> None: delattr(instance, self.STATE_ATTR) def _serialize( - self, state: state.InstanceState, state_dict: Dict[str, Any] + self, state: InstanceState[_O], state_dict: Dict[str, Any] ) -> _SerializeManager: return _SerializeManager(state, state_dict) - def _new_state_if_none(self, instance): + def _new_state_if_none( + self, instance: _O + ) -> Union[Literal[False], InstanceState[_O]]: """Install a default InstanceState if none is present. A private convenience method used by the __init__ decorator. @@ -503,20 +535,20 @@ class ClassManager( self._state_setter(instance, state) return state - def has_state(self, instance): + def has_state(self, instance: _O) -> bool: return hasattr(instance, self.STATE_ATTR) - def has_parent(self, state, key, optimistic=False): + def has_parent( + self, state: InstanceState[_O], key: str, optimistic: bool = False + ) -> bool: """TODO""" return self.get_impl(key).hasparent(state, optimistic=optimistic) - def __bool__(self): + def __bool__(self) -> bool: """All ClassManagers are non-zero regardless of attribute state.""" return True - __nonzero__ = __bool__ - - def __repr__(self): + def __repr__(self) -> str: return "<%s of %r at %x>" % ( self.__class__.__name__, self.class_, @@ -558,9 +590,11 @@ class _SerializeManager: manager.dispatch.unpickle(state, state_dict) -class InstrumentationFactory: +class InstrumentationFactory(EventTarget): """Factory for new ClassManager instances.""" + dispatch: dispatcher[InstrumentationFactory] + def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]: assert class_ is not None assert opt_manager_of_class(class_) is None @@ -589,11 +623,10 @@ class InstrumentationFactory: def _check_conflicts( self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]] - ): + ) -> None: """Overridden by a subclass to test for conflicting factories.""" - return - def unregister(self, class_): + def unregister(self, class_: Type[_O]) -> None: manager = manager_of_class(class_) manager.unregister() self.dispatch.class_uninstrument(class_) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 3e21b01023..3d093d367c 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -60,6 +60,7 @@ from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey from ..sql.schema import Column from ..sql.type_api import TypeEngine +from ..util.typing import DescriptorReference from ..util.typing import TypedDict if typing.TYPE_CHECKING: @@ -68,7 +69,6 @@ if typing.TYPE_CHECKING: from ._typing import _InstanceDict from ._typing import _InternalEntityType from ._typing import _ORMAdapterProto - from ._typing import _ORMColumnExprArgument from .attributes import InstrumentedAttribute from .context import _MapperEntity from .context import ORMCompileState @@ -89,7 +89,6 @@ if typing.TYPE_CHECKING: from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _DMLColumnArgument from ..sql._typing import _InfoType - from ..sql._typing import _PropagateAttrsType from ..sql.operators import OperatorType from ..sql.util import ColumnAdapter from ..sql.visitors import _TraverseInternalsType @@ -171,12 +170,18 @@ class _MapsColumns(_MappedAttribute[_T]): raise NotImplementedError() +# NOTE: MapperProperty needs to extend _MappedAttribute so that declarative +# typing works, i.e. "Mapped[A] = relationship()". This introduces an +# inconvenience which is that all the MapperProperty objects are treated +# as descriptors by typing tools, which are misled by this as assignment / +# access to a descriptor attribute wants to move through __get__. +# Therefore, references to MapperProperty as an instance variable, such +# as in PropComparator, may have some special typing workarounds such as the +# use of sqlalchemy.util.typing.DescriptorReference to avoid mis-interpretation +# by typing tools @inspection._self_inspects class MapperProperty( - HasCacheKey, - _MappedAttribute[_T], - InspectionAttrInfo, - util.MemoizedSlots, + HasCacheKey, _MappedAttribute[_T], InspectionAttrInfo, util.MemoizedSlots ): """Represent a particular class attribute mapped by :class:`_orm.Mapper`. @@ -522,6 +527,7 @@ class PropComparator(SQLORMOperations[_T]): _parententity: _InternalEntityType[Any] _adapt_to_entity: Optional[AliasedInsp[Any]] + prop: DescriptorReference[MapperProperty[_T]] def __init__( self, @@ -533,11 +539,20 @@ class PropComparator(SQLORMOperations[_T]): self._parententity = adapt_to_entity or parentmapper self._adapt_to_entity = adapt_to_entity - @util.ro_non_memoized_property + @util.non_memoized_property def property(self) -> Optional[MapperProperty[_T]]: + """Return the :class:`.MapperProperty` associated with this + :class:`.PropComparator`. + + + Return values here will commonly be instances of + :class:`.ColumnProperty` or :class:`.Relationship`. + + + """ return self.prop - def __clause_element__(self) -> _ORMColumnExprArgument[_T]: + def __clause_element__(self) -> roles.ColumnsClauseRole: raise NotImplementedError("%r" % self) def _bulk_update_tuples( @@ -567,18 +582,6 @@ class PropComparator(SQLORMOperations[_T]): compatible with QueryableAttribute.""" return self._parententity.mapper - @util.memoized_property - def _propagate_attrs(self) -> _PropagateAttrsType: - # this suits the case in coercions where we don't actually - # call ``__clause_element__()`` but still need to get - # resolved._propagate_attrs. See #6558. - return util.immutabledict( - { - "compile_state_plugin": "orm", - "plugin_subject": self._parentmapper, - } - ) - def _criterion_exists( self, criterion: Optional[_ColumnExpressionArgument[bool]] = None, @@ -657,7 +660,7 @@ class PropComparator(SQLORMOperations[_T]): def and_( self, *criteria: _ColumnExpressionArgument[bool] - ) -> ColumnElement[bool]: + ) -> PropComparator[bool]: """Add additional criteria to the ON clause that's represented by this relationship attribute. diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 66021c9c20..514ad7023e 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -71,6 +71,7 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _EntityType + from ._typing import _InternalEntityType from .mapper import Mapper from .util import AliasedClass from .util import AliasedInsp @@ -348,7 +349,7 @@ class Relationship( doc=self.doc, ) - class Comparator(PropComparator[_PT]): + class Comparator(util.MemoizedSlots, PropComparator[_PT]): """Produce boolean, comparison, and other operators for :class:`.Relationship` attributes. @@ -369,8 +370,13 @@ class Relationship( """ - _of_type = None - _extra_criteria = () + __slots__ = ( + "entity", + "mapper", + "property", + "_of_type", + "_extra_criteria", + ) def __init__( self, @@ -389,6 +395,8 @@ class Relationship( self._adapt_to_entity = adapt_to_entity if of_type: self._of_type = of_type + else: + self._of_type = None self._extra_criteria = extra_criteria def adapt_to_entity(self, adapt_to_entity): @@ -399,40 +407,35 @@ class Relationship( of_type=self._of_type, ) - @util.memoized_property - def entity(self): - """The target entity referred to by this - :class:`.Relationship.Comparator`. + entity: _InternalEntityType + """The target entity referred to by this + :class:`.Relationship.Comparator`. - This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp` - object. + This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp` + object. - This is the "target" or "remote" side of the - :func:`_orm.relationship`. + This is the "target" or "remote" side of the + :func:`_orm.relationship`. - """ - # this is a relatively recent change made for - # 1.4.27 as part of #7244. - # TODO: shouldn't _of_type be inspected up front when received? - if self._of_type is not None: - return inspect(self._of_type) - else: - return self.property.entity + """ - @util.memoized_property - def mapper(self): - """The target :class:`_orm.Mapper` referred to by this - :class:`.Relationship.Comparator`. + mapper: Mapper[Any] + """The target :class:`_orm.Mapper` referred to by this + :class:`.Relationship.Comparator`. - This is the "target" or "remote" side of the - :func:`_orm.relationship`. + This is the "target" or "remote" side of the + :func:`_orm.relationship`. - """ - return self.property.mapper + """ - @util.memoized_property - def _parententity(self): - return self.property.parent + def _memoized_attr_entity(self) -> _InternalEntityType: + if self._of_type: + return inspect(self._of_type) + else: + return self.prop.entity + + def _memoized_attr_mapper(self) -> Mapper[Any]: + return self.entity.mapper def _source_selectable(self): if self._adapt_to_entity: @@ -481,7 +484,9 @@ class Relationship( extra_criteria=self._extra_criteria, ) - def and_(self, *other): + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> interfaces.PropComparator[bool]: """Add AND criteria. See :meth:`.PropComparator.and_` for an example. @@ -489,12 +494,17 @@ class Relationship( .. versionadded:: 1.4 """ + exprs = tuple( + coercions.expect(roles.WhereHavingRole, clause) + for clause in util.coerce_generator_arg(criteria) + ) + return Relationship.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, of_type=self._of_type, - extra_criteria=self._extra_criteria + other, + extra_criteria=self._extra_criteria + exprs, ) def in_(self, other): @@ -924,8 +934,7 @@ class Relationship( else: return _orm_annotate(self.__negated_contains_or_equals(other)) - @util.memoized_property - def property(self): + def _memoized_attr_property(self): self.prop.parent._check_configure() return self.prop diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index ab32a3981a..49ee701b44 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -23,6 +23,7 @@ from typing import Optional from typing import Set from typing import Tuple from typing import TYPE_CHECKING +from typing import Union import weakref from . import base @@ -43,6 +44,7 @@ from .path_registry import PathRegistry from .. import exc as sa_exc from .. import inspection from .. import util +from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: @@ -53,6 +55,7 @@ if TYPE_CHECKING: from .attributes import History from .base import LoaderCallableStatus from .base import PassiveFlag + from .collections import _AdaptedCollectionProtocol from .identity import IdentityMap from .instrumentation import ClassManager from .interfaces import ORMOption @@ -421,7 +424,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): return self.key @util.memoized_property - def parents(self) -> Dict[int, InstanceState[Any]]: + def parents(self) -> Dict[int, Union[Literal[False], InstanceState[Any]]]: return {} @util.memoized_property @@ -429,7 +432,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): return {} @util.memoized_property - def _empty_collections(self) -> Dict[Any, Any]: + def _empty_collections(self) -> Dict[str, _AdaptedCollectionProtocol]: return {} @util.memoized_property @@ -844,7 +847,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): def _modified_event( self, dict_: _InstanceDict, - attr: AttributeImpl, + attr: Optional[AttributeImpl], previous: Any, collection: bool = False, is_userland: bool = False, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 7e8a6b4c6e..b095e3f7ae 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -24,6 +24,7 @@ from typing import Optional from typing import Sequence from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union import weakref @@ -531,6 +532,8 @@ class AliasedClass( """ + __name__: str + def __init__( self, mapped_class_or_ac: _EntityType[_O], @@ -1529,7 +1532,7 @@ class _ORMJoin(expression.Join): full: bool = False, _left_memo: Optional[Any] = None, _right_memo: Optional[Any] = None, - _extra_criteria: Sequence[ColumnElement[bool]] = (), + _extra_criteria: Tuple[ColumnElement[bool], ...] = (), ): left_info = cast( "Union[FromClause, _InternalEntityType[Any]]", @@ -1547,6 +1550,8 @@ class _ORMJoin(expression.Join): self._right_memo = _right_memo if isinstance(onclause, attributes.QueryableAttribute): + if TYPE_CHECKING: + assert isinstance(onclause.comparator, Relationship.Comparator) on_selectable = onclause.comparator._source_selectable() prop = onclause.property _extra_criteria += onclause._extra_criteria @@ -1728,12 +1733,15 @@ def with_parent( elif isinstance(prop, attributes.QueryableAttribute): if prop._of_type: from_entity = prop._of_type - if not prop_is_relationship(prop.property): + mapper_property = prop.property + if mapper_property is None or not prop_is_relationship( + mapper_property + ): raise sa_exc.ArgumentError( f"Expected relationship property for with_parent(), " - f"got {prop.property}" + f"got {mapper_property}" ) - prop_t = prop.property + prop_t = mapper_property else: prop_t = prop diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index fb959654fe..248b48a250 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -133,6 +133,8 @@ class Immutable: """ + __slots__ = () + _is_immutable = True def unique_params(self, *optionaldict, **kwargs): @@ -145,7 +147,7 @@ class Immutable: return self def _copy_internals( - self, omit_attrs: Iterable[str] = (), **kw: Any + self, *, omit_attrs: Iterable[str] = (), **kw: Any ) -> None: pass diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 15fbc2afb9..c16fbdae13 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -36,7 +36,6 @@ if typing.TYPE_CHECKING: from .elements import BindParameter from .elements import ClauseElement from .visitors import _TraverseInternalsType - from ..engine.base import _CompiledCacheType from ..engine.interfaces import _CoreSingleExecuteParams @@ -393,6 +392,13 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized): return HasCacheKey._generate_cache_key(self) +class SlotsMemoizedHasCacheKey(HasCacheKey, util.MemoizedSlots): + __slots__ = () + + def _memoized_method__generate_cache_key(self) -> Optional[CacheKey]: + return HasCacheKey._generate_cache_key(self) + + class CacheKey(NamedTuple): """The key used to identify a SQL statement construct in the SQL compilation cache. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 2e0112f08f..2655adbdc9 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1265,7 +1265,16 @@ class ColumnAdapter(ClauseAdapter): if self.adapt_required and c is col: return None - c._allow_label_resolve = self.allow_label_resolve + # allow_label_resolve is consumed by one case for joined eager loading + # as part of its logic to prevent its own columns from being affected + # by .order_by(). Before full typing were applied to the ORM, this + # logic would set this attribute on the incoming object (which is + # typically a column, but we have a test for it being a non-column + # object) if no column were found. While this seemed to + # have no negative effects, this adjustment should only occur on the + # new column which is assumed to be local to an adapted selectable. + if c is not col: + c._allow_label_resolve = self.allow_label_resolve return c diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index e9b0c93f28..7150dedcf8 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -410,11 +410,11 @@ class UniqueAppender(Generic[_T]): return iter(self.data) -def coerce_generator_arg(arg): +def coerce_generator_arg(arg: Any) -> List[Any]: if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): return list(arg[0]) else: - return arg + return cast("List[Any]", arg) def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 10110dbbee..24c66bfa4e 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1272,17 +1272,20 @@ class MemoizedSlots: def _fallback_getattr(self, key): raise AttributeError(key) - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: if key.startswith("_memoized_attr_") or key.startswith( "_memoized_method_" ): raise AttributeError(key) - elif hasattr(self, "_memoized_attr_%s" % key): - value = getattr(self, "_memoized_attr_%s" % key)() + # to avoid recursion errors when interacting with other __getattr__ + # schemes that refer to this one, when testing for memoized method + # look at __class__ only rather than going into __getattr__ again. + elif hasattr(self.__class__, f"_memoized_attr_{key}"): + value = getattr(self, f"_memoized_attr_{key}")() setattr(self, key, value) return value - elif hasattr(self, "_memoized_method_%s" % key): - fn = getattr(self, "_memoized_method_%s" % key) + elif hasattr(self.__class__, f"_memoized_method_{key}"): + fn = getattr(self, f"_memoized_method_{key}") def oneshot(*args, **kw): result = fn(*args, **kw) diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 4929ba1a65..a95f5ab93c 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -9,6 +9,7 @@ from typing import Callable from typing import cast from typing import Dict from typing import ForwardRef +from typing import Generic from typing import Iterable from typing import Optional from typing import Tuple @@ -213,3 +214,41 @@ def _get_type_name(type_: Type[Any]) -> str: typ_name = getattr(type_, "_name", None) return typ_name # type: ignore + + +class DescriptorProto(Protocol): + def __get__(self, instance: object, owner: Any) -> Any: + ... + + def __set__(self, instance: Any, value: Any) -> None: + ... + + def __delete__(self, instance: Any) -> None: + ... + + +_DESC = TypeVar("_DESC", bound=DescriptorProto) + + +class DescriptorReference(Generic[_DESC]): + """a descriptor that refers to a descriptor. + + used for cases where we need to have an instance variable referring to an + object that is itself a descriptor, which typically confuses typing tools + as they don't know when they should use ``__get__`` or not when referring + to the descriptor assignment as an instance variable. See + sqlalchemy.orm.interfaces.PropComparator.prop + + """ + + def __get__(self, instance: object, owner: Any) -> _DESC: + ... + + def __set__(self, instance: Any, value: _DESC) -> None: + ... + + def __delete__(self, instance: Any) -> None: + ... + + +# $def ro_descriptor_reference(fn: Callable[]) diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index 7830fcee68..c222a0a933 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -26,6 +26,13 @@ from sqlalchemy.testing import ne_ from sqlalchemy.testing.util import decorator +def _register_attribute(class_, key, **kw): + kw.setdefault("comparator", object()) + kw.setdefault("parententity", object()) + + attributes.register_attribute(class_, key, **kw) + + @decorator def modifies_instrumentation_finders(fn, *args, **kw): pristine = instrumentation.instrumentation_finders[:] @@ -205,13 +212,9 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): pass register_class(User) - attributes.register_attribute( - User, "user_id", uselist=False, useobject=False - ) - attributes.register_attribute( - User, "user_name", uselist=False, useobject=False - ) - attributes.register_attribute( + _register_attribute(User, "user_id", uselist=False, useobject=False) + _register_attribute(User, "user_name", uselist=False, useobject=False) + _register_attribute( User, "email_address", uselist=False, useobject=False ) @@ -238,13 +241,13 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): pass register_class(User) - attributes.register_attribute( + _register_attribute( User, "user_id", uselist=False, useobject=False ) - attributes.register_attribute( + _register_attribute( User, "user_name", uselist=False, useobject=False ) - attributes.register_attribute( + _register_attribute( User, "email_address", uselist=False, useobject=False ) @@ -284,12 +287,8 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): manager = register_class(Foo) manager.expired_attribute_loader = loader - attributes.register_attribute( - Foo, "a", uselist=False, useobject=False - ) - attributes.register_attribute( - Foo, "b", uselist=False, useobject=False - ) + _register_attribute(Foo, "a", uselist=False, useobject=False) + _register_attribute(Foo, "b", uselist=False, useobject=False) if base is object: assert Foo not in ( @@ -360,13 +359,13 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): def func3(state, passive): return "this is the shared attr" - attributes.register_attribute( + _register_attribute( Foo, "element", uselist=False, callable_=func1, useobject=True ) - attributes.register_attribute( + _register_attribute( Foo, "element2", uselist=False, callable_=func3, useobject=True ) - attributes.register_attribute( + _register_attribute( Bar, "element", uselist=False, callable_=func2, useobject=True ) @@ -388,7 +387,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): register_class(Post) register_class(Blog) - attributes.register_attribute( + _register_attribute( Post, "blog", uselist=False, @@ -396,7 +395,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): trackparent=True, useobject=True, ) - attributes.register_attribute( + _register_attribute( Blog, "posts", uselist=True, @@ -438,15 +437,11 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): register_class(Foo) register_class(Bar) - attributes.register_attribute( - Foo, "name", uselist=False, useobject=False - ) - attributes.register_attribute( + _register_attribute(Foo, "name", uselist=False, useobject=False) + _register_attribute( Foo, "bars", uselist=True, trackparent=True, useobject=True ) - attributes.register_attribute( - Bar, "name", uselist=False, useobject=False - ) + _register_attribute(Bar, "name", uselist=False, useobject=False) f1 = Foo() f1.name = "f1" @@ -517,10 +512,8 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): pass register_class(Foo) - attributes.register_attribute( - Foo, "name", uselist=False, useobject=False - ) - attributes.register_attribute( + _register_attribute(Foo, "name", uselist=False, useobject=False) + _register_attribute( Foo, "bars", uselist=True, trackparent=True, useobject=True ) diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index 3370fa1b55..81cecb08b3 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -147,7 +147,10 @@ class _MutableDictTestBase(_MutableDictTestFixture): canary.mock_calls, [ mock.call( - f1, attributes.Event(Foo.data.impl, attributes.OP_MODIFIED) + f1, + attributes.AttributeEventToken( + Foo.data.impl, attributes.OP_MODIFIED + ), ) ], ) diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 8e6ddf9d4e..e1274a8051 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -36,6 +36,13 @@ def _set_callable(state, dict_, key, callable_): fn(state, dict_, None) +def _register_attribute(class_, key, **kw): + kw.setdefault("comparator", object()) + kw.setdefault("parententity", object()) + + attributes.register_attribute(class_, key, **kw) + + class AttributeImplAPITest(fixtures.MappedTest): def _scalar_obj_fixture(self): class A: @@ -46,7 +53,7 @@ class AttributeImplAPITest(fixtures.MappedTest): instrumentation.register_class(A) instrumentation.register_class(B) - attributes.register_attribute(A, "b", uselist=False, useobject=True) + _register_attribute(A, "b", uselist=False, useobject=True) return A, B def _collection_obj_fixture(self): @@ -58,7 +65,7 @@ class AttributeImplAPITest(fixtures.MappedTest): instrumentation.register_class(A) instrumentation.register_class(B) - attributes.register_attribute(A, "b", uselist=True, useobject=True) + _register_attribute(A, "b", uselist=True, useobject=True) return A, B def test_scalar_obj_remove_invalid(self): @@ -228,13 +235,9 @@ class AttributesTest(fixtures.ORMTest): pass instrumentation.register_class(User) - attributes.register_attribute( - User, "user_id", uselist=False, useobject=False - ) - attributes.register_attribute( - User, "user_name", uselist=False, useobject=False - ) - attributes.register_attribute( + _register_attribute(User, "user_id", uselist=False, useobject=False) + _register_attribute(User, "user_name", uselist=False, useobject=False) + _register_attribute( User, "email_address", uselist=False, useobject=False ) u = User() @@ -263,28 +266,22 @@ class AttributesTest(fixtures.ORMTest): def test_pickleness(self): instrumentation.register_class(MyTest) instrumentation.register_class(MyTest2) - attributes.register_attribute( - MyTest, "user_id", uselist=False, useobject=False - ) - attributes.register_attribute( + _register_attribute(MyTest, "user_id", uselist=False, useobject=False) + _register_attribute( MyTest, "user_name", uselist=False, useobject=False ) - attributes.register_attribute( + _register_attribute( MyTest, "email_address", uselist=False, useobject=False ) - attributes.register_attribute( - MyTest2, "a", uselist=False, useobject=False - ) - attributes.register_attribute( - MyTest2, "b", uselist=False, useobject=False - ) + _register_attribute(MyTest2, "a", uselist=False, useobject=False) + _register_attribute(MyTest2, "b", uselist=False, useobject=False) # shouldn't be pickling callables at the class level def somecallable(state, passive): return None - attributes.register_attribute( + _register_attribute( MyTest, "mt2", uselist=True, @@ -350,9 +347,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "bars", uselist=True, useobject=True - ) + _register_attribute(Foo, "bars", uselist=True, useobject=True) assert_raises_message( orm_exc.ObjectDereferencedError, @@ -367,9 +362,7 @@ class AttributesTest(fixtures.ORMTest): pass instrumentation.register_class(User) - attributes.register_attribute( - User, "user_name", uselist=False, useobject=False - ) + _register_attribute(User, "user_name", uselist=False, useobject=False) class Blog: name = User.user_name @@ -388,7 +381,7 @@ class AttributesTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, "b", uselist=False, useobject=False) + _register_attribute(Foo, "b", uselist=False, useobject=False) f1 = Foo() @@ -417,7 +410,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, "b", uselist=False, useobject=True) + _register_attribute(Foo, "b", uselist=False, useobject=True) f1 = Foo() @@ -444,7 +437,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, "b", uselist=True, useobject=True) + _register_attribute(Foo, "b", uselist=True, useobject=True) f1 = Foo() @@ -472,8 +465,8 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) manager = attributes.manager_of_class(Foo) manager.expired_attribute_loader = loader - attributes.register_attribute(Foo, "a", uselist=False, useobject=False) - attributes.register_attribute(Foo, "b", uselist=False, useobject=False) + _register_attribute(Foo, "a", uselist=False, useobject=False) + _register_attribute(Foo, "b", uselist=False, useobject=False) f = Foo() attributes.instance_state(f)._expire( @@ -518,12 +511,8 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(MyTest) manager = attributes.manager_of_class(MyTest) manager.expired_attribute_loader = loader - attributes.register_attribute( - MyTest, "a", uselist=False, useobject=False - ) - attributes.register_attribute( - MyTest, "b", uselist=False, useobject=False - ) + _register_attribute(MyTest, "a", uselist=False, useobject=False) + _register_attribute(MyTest, "b", uselist=False, useobject=False) m = MyTest() attributes.instance_state(m)._expire( @@ -544,19 +533,13 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(User) instrumentation.register_class(Address) - attributes.register_attribute( - User, "user_id", uselist=False, useobject=False - ) - attributes.register_attribute( - User, "user_name", uselist=False, useobject=False - ) - attributes.register_attribute( - User, "addresses", uselist=True, useobject=True - ) - attributes.register_attribute( + _register_attribute(User, "user_id", uselist=False, useobject=False) + _register_attribute(User, "user_name", uselist=False, useobject=False) + _register_attribute(User, "addresses", uselist=True, useobject=True) + _register_attribute( Address, "address_id", uselist=False, useobject=False ) - attributes.register_attribute( + _register_attribute( Address, "email_address", uselist=False, useobject=False ) @@ -613,7 +596,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Blog) # set up instrumented attributes with backrefs - attributes.register_attribute( + _register_attribute( Post, "blog", uselist=False, @@ -621,7 +604,7 @@ class AttributesTest(fixtures.ORMTest): trackparent=True, useobject=True, ) - attributes.register_attribute( + _register_attribute( Blog, "posts", uselist=True, @@ -675,7 +658,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Post) instrumentation.register_class(Blog) - attributes.register_attribute(Post, "blog", useobject=True) + _register_attribute(Post, "blog", useobject=True) assert_raises_message( AssertionError, "This AttributeImpl is not configured to track parents.", @@ -714,13 +697,13 @@ class AttributesTest(fixtures.ORMTest): def func3(state, passive): return "this is the shared attr" - attributes.register_attribute( + _register_attribute( Foo, "element", uselist=False, callable_=func1, useobject=True ) - attributes.register_attribute( + _register_attribute( Foo, "element2", uselist=False, callable_=func3, useobject=True ) - attributes.register_attribute( + _register_attribute( Bar, "element", uselist=False, callable_=func2, useobject=True ) @@ -766,9 +749,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "element", uselist=False, useobject=True - ) + _register_attribute(Foo, "element", uselist=False, useobject=True) el = Element() x = Bar() x.element = el @@ -804,13 +785,13 @@ class AttributesTest(fixtures.ORMTest): def func2(state, passive): return [bar1, bar2, bar3] - attributes.register_attribute( + _register_attribute( Foo, "col1", uselist=False, callable_=func1, useobject=True ) - attributes.register_attribute( + _register_attribute( Foo, "col2", uselist=True, callable_=func2, useobject=True ) - attributes.register_attribute(Bar, "id", uselist=False, useobject=True) + _register_attribute(Bar, "id", uselist=False, useobject=True) x = Foo() attributes.instance_state(x)._commit_all(attributes.instance_dict(x)) x.col2.append(bar4) @@ -828,10 +809,10 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( + _register_attribute( Foo, "element", uselist=False, trackparent=True, useobject=True ) - attributes.register_attribute( + _register_attribute( Bar, "element", uselist=False, trackparent=True, useobject=True ) f1 = Foo() @@ -878,7 +859,7 @@ class AttributesTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute( + _register_attribute( Foo, "collection", uselist=True, typecallable=set, useobject=True ) assert attributes.manager_of_class(Foo).is_instrumented("collection") @@ -888,7 +869,7 @@ class AttributesTest(fixtures.ORMTest): "collection" ) try: - attributes.register_attribute( + _register_attribute( Foo, "collection", uselist=True, @@ -911,7 +892,7 @@ class AttributesTest(fixtures.ORMTest): def remove(self, item): del self[item.foo] - attributes.register_attribute( + _register_attribute( Foo, "collection", uselist=True, @@ -925,7 +906,7 @@ class AttributesTest(fixtures.ORMTest): pass try: - attributes.register_attribute( + _register_attribute( Foo, "collection", uselist=True, @@ -952,7 +933,7 @@ class AttributesTest(fixtures.ORMTest): def remove(self, item): pass - attributes.register_attribute( + _register_attribute( Foo, "collection", uselist=True, @@ -970,9 +951,9 @@ class AttributesTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, "a", useobject=False) - attributes.register_attribute(Foo, "b", useobject=False) - attributes.register_attribute(Foo, "c", useobject=False) + _register_attribute(Foo, "a", useobject=False) + _register_attribute(Foo, "b", useobject=False) + _register_attribute(Foo, "c", useobject=False) f1 = Foo() state = attributes.instance_state(f1) @@ -1026,7 +1007,7 @@ class GetNoValueTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) if expected is not None: - attributes.register_attribute( + _register_attribute( Foo, "attr", useobject=True, @@ -1034,9 +1015,7 @@ class GetNoValueTest(fixtures.ORMTest): callable_=lazy_callable, ) else: - attributes.register_attribute( - Foo, "attr", useobject=True, uselist=False - ) + _register_attribute(Foo, "attr", useobject=True, uselist=False) f1 = self.f1 = Foo() return ( @@ -1092,9 +1071,7 @@ class UtilTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "coll", uselist=True, useobject=True - ) + _register_attribute(Foo, "coll", uselist=True, useobject=True) f1 = Foo() b1 = Bar() @@ -1123,10 +1100,8 @@ class UtilTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "col_list", uselist=True, useobject=True - ) - attributes.register_attribute( + _register_attribute(Foo, "col_list", uselist=True, useobject=True) + _register_attribute( Foo, "col_set", uselist=True, useobject=True, typecallable=set ) @@ -1146,8 +1121,8 @@ class UtilTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, "a", uselist=False, useobject=False) - attributes.register_attribute(Bar, "b", uselist=False, useobject=False) + _register_attribute(Foo, "a", uselist=False, useobject=False) + _register_attribute(Bar, "b", uselist=False, useobject=False) @event.listens_for(Foo.a, "set") def sync_a(target, value, oldvalue, initiator): @@ -1182,14 +1157,14 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Student) instrumentation.register_class(Course) - attributes.register_attribute( + _register_attribute( Student, "courses", uselist=True, backref="students", useobject=True, ) - attributes.register_attribute( + _register_attribute( Course, "students", uselist=True, backref="courses", useobject=True ) @@ -1217,7 +1192,7 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Post) instrumentation.register_class(Blog) - attributes.register_attribute( + _register_attribute( Post, "blog", uselist=False, @@ -1225,7 +1200,7 @@ class BackrefTest(fixtures.ORMTest): trackparent=True, useobject=True, ) - attributes.register_attribute( + _register_attribute( Blog, "posts", uselist=True, @@ -1266,11 +1241,11 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Port) instrumentation.register_class(Jack) - attributes.register_attribute( + _register_attribute( Port, "jack", uselist=False, useobject=True, backref="port" ) - attributes.register_attribute( + _register_attribute( Jack, "port", uselist=False, useobject=True, backref="jack" ) @@ -1306,7 +1281,7 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Parent) instrumentation.register_class(Child) instrumentation.register_class(SubChild) - attributes.register_attribute( + _register_attribute( Parent, "child", uselist=False, @@ -1314,7 +1289,7 @@ class BackrefTest(fixtures.ORMTest): parent_token=p_token, useobject=True, ) - attributes.register_attribute( + _register_attribute( Child, "parent", uselist=False, @@ -1322,7 +1297,7 @@ class BackrefTest(fixtures.ORMTest): parent_token=c_token, useobject=True, ) - attributes.register_attribute( + _register_attribute( SubChild, "parent", uselist=False, @@ -1354,7 +1329,7 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Parent) instrumentation.register_class(SubParent) instrumentation.register_class(Child) - attributes.register_attribute( + _register_attribute( Parent, "children", uselist=True, @@ -1362,7 +1337,7 @@ class BackrefTest(fixtures.ORMTest): parent_token=p_token, useobject=True, ) - attributes.register_attribute( + _register_attribute( SubParent, "children", uselist=True, @@ -1370,7 +1345,7 @@ class BackrefTest(fixtures.ORMTest): parent_token=p_token, useobject=True, ) - attributes.register_attribute( + _register_attribute( Child, "parent", uselist=False, @@ -1444,15 +1419,11 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): instrumentation.register_class(A) instrumentation.register_class(B) instrumentation.register_class(C) - attributes.register_attribute(C, "a", backref="c", useobject=True) - attributes.register_attribute(C, "b", backref="c", useobject=True) + _register_attribute(C, "a", backref="c", useobject=True) + _register_attribute(C, "b", backref="c", useobject=True) - attributes.register_attribute( - A, "c", backref="a", useobject=True, uselist=True - ) - attributes.register_attribute( - B, "c", backref="b", useobject=True, uselist=True - ) + _register_attribute(A, "c", backref="a", useobject=True, uselist=True) + _register_attribute(B, "c", backref="b", useobject=True, uselist=True) return A, B, C @@ -1470,15 +1441,11 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): instrumentation.register_class(B) instrumentation.register_class(C) - attributes.register_attribute( - C, "a", backref="c", useobject=True, uselist=True - ) - attributes.register_attribute( - C, "b", backref="c", useobject=True, uselist=True - ) + _register_attribute(C, "a", backref="c", useobject=True, uselist=True) + _register_attribute(C, "b", backref="c", useobject=True, uselist=True) - attributes.register_attribute(A, "c", backref="a", useobject=True) - attributes.register_attribute(B, "c", backref="b", useobject=True) + _register_attribute(A, "c", backref="a", useobject=True) + _register_attribute(B, "c", backref="b", useobject=True) return A, B, C @@ -1492,14 +1459,10 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): instrumentation.register_class(A) instrumentation.register_class(B) - attributes.register_attribute(A, "b", backref="a1", useobject=True) - attributes.register_attribute( - B, "a1", backref="b", useobject=True, uselist=True - ) + _register_attribute(A, "b", backref="a1", useobject=True) + _register_attribute(B, "a1", backref="b", useobject=True, uselist=True) - attributes.register_attribute( - B, "a2", backref="b", useobject=True, uselist=True - ) + _register_attribute(B, "a2", backref="b", useobject=True, uselist=True) return A, B @@ -1542,7 +1505,7 @@ class PendingBackrefTest(fixtures.ORMTest): instrumentation.register_class(Post) instrumentation.register_class(Blog) - attributes.register_attribute( + _register_attribute( Post, "blog", uselist=False, @@ -1550,7 +1513,7 @@ class PendingBackrefTest(fixtures.ORMTest): trackparent=True, useobject=True, ) - attributes.register_attribute( + _register_attribute( Blog, "posts", uselist=True, @@ -1777,7 +1740,7 @@ class HistoryTest(fixtures.TestBase): pass instrumentation.register_class(Foo) - attributes.register_attribute( + _register_attribute( Foo, "someattr", uselist=uselist, @@ -1797,7 +1760,7 @@ class HistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( + _register_attribute( Foo, "someattr", uselist=uselist, @@ -2617,7 +2580,7 @@ class HistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( + _register_attribute( Foo, "someattr", uselist=True, @@ -2675,12 +2638,8 @@ class HistoryTest(fixtures.TestBase): pass instrumentation.register_class(Foo) - attributes.register_attribute( - Foo, "someattr", uselist=True, useobject=True - ) - attributes.register_attribute( - Foo, "id", uselist=False, useobject=False - ) + _register_attribute(Foo, "someattr", uselist=True, useobject=True) + _register_attribute(Foo, "id", uselist=False, useobject=False) instrumentation.register_class(Bar) hi = Bar(name="hi") there = Bar(name="there") @@ -2868,7 +2827,7 @@ class HistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( + _register_attribute( Foo, "bars", uselist=True, @@ -2876,7 +2835,7 @@ class HistoryTest(fixtures.TestBase): trackparent=True, useobject=True, ) - attributes.register_attribute( + _register_attribute( Bar, "foo", uselist=False, @@ -2945,7 +2904,7 @@ class LazyloadHistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( + _register_attribute( Foo, "bars", uselist=True, @@ -2954,7 +2913,7 @@ class LazyloadHistoryTest(fixtures.TestBase): callable_=lazyload, useobject=True, ) - attributes.register_attribute( + _register_attribute( Bar, "foo", uselist=False, @@ -3004,7 +2963,7 @@ class LazyloadHistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( + _register_attribute( Foo, "bars", uselist=True, @@ -3063,7 +3022,7 @@ class LazyloadHistoryTest(fixtures.TestBase): return lazy_load instrumentation.register_class(Foo) - attributes.register_attribute( + _register_attribute( Foo, "bar", uselist=False, callable_=lazyload, useobject=False ) lazy_load = "hi" @@ -3119,7 +3078,7 @@ class LazyloadHistoryTest(fixtures.TestBase): return lazy_load instrumentation.register_class(Foo) - attributes.register_attribute( + _register_attribute( Foo, "bar", uselist=False, @@ -3184,7 +3143,7 @@ class LazyloadHistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( + _register_attribute( Foo, "bar", uselist=False, @@ -3254,18 +3213,12 @@ class ListenerTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "data", uselist=False, useobject=False - ) - attributes.register_attribute( - Foo, "barlist", uselist=True, useobject=True - ) - attributes.register_attribute( + _register_attribute(Foo, "data", uselist=False, useobject=False) + _register_attribute(Foo, "barlist", uselist=True, useobject=True) + _register_attribute( Foo, "barset", typecallable=set, uselist=True, useobject=True ) - attributes.register_attribute( - Bar, "data", uselist=False, useobject=False - ) + _register_attribute(Bar, "data", uselist=False, useobject=False) event.listen(Foo.data, "set", on_set, retval=True) event.listen(Foo.barlist, "append", append, retval=True) event.listen(Foo.barset, "append", append, retval=True) @@ -3291,12 +3244,8 @@ class ListenerTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "data", uselist=False, useobject=False - ) - attributes.register_attribute( - Foo, "barlist", uselist=True, useobject=True - ) + _register_attribute(Foo, "data", uselist=False, useobject=False) + _register_attribute(Foo, "barlist", uselist=True, useobject=True) event.listen(Foo.data, "set", canary.set, named=True) event.listen(Foo.barlist, "append", canary.append, named=True) @@ -3312,21 +3261,21 @@ class ListenerTest(fixtures.ORMTest): [ call.set( oldvalue=attributes.NO_VALUE, - initiator=attributes.Event( + initiator=attributes.AttributeEventToken( Foo.data.impl, attributes.OP_REPLACE ), target=f1, value=5, ), call.append( - initiator=attributes.Event( + initiator=attributes.AttributeEventToken( Foo.barlist.impl, attributes.OP_APPEND ), target=f1, value=b1, ), call.remove( - initiator=attributes.Event( + initiator=attributes.AttributeEventToken( Foo.barlist.impl, attributes.OP_REMOVE ), target=f1, @@ -3344,9 +3293,7 @@ class ListenerTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "barlist", uselist=True, useobject=True - ) + _register_attribute(Foo, "barlist", uselist=True, useobject=True) canary = Mock() event.listen(Foo.barlist, "init_collection", canary.init) @@ -3389,9 +3336,7 @@ class ListenerTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "barlist", uselist=True, useobject=True - ) + _register_attribute(Foo, "barlist", uselist=True, useobject=True) canary = [] def append(state, child, initiator): @@ -3428,7 +3373,7 @@ class ListenerTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, "bar") + _register_attribute(Foo, "bar") event.listen(Foo.bar, "modified", canary) f1 = Foo() @@ -3436,7 +3381,14 @@ class ListenerTest(fixtures.ORMTest): attributes.flag_modified(f1, "bar") eq_( canary.mock_calls, - [call(f1, attributes.Event(Foo.bar.impl, attributes.OP_MODIFIED))], + [ + call( + f1, + attributes.AttributeEventToken( + Foo.bar.impl, attributes.OP_MODIFIED + ), + ) + ], ) def test_none_init_scalar(self): @@ -3446,7 +3398,7 @@ class ListenerTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, "bar") + _register_attribute(Foo, "bar") event.listen(Foo.bar, "set", canary) @@ -3462,7 +3414,7 @@ class ListenerTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, "bar", useobject=True) + _register_attribute(Foo, "bar", useobject=True) event.listen(Foo.bar, "set", canary) @@ -3482,7 +3434,7 @@ class ListenerTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, "bar", useobject=True, uselist=True) + _register_attribute(Foo, "bar", useobject=True, uselist=True) event.listen(Foo.bar, "set", canary) @@ -3622,17 +3574,17 @@ class EventPropagateTest(fixtures.TestBase): instrumentation.register_class(classes[3]) def attr_a(): - attributes.register_attribute( + _register_attribute( classes[0], "attrib", uselist=False, useobject=useobject ) def attr_b(): - attributes.register_attribute( + _register_attribute( classes[1], "attrib", uselist=False, useobject=useobject ) def attr_c(): - attributes.register_attribute( + _register_attribute( classes[2], "attrib", uselist=False, useobject=useobject ) @@ -3702,7 +3654,7 @@ class CollectionInitTest(fixtures.TestBase): self.B = B instrumentation.register_class(A) instrumentation.register_class(B) - attributes.register_attribute(A, "bs", uselist=True, useobject=True) + _register_attribute(A, "bs", uselist=True, useobject=True) def test_bulk_replace_resets_empty(self): A = self.A @@ -3761,7 +3713,7 @@ class TestUnlink(fixtures.TestBase): self.B = B instrumentation.register_class(A) instrumentation.register_class(B) - attributes.register_attribute(A, "bs", uselist=True, useobject=True) + _register_attribute(A, "bs", uselist=True, useobject=True) def test_expired(self): A, B = self.A, self.B diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 806d98a69b..f7e8ac9f68 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -28,6 +28,13 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +def _register_attribute(class_, key, **kw): + kw.setdefault("comparator", object()) + kw.setdefault("parententity", object()) + + return attributes.register_attribute(class_, key, **kw) + + class Canary: def __init__(self): self.data = set() @@ -123,7 +130,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -175,7 +182,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute( + _register_attribute( Foo, "attr", uselist=True, @@ -212,7 +219,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -457,7 +464,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -664,7 +671,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -706,7 +713,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -974,7 +981,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -1115,7 +1122,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -1176,7 +1183,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -1304,7 +1311,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -1534,7 +1541,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, @@ -1695,7 +1702,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - d = attributes.register_attribute( + d = _register_attribute( Foo, "attr", uselist=True, typecallable=Custom, useobject=True ) canary.listen(d) @@ -1769,9 +1776,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): canary = Canary() creator = self.entity_maker instrumentation.register_class(Foo) - d = attributes.register_attribute( - Foo, "attr", uselist=True, useobject=True - ) + d = _register_attribute(Foo, "attr", uselist=True, useobject=True) canary.listen(d) obj = Foo() @@ -2321,10 +2326,19 @@ class CustomCollectionsTest(fixtures.MappedTest): replaced = set([id(b) for b in list(f.bars.values())]) ne_(existing, replaced) - def test_list(self): - self._test_list(list) + @testing.combinations("direct", "as_callable", argnames="factory_type") + def test_list(self, factory_type): + if factory_type == "as_callable": + # test passing as callable - def test_list_no_setslice(self): + # this codepath likely was not working for many major + # versions, at least through 1.3 + self._test_list(lambda: []) + else: + self._test_list(list) + + @testing.combinations("direct", "as_callable", argnames="factory_type") + def test_list_no_setslice(self, factory_type): class ListLike: def __init__(self): self.data = list() @@ -2367,7 +2381,15 @@ class CustomCollectionsTest(fixtures.MappedTest): def __repr__(self): return "ListLike(%s)" % repr(self.data) - self._test_list(ListLike) + if factory_type == "as_callable": + # test passing as callable + + # this codepath likely was not working for many major + # versions, at least through 1.3 + + self._test_list(lambda: ListLike()) + else: + self._test_list(ListLike) def _test_list(self, listcls): someothertable, sometable = ( @@ -2598,9 +2620,7 @@ class InstrumentationTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute( - Foo, "attr", uselist=True, useobject=True - ) + _register_attribute(Foo, "attr", uselist=True, useobject=True) f1 = Foo() f1.attr.append(3) diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 587c498eae..a5506a2967 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -887,7 +887,13 @@ class InstrumentationTest(fixtures.ORMTest): instrumentation.register_class(Foo) attributes.register_attribute( - Foo, "attr", uselist=True, typecallable=MyDict, useobject=True + Foo, + "attr", + parententity=object(), + comparator=object(), + uselist=True, + typecallable=MyDict, + useobject=True, ) f = Foo() diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index 437129af16..99d5498d68 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -736,7 +736,14 @@ class MiscTest(fixtures.MappedTest): pass manager = instrumentation.register_class(A) - attributes.register_attribute(A, "x", uselist=False, useobject=False) + attributes.register_attribute( + A, + "x", + comparator=object(), + parententity=object(), + uselist=False, + useobject=False, + ) assert instrumentation.manager_of_class(A) is manager instrumentation.unregister_class(A) -- 2.47.2