From aeeff72e806420bf85e2e6723b1f941df38a3e1a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 15 Apr 2022 11:05:36 -0400 Subject: [PATCH] pep-484: ORM public API, constructors for the moment, abandoning using @overload with relationship() and mapped_column(). The overloads are very difficult to get working at all, and the overloads that were there all wouldn't pass on mypy. various techniques of getting them to "work", meaning having right hand side dictate what's legal on the left, have mixed success and wont give consistent results; additionally, it's legal to have Optional / non-optional independent of nullable in any case for columns. relationship cases are less ambiguous but mypy was not going along with things. we have a comprehensive system of allowing left side annotations to drive the right side, in the absense of explicit settings on the right. so type-centric SQLAlchemy will be left-side driven just like dataclasses, and the various flags and switches on the right side will just not be needed very much. in other matters, one surprise, forgot to remove string support from orm.join(A, B, "somename") or do deprecations for it in 1.4. This is a really not-directly-used structure barely mentioned in the docs for many years, the example shows a relationship being used, not a string, so we will just change it to raise the usual error here. Change-Id: Iefbbb8d34548b538023890ab8b7c9a5d9496ec6e --- lib/sqlalchemy/engine/cursor.py | 1 - lib/sqlalchemy/engine/util.py | 2 +- lib/sqlalchemy/ext/asyncio/scoping.py | 81 ++- lib/sqlalchemy/ext/hybrid.py | 20 +- lib/sqlalchemy/ext/instrumentation.py | 37 +- lib/sqlalchemy/inspection.py | 28 +- lib/sqlalchemy/orm/_orm_constructors.py | 558 ++++++--------- lib/sqlalchemy/orm/_typing.py | 51 +- lib/sqlalchemy/orm/attributes.py | 8 +- lib/sqlalchemy/orm/base.py | 77 ++- lib/sqlalchemy/orm/context.py | 44 +- lib/sqlalchemy/orm/decl_api.py | 8 +- lib/sqlalchemy/orm/decl_base.py | 19 +- lib/sqlalchemy/orm/descriptor_props.py | 6 +- lib/sqlalchemy/orm/events.py | 13 +- lib/sqlalchemy/orm/exc.py | 5 +- lib/sqlalchemy/orm/instrumentation.py | 81 ++- lib/sqlalchemy/orm/interfaces.py | 302 +++++--- lib/sqlalchemy/orm/loading.py | 7 +- lib/sqlalchemy/orm/mapper.py | 649 +++++++++++------- lib/sqlalchemy/orm/path_registry.py | 445 ++++++++---- lib/sqlalchemy/orm/properties.py | 67 +- lib/sqlalchemy/orm/query.py | 7 +- lib/sqlalchemy/orm/relationships.py | 114 +-- lib/sqlalchemy/orm/session.py | 21 +- lib/sqlalchemy/orm/strategies.py | 5 +- lib/sqlalchemy/orm/util.py | 539 +++++++++------ lib/sqlalchemy/sql/_elements_constructors.py | 7 +- lib/sqlalchemy/sql/_typing.py | 21 +- lib/sqlalchemy/sql/base.py | 17 +- lib/sqlalchemy/sql/coercions.py | 1 + lib/sqlalchemy/sql/compiler.py | 2 +- lib/sqlalchemy/sql/ddl.py | 3 +- lib/sqlalchemy/sql/elements.py | 37 +- lib/sqlalchemy/sql/lambdas.py | 4 +- lib/sqlalchemy/sql/roles.py | 11 +- lib/sqlalchemy/sql/schema.py | 135 +--- lib/sqlalchemy/sql/selectable.py | 4 +- lib/sqlalchemy/sql/util.py | 98 ++- lib/sqlalchemy/sql/visitors.py | 8 +- lib/sqlalchemy/testing/plugin/plugin_base.py | 16 - lib/sqlalchemy/util/_collections.py | 5 +- lib/sqlalchemy/util/_py_collections.py | 6 +- lib/sqlalchemy/util/deprecations.py | 17 +- lib/sqlalchemy/util/langhelpers.py | 2 + lib/sqlalchemy/util/preloaded.py | 2 + lib/sqlalchemy/util/typing.py | 8 +- pyproject.toml | 20 +- .../mypy/plain_files/association_proxy_one.py | 4 +- .../plain_files/experimental_relationship.py | 13 +- test/ext/mypy/plain_files/hybrid_one.py | 2 +- test/ext/mypy/plain_files/hybrid_two.py | 4 +- test/ext/mypy/plain_files/mapped_column.py | 32 +- test/ext/mypy/plain_files/sql_operations.py | 2 + .../plain_files/trad_relationship_uselist.py | 22 +- .../plain_files/traditional_relationship.py | 16 +- .../plugin_files/relationship_6255_one.py | 2 +- test/ext/mypy/plugin_files/typing_err3.py | 4 +- test/ext/test_extendedattr.py | 24 +- test/orm/inheritance/test_basic.py | 27 + test/orm/test_cascade.py | 42 +- test/orm/test_instrumentation.py | 9 +- test/orm/test_joins.py | 46 +- test/orm/test_mapper.py | 48 +- test/orm/test_options.py | 4 +- test/orm/test_query.py | 17 + test/orm/test_utils.py | 25 +- test/sql/test_selectable.py | 22 + 68 files changed, 2397 insertions(+), 1587 deletions(-) diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 72102ac264..ccf5736756 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1638,7 +1638,6 @@ class CursorResult(Result): :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial` """ # noqa: E501 - try: return self.context.rowcount except BaseException as e: diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 529b2ca73b..45f6bf20bf 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -48,7 +48,7 @@ def connection_memoize(key: str) -> Callable[[_C], _C]: connection.info[key] = val = fn(self, connection) return val - return decorated # type: ignore[return-value] + return decorated # type: ignore class _TConsSubject(Protocol): diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 33cf3f745a..c7a6e2ca01 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -76,7 +76,6 @@ if TYPE_CHECKING: "expunge", "expunge_all", "flush", - "get", "get_bind", "is_modified", "invalidate", @@ -204,6 +203,49 @@ class async_scoped_session: await self.registry().close() self.registry.clear() + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: + r"""Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + + """ # noqa: E501 + + # this was proxied but Mypy is requiring the return type to be + # clarified + + # work around: + # https://github.com/python/typing/discussions/1143 + return_value = await self._proxied.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + return return_value + # START PROXY METHODS async_scoped_session # code within this block is **programmatically, @@ -632,43 +674,6 @@ class async_scoped_session: return await self._proxied.flush(objects=objects) - async def get( - self, - entity: _EntityBindKey[_O], - ident: _PKIdentityArgument, - *, - options: Optional[Sequence[ORMOption]] = None, - populate_existing: bool = False, - with_for_update: Optional[ForUpdateArg] = None, - identity_token: Optional[Any] = None, - execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, - ) -> Optional[_O]: - r"""Return an instance based on the given primary key identifier, - or ``None`` if not found. - - .. container:: class_bases - - Proxied for the :class:`_asyncio.AsyncSession` class on - behalf of the :class:`_asyncio.scoping.async_scoped_session` class. - - .. seealso:: - - :meth:`_orm.Session.get` - main documentation for get - - - - """ # noqa: E501 - - return await self._proxied.get( - entity, - ident, - options=options, - populate_existing=populate_existing, - with_for_update=with_for_update, - identity_token=identity_token, - execution_options=execution_options, - ) - def get_bind( self, mapper: Optional[_EntityBindKey[_O]] = None, diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index be872804e2..7200414a18 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -832,6 +832,8 @@ from ..util.typing import Protocol if TYPE_CHECKING: + from ..orm._typing import _ORMColumnExprArgument + from ..orm.interfaces import MapperProperty from ..orm.util import AliasedInsp from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _DMLColumnArgument @@ -840,7 +842,6 @@ if TYPE_CHECKING: from ..sql.operators import OperatorType from ..sql.roles import ColumnsClauseRole - _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) _T_con = TypeVar("_T_con", bound=Any, contravariant=True) @@ -1289,7 +1290,7 @@ class Comparator(interfaces.PropComparator[_T]): ): self.expression = expression - def __clause_element__(self) -> ColumnsClauseRole: + def __clause_element__(self) -> _ORMColumnExprArgument[_T]: expr = self.expression if is_has_clause_element(expr): ret_expr = expr.__clause_element__() @@ -1298,10 +1299,15 @@ class Comparator(interfaces.PropComparator[_T]): assert isinstance(expr, ColumnElement) ret_expr = expr + if TYPE_CHECKING: + # see test_hybrid->test_expression_isnt_clause_element + # that exercises the usual place this is caught if not + # true + assert isinstance(ret_expr, ColumnElement) return ret_expr - @util.non_memoized_property - def property(self) -> Any: + @util.ro_non_memoized_property + def property(self) -> Optional[interfaces.MapperProperty[_T]]: return None def adapt_to_entity( @@ -1325,7 +1331,7 @@ class ExprComparator(Comparator[_T]): def __getattr__(self, key: str) -> Any: return getattr(self.expression, key) - @util.non_memoized_property + @util.ro_non_memoized_property def info(self) -> _InfoType: return self.hybrid.info @@ -1339,8 +1345,8 @@ class ExprComparator(Comparator[_T]): else: return [(self.expression, value)] - @util.non_memoized_property - def property(self) -> Any: + @util.ro_non_memoized_property + def property(self) -> Optional[MapperProperty[_T]]: return self.expression.property # type: ignore def operate( diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 72448fbdc8..b1138a4ad8 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -25,6 +25,7 @@ from ..orm import exc as orm_exc from ..orm import instrumentation as orm_instrumentation from ..orm.instrumentation import _default_dict_getter from ..orm.instrumentation import _default_manager_getter +from ..orm.instrumentation import _default_opt_manager_getter from ..orm.instrumentation import _default_state_getter from ..orm.instrumentation import ClassManager from ..orm.instrumentation import InstrumentationFactory @@ -140,7 +141,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): hierarchy = util.class_hierarchy(cls) factories = set() for member in hierarchy: - manager = self.manager_of_class(member) + manager = self.opt_manager_of_class(member) if manager is not None: factories.add(manager.factory) else: @@ -161,17 +162,34 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): del self._state_finders[class_] del self._dict_finders[class_] - def manager_of_class(self, cls): - if cls is None: - return None + def opt_manager_of_class(self, cls): try: - finder = self._manager_finders.get(cls, _default_manager_getter) + finder = self._manager_finders.get( + cls, _default_opt_manager_getter + ) except TypeError: # due to weakref lookup on invalid object return None else: return finder(cls) + def manager_of_class(self, cls): + try: + finder = self._manager_finders.get(cls, _default_manager_getter) + except TypeError: + # due to weakref lookup on invalid object + raise orm_exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) + else: + manager = finder(cls) + if manager is None: + raise orm_exc.UnmappedClassError( + cls, + f"Can't locate an instrumentation manager for class {cls}", + ) + return manager + def state_of(self, instance): if instance is None: raise AttributeError("None has no persistent state.") @@ -384,6 +402,7 @@ def _install_instrumented_lookups(): instance_state=_instrumentation_factory.state_of, instance_dict=_instrumentation_factory.dict_of, manager_of_class=_instrumentation_factory.manager_of_class, + opt_manager_of_class=_instrumentation_factory.opt_manager_of_class, ) ) @@ -395,16 +414,19 @@ def _reinstall_default_lookups(): instance_state=_default_state_getter, instance_dict=_default_dict_getter, manager_of_class=_default_manager_getter, + opt_manager_of_class=_default_opt_manager_getter, ) ) _instrumentation_factory._extended = False def _install_lookups(lookups): - global instance_state, instance_dict, manager_of_class + global instance_state, instance_dict + global manager_of_class, opt_manager_of_class instance_state = lookups["instance_state"] instance_dict = lookups["instance_dict"] manager_of_class = lookups["manager_of_class"] + opt_manager_of_class = lookups["opt_manager_of_class"] orm_base.instance_state = ( attributes.instance_state ) = orm_instrumentation.instance_state = instance_state @@ -414,3 +436,6 @@ def _install_lookups(lookups): orm_base.manager_of_class = ( attributes.manager_of_class ) = orm_instrumentation.manager_of_class = manager_of_class + orm_base.opt_manager_of_class = ( + attributes.opt_manager_of_class + ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index 6b06c0d6b6..01c740fe41 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -34,6 +34,7 @@ from typing import Any from typing import Callable from typing import Dict from typing import Generic +from typing import Optional from typing import overload from typing import Type from typing import TypeVar @@ -43,6 +44,9 @@ from . import exc from .util.typing import Literal _T = TypeVar("_T", bound=Any) +_F = TypeVar("_F", bound=Callable[..., Any]) + +_IN = TypeVar("_IN", bound="Inspectable[Any]") _registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {} @@ -53,11 +57,22 @@ class Inspectable(Generic[_T]): This allows typing to set up a linkage between an object that can be inspected and the type of inspection it returns. + Unfortunately we cannot at the moment get all classes that are + returned by inspection to suit this interface as we get into + MRO issues. + """ + __slots__ = () + @overload -def inspect(subject: Inspectable[_T], raiseerr: bool = True) -> _T: +def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: + ... + + +@overload +def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: ... @@ -108,9 +123,9 @@ def inspect(subject: Any, raiseerr: bool = True) -> Any: def _inspects( - *types: type, -) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]: - def decorate(fn_or_cls: Callable[[Any], Any]) -> Callable[[Any], Any]: + *types: Type[Any], +) -> Callable[[_F], _F]: + def decorate(fn_or_cls: _F) -> _F: for type_ in types: if type_ in _registrars: raise AssertionError( @@ -122,7 +137,10 @@ def _inspects( return decorate -def _self_inspects(cls: Type[_T]) -> Type[_T]: +_TT = TypeVar("_TT", bound="Type[Any]") + + +def _self_inspects(cls: _TT) -> _TT: if cls in _registrars: raise AssertionError("Type %s is already " "registered" % cls) _registrars[cls] = True diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 7690c05dec..457ad5c5a6 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -9,20 +9,19 @@ from __future__ import annotations import typing from typing import Any +from typing import Callable from typing import Collection -from typing import Dict -from typing import List from typing import Optional from typing import overload -from typing import Set from typing import Type +from typing import TYPE_CHECKING from typing import Union -from . import mapper as mapperlib +from . import mapperlib as mapperlib +from ._typing import _O from .base import Mapped from .descriptor_props import Composite from .descriptor_props import Synonym -from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn from .query import AliasOption @@ -37,11 +36,29 @@ from .. import sql from .. import util from ..exc import InvalidRequestError from ..sql.base import SchemaEventTarget -from ..sql.selectable import Alias +from ..sql.schema import SchemaConst from ..sql.selectable import FromClause -from ..sql.type_api import TypeEngine from ..util.typing import Literal +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ORMColumnExprArgument + from .descriptor_props import _CompositeAttrType + from .interfaces import PropComparator + from .query import Query + from .relationships import _LazyLoadArgumentType + from .relationships import _ORMBackrefArgument + from .relationships import _ORMColCollectionArgument + from .relationships import _ORMOrderByArgument + from .relationships import _RelationshipJoinConditionArgument + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType + from ..sql._typing import _TypeEngineArgument + from ..sql.schema import _ServerDefaultType + from ..sql.schema import FetchedValue + from ..sql.selectable import Alias + from ..sql.selectable import Subquery + _T = typing.TypeVar("_T") @@ -61,7 +78,7 @@ SynonymProperty = Synonym "for entities to be matched up to a query that is established " "via :meth:`.Query.from_statement` and now does nothing.", ) -def contains_alias(alias) -> AliasOption: +def contains_alias(alias: Union[Alias, Subquery]) -> AliasOption: r"""Return a :class:`.MapperOption` that will indicate to the :class:`_query.Query` that the main table has been aliased. @@ -70,134 +87,36 @@ def contains_alias(alias) -> AliasOption: return AliasOption(alias) -# see test/ext/mypy/plain_files/mapped_column.py for mapped column -# typing tests - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[None] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[None] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[True] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Optional[_T]]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[True] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Optional[_T]]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[False] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[False] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: Literal[True] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: Literal[True] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __name: str, - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload def mapped_column( + __name_pos: Optional[ + Union[str, _TypeEngineArgument[Any], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[Any], SchemaEventTarget] + ] = None, *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = SchemaConst.NULL_UNSPECIFIED, + primary_key: Optional[bool] = False, + deferred: bool = False, + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", + default: Optional[Any] = None, + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + onupdate: Optional[Any] = None, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + **dialect_kwargs: Any, +) -> MappedColumn[Any]: r"""construct a new ORM-mapped :class:`_schema.Column` construct. The :func:`_orm.mapped_column` function provides an ORM-aware and @@ -363,12 +282,45 @@ def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": """ - return MappedColumn(*args, **kw) + return MappedColumn( + __name_pos, + __type_pos, + *args, + name=name, + type_=type_, + autoincrement=autoincrement, + default=default, + doc=doc, + key=key, + index=index, + unique=unique, + info=info, + nullable=nullable, + onupdate=onupdate, + primary_key=primary_key, + server_default=server_default, + server_onupdate=server_onupdate, + quote=quote, + comment=comment, + system=system, + deferred=deferred, + **dialect_kwargs, + ) def column_property( - column: sql.ColumnElement[_T], *additional_columns, **kwargs -) -> "ColumnProperty[_T]": + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + descriptor: Optional[Any] = None, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> ColumnProperty[_T]: r"""Provide a column-level property for use with a mapping. Column-based properties can normally be applied to the mapper's @@ -452,13 +404,25 @@ def column_property( expressions """ - return ColumnProperty(column, *additional_columns, **kwargs) + return ColumnProperty( + column, + *additional_columns, + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + descriptor=descriptor, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) @overload def composite( class_: Type[_T], - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[_T]: ... @@ -466,7 +430,7 @@ def composite( @overload def composite( - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[Any]: ... @@ -474,7 +438,7 @@ def composite( def composite( class_: Any = None, - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[Any]: r"""Return a composite column-based property for use with a Mapper. @@ -529,13 +493,13 @@ def composite( def with_loader_criteria( - entity_or_base, - where_criteria, - loader_only=False, - include_aliases=False, - propagate_to_loaders=True, - track_closure_variables=True, -) -> "LoaderCriteriaOption": + entity_or_base: _EntityType[Any], + where_criteria: _ColumnExpressionArgument[bool], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, +) -> LoaderCriteriaOption: """Add additional WHERE criteria to the load for all occurrences of a particular entity. @@ -711,180 +675,40 @@ def with_loader_criteria( ) -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Type[Set] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Set[Any]]: - ... - - -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Type[List] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[Any]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Literal[False] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[_T]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Literal[True] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[List] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[Set] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Set[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[Dict[Any, Any]] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Dict[Any, _T]]: - ... - - -@overload -def relationship( - argument: _RelationshipArgumentType[_T], - secondary=..., - *, - uselist: Literal[None] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=None, - back_populates=None, - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]] = ..., - secondary=..., - *, - uselist: Literal[True] = ..., - collection_class: Any = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload def relationship( - argument: Literal[None] = ..., - secondary=..., - *, - uselist: Optional[bool] = ..., - collection_class: Any = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -def relationship( - argument: Optional[_RelationshipArgumentType[_T]] = None, - secondary=None, + argument: Optional[_RelationshipArgumentType[Any]] = None, + secondary: Optional[FromClause] = None, *, uselist: Optional[bool] = None, - collection_class: Optional[Type[Collection]] = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[_ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, + lazy: _LazyLoadArgumentType = "select", + passive_deletes: bool = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[Type[PropComparator[Any]]] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, **kw: Any, ) -> Relationship[Any]: """Provide a relationship between two mapped classes. @@ -1098,13 +922,6 @@ def relationship( :ref:`error_qzyx` - usage example - :param bake_queries=True: - Legacy parameter, not used. - - .. versionchanged:: 1.4.23 the "lambda caching" system is no longer - used by loader strategies and the ``bake_queries`` parameter - has no effect. - :param cascade: A comma-separated list of cascade rules which determines how Session operations should be "cascaded" from parent to child. @@ -1701,18 +1518,42 @@ def relationship( primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, back_populates=back_populates, + order_by=order_by, + backref=backref, + overlaps=overlaps, + post_update=post_update, + cascade=cascade, + viewonly=viewonly, + lazy=lazy, + passive_deletes=passive_deletes, + passive_updates=passive_updates, + active_history=active_history, + enable_typechecks=enable_typechecks, + foreign_keys=foreign_keys, + remote_side=remote_side, + join_depth=join_depth, + comparator_factory=comparator_factory, + single_parent=single_parent, + innerjoin=innerjoin, + distinct_target_key=distinct_target_key, + load_on_pending=load_on_pending, + query_class=query_class, + info=info, + omit_join=omit_join, + sync_backref=sync_backref, **kw, ) def synonym( - name, - map_column=None, - descriptor=None, - comparator_factory=None, - doc=None, - info=None, -) -> "Synonym[Any]": + name: str, + *, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> Synonym[Any]: """Denote an attribute name as a synonym to a mapped property, in that the attribute will mirror the value and expression behavior of another attribute. @@ -1951,8 +1792,8 @@ def deferred(*columns, **kw): def query_expression( - default_expr: sql.ColumnElement[_T] = sql.null(), -) -> "Mapped[_T]": + default_expr: _ORMColumnExprArgument[_T] = sql.null(), +) -> Mapped[_T]: """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -2010,33 +1851,33 @@ def clear_mappers(): @overload def aliased( - element: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> "AliasedClass[_T]": + element: _EntityType[_O], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedClass[_O]: ... @overload def aliased( - element: "FromClause", - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> "Alias": + element: FromClause, + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> FromClause: ... def aliased( - element: Union[Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> Union["AliasedClass[_T]", "Alias"]: + element: Union[_EntityType[_O], FromClause], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> Union[AliasedClass[_O], FromClause]: """Produce an alias of the given element, usually an :class:`.AliasedClass` instance. @@ -2233,9 +2074,7 @@ def with_polymorphic( ) -def join( - left, right, onclause=None, isouter=False, full=False, join_to_left=None -): +def join(left, right, onclause=None, isouter=False, full=False): r"""Produce an inner join between left and right clauses. :func:`_orm.join` is an extension to the core join interface @@ -2270,16 +2109,11 @@ def join( See :ref:`orm_queryguide_joins` for information on modern usage of ORM level joins. - .. deprecated:: 0.8 - - the ``join_to_left`` parameter is deprecated, and will be removed - in a future release. The parameter has no effect. - """ return _ORMJoin(left, right, onclause, isouter, full) -def outerjoin(left, right, onclause=None, full=False, join_to_left=None): +def outerjoin(left, right, onclause=None, full=False): """Produce a left outer join between left and right clauses. This is the "outer join" version of the :func:`_orm.join` function, diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 4250cdbe1f..339844f147 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -2,6 +2,7 @@ from __future__ import annotations import operator from typing import Any +from typing import Callable from typing import Dict from typing import Optional from typing import Tuple @@ -10,7 +11,9 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.orm.interfaces import UserDefinedOption +from ..sql import roles +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnElement from ..util.typing import Protocol from ..util.typing import TypeGuard @@ -18,8 +21,12 @@ if TYPE_CHECKING: from .attributes import AttributeImpl from .attributes import CollectionAttributeImpl from .base import PassiveFlag + from .decl_api import registry as _registry_type from .descriptor_props import _CompositeClassProto + from .interfaces import MapperProperty + from .interfaces import UserDefinedOption from .mapper import Mapper + from .relationships import Relationship from .state import InstanceState from .util import AliasedClass from .util import AliasedInsp @@ -27,21 +34,39 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) + +# I would have preferred this were bound=object however it seems +# to not travel in all situations when defined in that way. _O = TypeVar("_O", bound=Any) """The 'ORM mapped object' type. -I would have preferred this were bound=object however it seems -to not travel in all situations when defined in that way. + """ +if TYPE_CHECKING: + _RegistryType = _registry_type + _InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"] -_EntityType = Union[_T, "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"] +_EntityType = Union[ + Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]" +] _InstanceDict = Dict[str, Any] _IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]] +_ORMColumnExprArgument = Union[ + ColumnElement[_T], + _HasClauseElement, + roles.ExpressionElementRole[_T], +] + +# somehow Protocol didn't want to work for this one +_ORMAdapterProto = Callable[ + [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T] +] + class _LoaderCallable(Protocol): def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any: @@ -60,10 +85,28 @@ def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]: if TYPE_CHECKING: + def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]: + ... + + def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: + ... + + def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: + ... + + def prop_is_relationship( + prop: MapperProperty[Any], + ) -> TypeGuard[Relationship[Any]]: + ... + def is_collection_impl( impl: AttributeImpl, ) -> TypeGuard[CollectionAttributeImpl]: ... else: + insp_is_mapper_property = operator.attrgetter("is_property") + insp_is_mapper = operator.attrgetter("is_mapper") + insp_is_aliased_class = operator.attrgetter("is_aliased_class") is_collection_impl = operator.attrgetter("collection") + prop_is_relationship = operator.attrgetter("_is_relationship") diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 33ce96a192..41d944c57d 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -44,7 +44,7 @@ from .base import instance_dict as instance_dict from .base import instance_state as instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED -from .base import manager_of_class +from .base import manager_of_class as manager_of_class from .base import Mapped as Mapped # noqa from .base import NEVER_SET # noqa from .base import NO_AUTOFLUSH @@ -52,6 +52,7 @@ from .base import NO_CHANGE # noqa from .base import NO_RAISE from .base import NO_VALUE from .base import NON_PERSISTENT_OK # noqa +from .base import opt_manager_of_class as opt_manager_of_class from .base import PASSIVE_CLASS_MISMATCH # noqa from .base import PASSIVE_NO_FETCH from .base import PASSIVE_NO_FETCH_RELATED # noqa @@ -74,6 +75,7 @@ from ..sql import traversals from ..sql import visitors if TYPE_CHECKING: + from .interfaces import MapperProperty from .state import InstanceState from ..sql.dml import _DMLColumnElement from ..sql.elements import ColumnElement @@ -146,7 +148,7 @@ class QueryableAttribute( self._of_type = of_type self._extra_criteria = extra_criteria - manager = manager_of_class(class_) + manager = opt_manager_of_class(class_) # manager is None in the case of AliasedClass if manager: # propagate existing event listeners from @@ -370,7 +372,7 @@ class QueryableAttribute( return "%s.%s" % (self.class_.__name__, self.key) @util.memoized_property - def property(self): + def property(self) -> MapperProperty[_T]: """Return the :class:`.MapperProperty` associated with this :class:`.QueryableAttribute`. diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 3fa855a4bd..054d52d83b 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -26,24 +26,25 @@ from typing import TypeVar from typing import Union from . import exc +from ._typing import insp_is_mapper from .. import exc as sa_exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag from ..util.langhelpers import TypingOnly -from ..util.typing import Concatenate from ..util.typing import Literal -from ..util.typing import ParamSpec from ..util.typing import Self if typing.TYPE_CHECKING: from ._typing import _InternalEntityType from .attributes import InstrumentedAttribute + from .instrumentation import ClassManager from .mapper import Mapper from .state import InstanceState from ..sql._typing import _InfoType + _T = TypeVar("_T", bound=Any) _O = TypeVar("_O", bound=object) @@ -246,21 +247,15 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_Fn = TypeVar("_Fn", bound=Callable) -_Args = ParamSpec("_Args") +_F = TypeVar("_F", bound=Callable) _Self = TypeVar("_Self") def _assertions( *assertions: Any, -) -> Callable[ - [Callable[Concatenate[_Self, _Fn, _Args], _Self]], - Callable[Concatenate[_Self, _Fn, _Args], _Self], -]: +) -> Callable[[_F], _F]: @util.decorator - def generate( - fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs - ) -> _Self: + def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self: for assertion in assertions: assertion(self, fn.__name__) fn(self, *args, **kw) @@ -269,13 +264,13 @@ def _assertions( return generate -# these can be replaced by sqlalchemy.ext.instrumentation -# if augmented class instrumentation is enabled. -def manager_of_class(cls): - return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None) +if TYPE_CHECKING: + def manager_of_class(cls: Type[Any]) -> ClassManager: + ... -if TYPE_CHECKING: + def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]: + ... def instance_state(instance: _O) -> InstanceState[_O]: ... @@ -284,6 +279,20 @@ if TYPE_CHECKING: ... else: + # these can be replaced by sqlalchemy.ext.instrumentation + # if augmented class instrumentation is enabled. + + def manager_of_class(cls): + try: + return cls.__dict__[DEFAULT_MANAGER_ATTR] + except KeyError as ke: + raise exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) from ke + + def opt_manager_of_class(cls): + return cls.__dict__.get(DEFAULT_MANAGER_ATTR) + instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) instance_dict = operator.attrgetter("__dict__") @@ -458,11 +467,12 @@ else: _state_mapper = util.dottedgetter("manager.mapper") -@inspection._inspects(type) -def _inspect_mapped_class(class_, configure=False): +def _inspect_mapped_class( + class_: Type[_O], configure: bool = False +) -> Optional[Mapper[_O]]: try: - class_manager = manager_of_class(class_) - if not class_manager.is_mapped: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: return None mapper = class_manager.mapper except exc.NO_STATE: @@ -473,7 +483,28 @@ def _inspect_mapped_class(class_, configure=False): return mapper -def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: +@inspection._inspects(type) +def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]: + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + return None + else: + return mapper + + +def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]: + insp = inspection.inspect(arg, raiseerr=False) + if insp_is_mapper(insp): + return insp + + raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}") + + +def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]: """Given a class, return the primary :class:`_orm.Mapper` associated with the key. @@ -502,8 +533,8 @@ def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: class InspectionAttr: - """A base class applied to all ORM objects that can be returned - by the :func:`_sa.inspect` function. + """A base class applied to all ORM objects and attributes that are + related to things that can be returned by the :func:`_sa.inspect` function. The attributes defined here allow the usage of simple boolean checks to test basic facts about the object returned. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 419da65f7f..4fee2d383d 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -63,11 +63,14 @@ from ..sql.visitors import InternalTraversal if TYPE_CHECKING: from ._typing import _InternalEntityType + from .mapper import Mapper + from .query import Query from ..sql.compiler import _CompilerStackEntry from ..sql.dml import _DMLTableElement from ..sql.elements import ColumnElement from ..sql.selectable import _LabelConventionCallable from ..sql.selectable import SelectBase + from ..sql.type_api import TypeEngine _path_registry = PathRegistry.root @@ -211,6 +214,9 @@ class ORMCompileState(CompileState): _for_refresh_state = False _render_for_subquery = False + attributes: Dict[Any, Any] + global_attributes: Dict[Any, Any] + statement: Union[Select, FromStatement] select_statement: Union[Select, FromStatement] _entities: List[_QueryEntity] @@ -1930,7 +1936,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): assert right_mapper adapter = ORMAdapter( - right, equivalents=right_mapper._equivalent_columns + inspect(right), equivalents=right_mapper._equivalent_columns ) # if an alias() on the right side was generated, @@ -2075,14 +2081,16 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _column_descriptions( - query_or_select_stmt, compile_state=None, legacy=False + query_or_select_stmt: Union[Query, Select, FromStatement], + compile_state: Optional[ORMSelectCompileState] = None, + legacy: bool = False, ) -> List[ORMColumnDescription]: if compile_state is None: compile_state = ORMSelectCompileState._create_entities_collection( query_or_select_stmt, legacy=legacy ) ctx = compile_state - return [ + d = [ { "name": ent._label_name, "type": ent.type, @@ -2093,17 +2101,10 @@ def _column_descriptions( else None, } for ent, insp_ent in [ - ( - _ent, - ( - inspect(_ent.entity_zero) - if _ent.entity_zero is not None - else None - ), - ) - for _ent in ctx._entities + (_ent, _ent.entity_zero) for _ent in ctx._entities ] ] + return d def _legacy_filter_by_entity_zero(query_or_augmented_select): @@ -2157,6 +2158,11 @@ class _QueryEntity: _null_column_type = False use_id_for_hash = False + _label_name: Optional[str] + type: Union[Type[Any], TypeEngine[Any]] + expr: Union[_InternalEntityType, ColumnElement[Any]] + entity_zero: Optional[_InternalEntityType] + def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() @@ -2234,6 +2240,13 @@ class _MapperEntity(_QueryEntity): "_polymorphic_discriminator", ) + expr: _InternalEntityType + mapper: Mapper[Any] + entity_zero: _InternalEntityType + is_aliased_class: bool + path: PathRegistry + _label_name: str + def __init__( self, compile_state, entity, entities_collection, is_current_entities ): @@ -2389,6 +2402,13 @@ class _BundleEntity(_QueryEntity): "supports_single_entity", ) + _entities: List[_QueryEntity] + bundle: Bundle + type: Type[Any] + _label_name: str + supports_single_entity: bool + expr: Bundle + def __init__( self, compile_state, diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 70507015bc..0c990f8099 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -50,7 +50,7 @@ from ..util import hybridproperty from ..util import typing as compat_typing if typing.TYPE_CHECKING: - from .state import InstanceState # noqa + from .state import InstanceState _T = TypeVar("_T", bound=Any) @@ -280,7 +280,7 @@ class declared_attr(interfaces._MappedAttribute[_T]): # for the span of the declarative scan_attributes() phase. # to achieve this we look at the class manager that's configured. cls = owner - manager = attributes.manager_of_class(cls) + manager = attributes.opt_manager_of_class(cls) if manager is None: if not re.match(r"^__.+__$", self.fget.__name__): # if there is no manager at all, then this class hasn't been @@ -1294,8 +1294,8 @@ def as_declarative(**kw): @inspection._inspects( DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept ) -def _inspect_decl_meta(cls): - mp = _inspect_mapped_class(cls) +def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]: + mp: Mapper[Any] = _inspect_mapped_class(cls) if mp is None: if _DeferredMapperConfig.has_cls(cls): _DeferredMapperConfig.raise_unmapped_for_cls(cls) diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 804d05ce19..9c79a4172b 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -12,6 +12,8 @@ import collections from typing import Any from typing import Dict from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING import weakref from . import attributes @@ -42,6 +44,10 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +if TYPE_CHECKING: + from ._typing import _O + from ._typing import _RegistryType + def _declared_mapping_info(cls): # deferred mapping @@ -121,7 +127,7 @@ def _dive_for_cls_manager(cls): return None for base in cls.__mro__: - manager = attributes.manager_of_class(base) + manager = attributes.opt_manager_of_class(base) if manager: return manager return None @@ -171,7 +177,7 @@ class _MapperConfig: @classmethod def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw): - manager = attributes.manager_of_class(cls) + manager = attributes.opt_manager_of_class(cls) if manager and manager.class_ is cls_: raise exc.InvalidRequestError( "Class %r already has been " "instrumented declaratively" % cls @@ -191,7 +197,12 @@ class _MapperConfig: return cfg_cls(registry, cls_, dict_, table, mapper_kw) - def __init__(self, registry, cls_, mapper_kw): + def __init__( + self, + registry: _RegistryType, + cls_: Type[Any], + mapper_kw: Dict[str, Any], + ): self.cls = util.assert_arg_type(cls_, type, "cls_") self.classname = cls_.__name__ self.properties = util.OrderedDict() @@ -206,7 +217,7 @@ class _MapperConfig: init_method=registry.constructor, ) else: - manager = attributes.manager_of_class(self.cls) + manager = attributes.opt_manager_of_class(self.cls) if not manager or not manager.is_mapped: raise exc.InvalidRequestError( "Class %s has no primary mapper configured. Configure " diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 8beac472e3..4738d8c2c9 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -122,7 +122,11 @@ class DescriptorProperty(MapperProperty[_T]): _CompositeAttrType = Union[ - str, "Column[Any]", "MappedColumn[Any]", "InstrumentedAttribute[Any]" + str, + "Column[_T]", + "MappedColumn[_T]", + "InstrumentedAttribute[_T]", + "Mapped[_T]", ] diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index c531e7cf19..331c224eef 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -11,6 +11,9 @@ from __future__ import annotations from typing import Any +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING import weakref from . import instrumentation @@ -27,6 +30,10 @@ from .. import exc from .. import util from ..util.compat import inspect_getfullargspec +if TYPE_CHECKING: + from ._typing import _O + from .instrumentation import ClassManager + class InstrumentationEvents(event.Events): """Events related to class instrumentation events. @@ -214,7 +221,7 @@ class InstanceEvents(event.Events): if issubclass(target, mapperlib.Mapper): return instrumentation.ClassManager else: - manager = instrumentation.manager_of_class(target) + manager = instrumentation.opt_manager_of_class(target) if manager: return manager else: @@ -613,8 +620,8 @@ class _EventsHold(event.RefCollection): class _InstanceEventsHold(_EventsHold): all_holds = weakref.WeakKeyDictionary() - def resolve(self, class_): - return instrumentation.manager_of_class(class_) + def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: + return instrumentation.opt_manager_of_class(class_) class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents): pass diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index 00829ecbb7..529a7cd01f 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -203,7 +203,10 @@ def _default_unmapped(cls) -> Optional[str]: try: mappers = base.manager_of_class(cls).mappers - except (TypeError,) + NO_STATE: + except ( + UnmappedClassError, + TypeError, + ) + NO_STATE: mappers = {} name = _safe_cls_name(cls) diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 0d4b630dad..88ceacd076 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -33,10 +33,13 @@ alternate instrumentation forms. from __future__ import annotations from typing import Any +from typing import Callable from typing import Dict from typing import Generic from typing import Optional from typing import Set +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar import weakref @@ -53,7 +56,9 @@ from ..util import HasMemoized from ..util.typing import Protocol if TYPE_CHECKING: + from ._typing import _RegistryType from .attributes import InstrumentedAttribute + from .decl_base import _MapperConfig from .mapper import Mapper from .state import InstanceState from ..event import dispatcher @@ -72,6 +77,11 @@ class _ExpiredAttributeLoaderProto(Protocol): ... +class _ManagerFactory(Protocol): + def __call__(self, class_: Type[_O]) -> ClassManager[_O]: + ... + + class ClassManager( HasMemoized, Dict[str, "InstrumentedAttribute[Any]"], @@ -90,12 +100,12 @@ class ClassManager( expired_attribute_loader: _ExpiredAttributeLoaderProto "previously known as deferred_scalar_loader" - init_method = None + init_method: Optional[Callable[..., None]] - factory = None + factory: Optional[_ManagerFactory] - declarative_scan = None - registry = None + declarative_scan: Optional[weakref.ref[_MapperConfig]] = None + registry: Optional[_RegistryType] = None @property @util.deprecated( @@ -122,11 +132,13 @@ class ClassManager( self.local_attrs = {} self.originals = {} self._finalized = False + self.factory = None + self.init_method = None self._bases = [ mgr for mgr in [ - manager_of_class(base) + opt_manager_of_class(base) for base in self.class_.__bases__ if isinstance(base, type) ] @@ -139,7 +151,7 @@ class ClassManager( self.dispatch._events._new_classmanager_instance(class_, self) for basecls in class_.__mro__: - mgr = manager_of_class(basecls) + mgr = opt_manager_of_class(basecls) if mgr is not None: self.dispatch._update(mgr.dispatch) @@ -155,16 +167,18 @@ class ClassManager( def _update_state( self, - finalize=False, - mapper=None, - registry=None, - declarative_scan=None, - expired_attribute_loader=None, - init_method=None, - ): + finalize: bool = False, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[ + _ExpiredAttributeLoaderProto + ] = None, + init_method: Optional[Callable[..., None]] = None, + ) -> None: if mapper: - self.mapper = mapper + self.mapper = mapper # type: ignore[assignment] if registry: registry._add_manager(self) if declarative_scan: @@ -350,7 +364,7 @@ class ClassManager( def subclass_managers(self, recursive): for cls in self.class_.__subclasses__(): - mgr = manager_of_class(cls) + mgr = opt_manager_of_class(cls) if mgr is not None and mgr is not self: yield mgr if recursive: @@ -374,7 +388,7 @@ class ClassManager( self._reset_memoizations() del self[key] for cls in self.class_.__subclasses__(): - manager = manager_of_class(cls) + manager = opt_manager_of_class(cls) if manager: manager.uninstrument_attribute(key, True) @@ -523,7 +537,7 @@ class _SerializeManager: manager.dispatch.pickle(state, d) def __call__(self, state, inst, state_dict): - state.manager = manager = manager_of_class(self.class_) + state.manager = manager = opt_manager_of_class(self.class_) if manager is None: raise exc.UnmappedInstanceError( inst, @@ -546,9 +560,9 @@ class _SerializeManager: class InstrumentationFactory: """Factory for new ClassManager instances.""" - def create_manager_for_cls(self, class_): + def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]: assert class_ is not None - assert manager_of_class(class_) is None + assert opt_manager_of_class(class_) is None # give a more complicated subclass # a chance to do what it wants here @@ -557,6 +571,8 @@ class InstrumentationFactory: if factory is None: factory = ClassManager manager = factory(class_) + else: + assert manager is not None self._check_conflicts(class_, factory) @@ -564,11 +580,15 @@ class InstrumentationFactory: return manager - def _locate_extended_factory(self, class_): + def _locate_extended_factory( + self, class_: Type[_O] + ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]: """Overridden by a subclass to do an extended lookup.""" return None, None - def _check_conflicts(self, class_, factory): + def _check_conflicts( + self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]] + ): """Overridden by a subclass to test for conflicting factories.""" return @@ -590,24 +610,25 @@ instance_state = _default_state_getter = base.instance_state instance_dict = _default_dict_getter = base.instance_dict manager_of_class = _default_manager_getter = base.manager_of_class +opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class def register_class( - class_, - finalize=True, - mapper=None, - registry=None, - declarative_scan=None, - expired_attribute_loader=None, - init_method=None, -): + class_: Type[_O], + finalize: bool = True, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None, + init_method: Optional[Callable[..., None]] = None, +) -> ClassManager[_O]: """Register class instrumentation. Returns the existing or newly created class manager. """ - manager = manager_of_class(class_) + manager = opt_manager_of_class(class_) if manager is None: manager = _instrumentation_factory.create_manager_for_cls(class_) manager._update_state( diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index abc1300d8d..0ca62b7e35 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -21,10 +21,15 @@ from __future__ import annotations import collections import typing from typing import Any +from typing import Callable from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Iterator from typing import List from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TypeVar @@ -45,7 +50,6 @@ from .base import NotExtension as NotExtension from .base import ONETOMANY as ONETOMANY from .base import SQLORMOperations from .. import ColumnElement -from .. import inspect from .. import inspection from .. import util from ..sql import operators @@ -53,19 +57,47 @@ from ..sql import roles from ..sql import visitors from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey -from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.type_api import TypeEngine from ..util.typing import TypedDict + if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _ORMAdapterProto + from ._typing import _ORMColumnExprArgument + from .attributes import InstrumentedAttribute + from .context import _MapperEntity + from .context import ORMCompileState from .decl_api import RegistryType + from .loading import _PopulatorDict + from .mapper import Mapper + from .path_registry import AbstractEntityRegistry + from .path_registry import PathRegistry + from .query import Query + from .session import Session + from .state import InstanceState + from .strategy_options import _LoadElement + from .util import AliasedInsp + from .util import CascadeOptions + from .util import ORMAdapter + from ..engine.result import Result + from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _DMLColumnArgument from ..sql._typing import _InfoType + from ..sql._typing import _PropagateAttrsType + from ..sql.operators import OperatorType + from ..sql.util import ColumnAdapter + from ..sql.visitors import _TraverseInternalsType _T = TypeVar("_T", bound=Any) +_TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]") + class ORMStatementRole(roles.StatementRole): __slots__ = () @@ -91,7 +123,9 @@ class ORMFromClauseRole(roles.StrictFromClauseRole): class ORMColumnDescription(TypedDict): name: str - type: Union[Type, TypeEngine] + # TODO: add python_type and sql_type here; combining them + # into "type" is a bad idea + type: Union[Type[Any], TypeEngine[Any]] aliased: bool expr: _ColumnsClauseArgument entity: Optional[_ColumnsClauseArgument] @@ -102,10 +136,10 @@ class _IntrospectsAnnotations: def declarative_scan( self, - registry: "RegistryType", - cls: type, + registry: RegistryType, + cls: Type[Any], key: str, - annotation: Optional[type], + annotation: Optional[Type[Any]], is_dataclass_field: Optional[bool], ) -> None: """Perform class-specific initializaton at early declarative scanning @@ -124,12 +158,12 @@ class _MapsColumns(_MappedAttribute[_T]): __slots__ = () @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: """return a MapperProperty to be assigned to the declarative mapping""" raise NotImplementedError() @property - def columns_to_assign(self) -> List[Column]: + def columns_to_assign(self) -> List[Column[_T]]: """A list of Column objects that should be declaratively added to the new Table object. @@ -139,7 +173,10 @@ class _MapsColumns(_MappedAttribute[_T]): @inspection._self_inspects class MapperProperty( - HasCacheKey, _MappedAttribute[_T], InspectionAttr, util.MemoizedSlots + HasCacheKey, + _MappedAttribute[_T], + InspectionAttrInfo, + util.MemoizedSlots, ): """Represent a particular class attribute mapped by :class:`_orm.Mapper`. @@ -160,12 +197,12 @@ class MapperProperty( "info", ) - _cache_key_traversal = [ + _cache_key_traversal: _TraverseInternalsType = [ ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key), ("key", visitors.ExtendedInternalTraversal.dp_string), ] - cascade = frozenset() + cascade: Optional[CascadeOptions] = None """The set of 'cascade' attribute names. This collection is checked before the 'cascade_iterator' method is called. @@ -184,14 +221,20 @@ class MapperProperty( """The :class:`_orm.PropComparator` instance that implements SQL expression construction on behalf of this mapped attribute.""" - @property - def _links_to_entity(self): - """True if this MapperProperty refers to a mapped entity. + key: str + """name of class attribute""" - Should only be True for Relationship, False for all others. + parent: Mapper[Any] + """the :class:`.Mapper` managing this property.""" - """ - raise NotImplementedError() + _is_relationship = False + + _links_to_entity: bool + """True if this MapperProperty refers to a mapped entity. + + Should only be True for Relationship, False for all others. + + """ def _memoized_attr_info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined @@ -217,7 +260,14 @@ class MapperProperty( """ return {} - def setup(self, context, query_entity, path, adapter, **kwargs): + def setup( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: PathRegistry, + adapter: Optional[ColumnAdapter], + **kwargs: Any, + ) -> None: """Called by Query for the purposes of constructing a SQL statement. Each MapperProperty associated with the target mapper processes the @@ -227,16 +277,30 @@ class MapperProperty( """ def create_row_processor( - self, context, query_entity, path, mapper, result, adapter, populators - ): + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: PathRegistry, + mapper: Mapper[Any], + result: Result, + adapter: Optional[ColumnAdapter], + populators: _PopulatorDict, + ) -> None: """Produce row processing functions and append to the given set of populators lists. """ def cascade_iterator( - self, type_, state, dict_, visited_states, halt_on=None - ): + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[ + Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict] + ]: """Iterate through instances related to the given instance for a particular 'cascade', starting with this MapperProperty. @@ -251,7 +315,7 @@ class MapperProperty( return iter(()) - def set_parent(self, parent, init): + def set_parent(self, parent: Mapper[Any], init: bool) -> None: """Set the parent mapper that references this MapperProperty. This method is overridden by some subclasses to perform extra @@ -260,7 +324,7 @@ class MapperProperty( """ self.parent = parent - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: """Hook called by the Mapper to the property to initiate instrumentation of the class attribute managed by this MapperProperty. @@ -280,11 +344,11 @@ class MapperProperty( """ - def __init__(self): + def __init__(self) -> None: self._configure_started = False self._configure_finished = False - def init(self): + def init(self) -> None: """Called after all mappers are created to assemble relationships between mappers and perform other post-mapper-creation initialization steps. @@ -296,7 +360,7 @@ class MapperProperty( self._configure_finished = True @property - def class_attribute(self): + def class_attribute(self) -> InstrumentedAttribute[_T]: """Return the class-bound descriptor corresponding to this :class:`.MapperProperty`. @@ -319,9 +383,9 @@ class MapperProperty( """ - return getattr(self.parent.class_, self.key) + return getattr(self.parent.class_, self.key) # type: ignore - def do_init(self): + def do_init(self) -> None: """Perform subclass-specific initialization post-mapper-creation steps. @@ -330,7 +394,7 @@ class MapperProperty( """ - def post_instrument_class(self, mapper): + def post_instrument_class(self, mapper: Mapper[Any]) -> None: """Perform instrumentation adjustments that need to occur after init() has completed. @@ -347,21 +411,21 @@ class MapperProperty( def merge( self, - session, - source_state, - source_dict, - dest_state, - dest_dict, - load, - _recursive, - _resolve_conflict_map, - ): + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Set[InstanceState[Any]], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: """Merge the attribute represented by this ``MapperProperty`` from source to destination object. """ - def __repr__(self): + def __repr__(self) -> str: return "<%s at 0x%x; %s>" % ( self.__class__.__name__, id(self), @@ -452,21 +516,28 @@ class PropComparator(SQLORMOperations[_T]): """ - __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" + __slots__ = "prop", "_parententity", "_adapt_to_entity" __visit_name__ = "orm_prop_comparator" + _parententity: _InternalEntityType[Any] + _adapt_to_entity: Optional[AliasedInsp[Any]] + def __init__( self, - prop, - parentmapper, - adapt_to_entity=None, + prop: MapperProperty[_T], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, ): - self.prop = self.property = prop + self.prop = prop self._parententity = adapt_to_entity or parentmapper self._adapt_to_entity = adapt_to_entity - def __clause_element__(self): + @util.ro_non_memoized_property + def property(self) -> Optional[MapperProperty[_T]]: + return self.prop + + def __clause_element__(self) -> _ORMColumnExprArgument[_T]: raise NotImplementedError("%r" % self) def _bulk_update_tuples( @@ -480,22 +551,24 @@ class PropComparator(SQLORMOperations[_T]): """ - return [(self.__clause_element__(), value)] + return [(cast("_DMLColumnArgument", self.__clause_element__()), value)] - def adapt_to_entity(self, adapt_to_entity): + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> PropComparator[_T]: """Return a copy of this PropComparator which will use the given :class:`.AliasedInsp` to produce corresponding expressions. """ return self.__class__(self.prop, self._parententity, adapt_to_entity) - @property - def _parentmapper(self): + @util.ro_non_memoized_property + def _parentmapper(self) -> Mapper[Any]: """legacy; this is renamed to _parententity to be compatible with QueryableAttribute.""" - return inspect(self._parententity).mapper + return self._parententity.mapper - @property - def _propagate_attrs(self): + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: # this suits the case in coercions where we don't actually # call ``__clause_element__()`` but still need to get # resolved._propagate_attrs. See #6558. @@ -507,12 +580,14 @@ class PropComparator(SQLORMOperations[_T]): ) def _criterion_exists( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[Any]: return self.prop.comparator._criterion_exists(criterion, **kwargs) - @property - def adapter(self): + @util.ro_non_memoized_property + def adapter(self) -> Optional[_ORMAdapterProto[_T]]: """Produce a callable that adapts column expressions to suit an aliased version of this comparator. @@ -522,20 +597,20 @@ class PropComparator(SQLORMOperations[_T]): else: return self._adapt_to_entity._adapt_element - @util.non_memoized_property + @util.ro_non_memoized_property def info(self) -> _InfoType: - return self.property.info + return self.prop.info @staticmethod - def _any_op(a, b, **kwargs): + def _any_op(a: Any, b: Any, **kwargs: Any) -> Any: return a.any(b, **kwargs) @staticmethod - def _has_op(left, other, **kwargs): + def _has_op(left: Any, other: Any, **kwargs: Any) -> Any: return left.has(other, **kwargs) @staticmethod - def _of_type_op(a, class_): + def _of_type_op(a: Any, class_: Any) -> Any: return a.of_type(class_) any_op = cast(operators.OperatorType, _any_op) @@ -545,16 +620,16 @@ class PropComparator(SQLORMOperations[_T]): if typing.TYPE_CHECKING: def operate( - self, op: operators.OperatorType, *other: Any, **kwargs: Any - ) -> "SQLCoreOperations[Any]": + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... def reverse_operate( - self, op: operators.OperatorType, other: Any, **kwargs: Any - ) -> "SQLCoreOperations[Any]": + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... - def of_type(self, class_) -> "SQLORMOperations[_T]": + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: r"""Redefine this object in terms of a polymorphic subclass, :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased` construct. @@ -578,9 +653,11 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.of_type_op, class_) + return self.operate(PropComparator.of_type_op, class_) # type: ignore - def and_(self, *criteria) -> "SQLORMOperations[_T]": + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> ColumnElement[bool]: """Add additional criteria to the ON clause that's represented by this relationship attribute. @@ -606,10 +683,12 @@ class PropComparator(SQLORMOperations[_T]): :func:`.with_loader_criteria` """ - return self.operate(operators.and_, *criteria) + return self.operate(operators.and_, *criteria) # type: ignore def any( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[bool]: r"""Return a SQL expression representing true if this element references a member which meets the given criterion. @@ -626,10 +705,14 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.any_op, criterion, **kwargs) + return self.operate( # type: ignore + PropComparator.any_op, criterion, **kwargs + ) def has( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[bool]: r"""Return a SQL expression representing true if this element references a member which meets the given criterion. @@ -646,7 +729,9 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.has_op, criterion, **kwargs) + return self.operate( # type: ignore + PropComparator.has_op, criterion, **kwargs + ) class StrategizedProperty(MapperProperty[_T]): @@ -674,23 +759,30 @@ class StrategizedProperty(MapperProperty[_T]): "strategy_key", ) inherit_cache = True - strategy_wildcard_key = None + strategy_wildcard_key: ClassVar[str] strategy_key: Tuple[Any, ...] - def _memoized_attr__wildcard_token(self): + _strategies: Dict[Tuple[Any, ...], LoaderStrategy] + + def _memoized_attr__wildcard_token(self) -> Tuple[str]: return ( f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}", ) - def _memoized_attr__default_path_loader_key(self): + def _memoized_attr__default_path_loader_key( + self, + ) -> Tuple[str, Tuple[str]]: return ( "loader", (f"{self.strategy_wildcard_key}:{path_registry._DEFAULT_TOKEN}",), ) - def _get_context_loader(self, context, path): - load = None + def _get_context_loader( + self, context: ORMCompileState, path: AbstractEntityRegistry + ) -> Optional[_LoadElement]: + + load: Optional[_LoadElement] = None search_path = path[self] @@ -714,7 +806,7 @@ class StrategizedProperty(MapperProperty[_T]): return load - def _get_strategy(self, key): + def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy: try: return self._strategies[key] except KeyError: @@ -768,11 +860,13 @@ class StrategizedProperty(MapperProperty[_T]): ): self.strategy.init_class_attribute(mapper) - _all_strategies = collections.defaultdict(dict) + _all_strategies: collections.defaultdict[ + Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]] + ] = collections.defaultdict(dict) @classmethod - def strategy_for(cls, **kw): - def decorate(dec_cls): + def strategy_for(cls, **kw: Any) -> Callable[[_TLS], _TLS]: + def decorate(dec_cls: _TLS) -> _TLS: # ensure each subclass of the strategy has its # own _strategy_keys collection if "_strategy_keys" not in dec_cls.__dict__: @@ -785,7 +879,9 @@ class StrategizedProperty(MapperProperty[_T]): return decorate @classmethod - def _strategy_lookup(cls, requesting_property, *key): + def _strategy_lookup( + cls, requesting_property: MapperProperty[Any], *key: Any + ) -> Type[LoaderStrategy]: requesting_property.parent._with_polymorphic_mappers for prop_cls in cls.__mro__: @@ -984,10 +1080,10 @@ class MapperOption(ORMOption): """ - def process_query(self, query): + def process_query(self, query: Query[Any]) -> None: """Apply a modification to the given :class:`_query.Query`.""" - def process_query_conditionally(self, query): + def process_query_conditionally(self, query: Query[Any]) -> None: """same as process_query(), except that this option may not apply to the given query. @@ -1034,7 +1130,11 @@ class LoaderStrategy: "strategy_opts", ) - def __init__(self, parent, strategy_key): + _strategy_keys: ClassVar[List[Tuple[Any, ...]]] + + def __init__( + self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...] + ): self.parent_property = parent self.is_class_level = False self.parent = self.parent_property.parent @@ -1042,12 +1142,18 @@ class LoaderStrategy: self.strategy_key = strategy_key self.strategy_opts = dict(strategy_key) - def init_class_attribute(self, mapper): + def init_class_attribute(self, mapper: Mapper[Any]) -> None: pass def setup_query( - self, compile_state, query_entity, path, loadopt, adapter, **kwargs - ): + self, + compile_state: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: """Establish column and other state for a given QueryContext. This method fulfills the contract specified by MapperProperty.setup(). @@ -1059,15 +1165,15 @@ class LoaderStrategy: def create_row_processor( self, - context, - query_entity, - path, - loadopt, - mapper, - result, - adapter, - populators, - ): + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + mapper: Mapper[Any], + result: Result, + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: """Establish row processing functions for a given QueryContext. This method fulfills the contract specified by diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index ae083054cd..d9949eb7a3 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -16,7 +16,9 @@ as well as some of the attribute loading strategies. from __future__ import annotations from typing import Any +from typing import Dict from typing import Iterable +from typing import List from typing import Mapping from typing import Optional from typing import Sequence @@ -65,6 +67,9 @@ _O = TypeVar("_O", bound=object) _new_runid = util.counter() +_PopulatorDict = Dict[str, List[Tuple[str, Any]]] + + def instances(cursor, context): """Return a :class:`.Result` given an ORM query context. @@ -383,7 +388,7 @@ def get_from_identity( mapper: Mapper[_O], key: _IdentityKeyType[_O], passive: PassiveFlag, -) -> Union[Optional[_O], LoaderCallableStatus]: +) -> Union[LoaderCallableStatus, Optional[_O]]: """Look up the given key in the given session's identity map, check the object for expired state if found. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index abe11cc68c..b37c080eaf 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -23,12 +23,23 @@ import sys import threading from typing import Any from typing import Callable +from typing import cast +from typing import Collection +from typing import Deque +from typing import Dict from typing import Generic +from typing import Iterable from typing import Iterator +from typing import List +from typing import Mapping from typing import Optional +from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import weakref from . import attributes @@ -39,8 +50,8 @@ from . import properties from . import util as orm_util from ._typing import _O from .base import _class_to_mapper +from .base import _parse_mapper_argument from .base import _state_mapper -from .base import class_mapper from .base import PassiveFlag from .base import state_str from .interfaces import _MappedAttribute @@ -58,6 +69,8 @@ from .. import log from .. import schema from .. import sql from .. import util +from ..event import dispatcher +from ..event import EventTarget from ..sql import base as sql_base from ..sql import coercions from ..sql import expression @@ -65,26 +78,68 @@ from ..sql import operators from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.schema import Table from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import HasMemoized +from ..util import HasMemoized_ro_memoized_attribute +from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _IdentityKeyType from ._typing import _InstanceDict + from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .decl_api import registry + from .dependency import DependencyProcessor + from .descriptor_props import Composite + from .descriptor_props import Synonym + from .events import MapperEvents from .instrumentation import ClassManager + from .path_registry import AbstractEntityRegistry + from .path_registry import CachingEntityRegistry + from .properties import ColumnProperty + from .relationships import Relationship from .state import InstanceState + from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement from ..sql.schema import Column + from ..sql.schema import Table + from ..sql.selectable import FromClause + from ..sql.selectable import TableClause + from ..sql.util import ColumnAdapter + from ..util import OrderedSet -_mapper_registries = weakref.WeakKeyDictionary() +_T = TypeVar("_T", bound=Any) +_MP = TypeVar("_MP", bound="MapperProperty[Any]") -def _all_registries(): +_WithPolymorphicArg = Union[ + Literal["*"], + Tuple[ + Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], + Optional["FromClause"], + ], + Sequence[Union["Mapper[Any]", Type[Any]]], +] + + +_mapper_registries: weakref.WeakKeyDictionary[ + _RegistryType, bool +] = weakref.WeakKeyDictionary() + + +def _all_registries() -> Set[registry]: with _CONFIGURE_MUTEX: return set(_mapper_registries) -def _unconfigured_mappers(): +def _unconfigured_mappers() -> Iterator[Mapper[Any]]: for reg in _all_registries(): for mapper in reg._mappers_to_configure(): yield mapper @@ -107,9 +162,11 @@ _CONFIGURE_MUTEX = threading.RLock() class Mapper( ORMFromClauseRole, ORMEntityColumnsClauseRole, - sql_base.MemoizedHasCacheKey, + MemoizedHasCacheKey, InspectionAttr, log.Identified, + inspection.Inspectable["Mapper[_O]"], + EventTarget, Generic[_O], ): """Defines an association between a Python class and a database table or @@ -123,18 +180,11 @@ class Mapper( """ + dispatch: dispatcher[Mapper[_O]] + _dispose_called = False _ready_for_configure = False - class_: Type[_O] - """The class to which this :class:`_orm.Mapper` is mapped.""" - - _identity_class: Type[_O] - - always_refresh: bool - allow_partial_pks: bool - version_id_col: Optional[ColumnElement[Any]] - @util.deprecated_params( non_primary=( "1.3", @@ -148,33 +198,39 @@ class Mapper( def __init__( self, class_: Type[_O], - local_table=None, - properties=None, - primary_key=None, - non_primary=False, - inherits=None, - inherit_condition=None, - inherit_foreign_keys=None, - always_refresh=False, - version_id_col=None, - version_id_generator=None, - polymorphic_on=None, - _polymorphic_map=None, - polymorphic_identity=None, - concrete=False, - with_polymorphic=None, - polymorphic_load=None, - allow_partial_pks=True, - batch=True, - column_prefix=None, - include_properties=None, - exclude_properties=None, - passive_updates=True, - passive_deletes=False, - confirm_deleted_rows=True, - eager_defaults=False, - legacy_is_orphan=False, - _compiled_cache_size=100, + local_table: Optional[FromClause] = None, + properties: Optional[Mapping[str, MapperProperty[Any]]] = None, + primary_key: Optional[Iterable[_ORMColumnExprArgument[Any]]] = None, + non_primary: bool = False, + inherits: Optional[Union[Mapper[Any], Type[Any]]] = None, + inherit_condition: Optional[_ColumnExpressionArgument[bool]] = None, + inherit_foreign_keys: Optional[ + Sequence[_ORMColumnExprArgument[Any]] + ] = None, + always_refresh: bool = False, + version_id_col: Optional[_ORMColumnExprArgument[Any]] = None, + version_id_generator: Optional[ + Union[Literal[False], Callable[[Any], Any]] + ] = None, + polymorphic_on: Optional[ + Union[_ORMColumnExprArgument[Any], str, MapperProperty[Any]] + ] = None, + _polymorphic_map: Optional[Dict[Any, Mapper[Any]]] = None, + polymorphic_identity: Optional[Any] = None, + concrete: bool = False, + with_polymorphic: Optional[_WithPolymorphicArg] = None, + polymorphic_load: Optional[Literal["selectin", "inline"]] = None, + allow_partial_pks: bool = True, + batch: bool = True, + column_prefix: Optional[str] = None, + include_properties: Optional[Sequence[str]] = None, + exclude_properties: Optional[Sequence[str]] = None, + passive_updates: bool = True, + passive_deletes: bool = False, + confirm_deleted_rows: bool = True, + eager_defaults: bool = False, + legacy_is_orphan: bool = False, + _compiled_cache_size: int = 100, ): r"""Direct constructor for a new :class:`_orm.Mapper` object. @@ -593,8 +649,6 @@ class Mapper( self.class_.__name__, ) - self.class_manager = None - self._primary_key_argument = util.to_list(primary_key) self.non_primary = non_primary @@ -623,17 +677,36 @@ class Mapper( self.concrete = concrete self.single = False - self.inherits = inherits + + if inherits is not None: + self.inherits = _parse_mapper_argument(inherits) + else: + self.inherits = None + if local_table is not None: self.local_table = coercions.expect( roles.StrictFromClauseRole, local_table ) + elif self.inherits: + # note this is a new flow as of 2.0 so that + # .local_table need not be Optional + self.local_table = self.inherits.local_table + self.single = True else: - self.local_table = None + raise sa_exc.ArgumentError( + f"Mapper[{self.class_.__name__}(None)] has None for a " + "primary table argument and does not specify 'inherits'" + ) + + if inherit_condition is not None: + self.inherit_condition = coercions.expect( + roles.OnClauseRole, inherit_condition + ) + else: + self.inherit_condition = None - self.inherit_condition = inherit_condition self.inherit_foreign_keys = inherit_foreign_keys - self._init_properties = properties or {} + self._init_properties = dict(properties) if properties else {} self._delete_orphans = [] self.batch = batch self.eager_defaults = eager_defaults @@ -694,7 +767,10 @@ class Mapper( # while a configure_mappers() is occurring (and defer a # configure_mappers() until construction succeeds) with _CONFIGURE_MUTEX: - self.dispatch._events._new_mapper_instance(class_, self) + + cast("MapperEvents", self.dispatch._events)._new_mapper_instance( + class_, self + ) self._configure_inheritance() self._configure_class_instrumentation() self._configure_properties() @@ -704,16 +780,21 @@ class Mapper( self._log("constructed") self._expire_memoizations() - # major attributes initialized at the classlevel so that - # they can be Sphinx-documented. + def _gen_cache_key(self, anon_map, bindparams): + return (self,) + + # ### BEGIN + # ATTRIBUTE DECLARATIONS START HERE is_mapper = True """Part of the inspection API.""" represents_outer_join = False + registry: _RegistryType + @property - def mapper(self): + def mapper(self) -> Mapper[_O]: """Part of the inspection API. Returns self. @@ -721,9 +802,6 @@ class Mapper( """ return self - def _gen_cache_key(self, anon_map, bindparams): - return (self,) - @property def entity(self): r"""Part of the inspection API. @@ -733,49 +811,109 @@ class Mapper( """ return self.class_ - local_table = None - """The :class:`_expression.Selectable` which this :class:`_orm.Mapper` - manages. + class_: Type[_O] + """The class to which this :class:`_orm.Mapper` is mapped.""" + + _identity_class: Type[_O] + + _delete_orphans: List[Tuple[str, Type[Any]]] + _dependency_processors: List[DependencyProcessor] + _memoized_values: Dict[Any, Callable[[], Any]] + _inheriting_mappers: util.WeakSequence[Mapper[Any]] + _all_tables: Set[Table] + + _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] + _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] + + _props: util.OrderedDict[str, MapperProperty[Any]] + _init_properties: Dict[str, MapperProperty[Any]] + + _columntoproperty: _ColumnMapping + + _set_polymorphic_identity: Optional[Callable[[InstanceState[_O]], None]] + _validate_polymorphic_identity: Optional[ + Callable[[Mapper[_O], InstanceState[_O], _InstanceDict], None] + ] + + tables: Sequence[Table] + """A sequence containing the collection of :class:`_schema.Table` objects + which this :class:`_orm.Mapper` is aware of. + + If the mapper is mapped to a :class:`_expression.Join`, or an + :class:`_expression.Alias` + representing a :class:`_expression.Select`, the individual + :class:`_schema.Table` + objects that comprise the full construct will be represented here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + validators: util.immutabledict[str, Tuple[str, Dict[str, Any]]] + """An immutable dictionary of attributes which have been decorated + using the :func:`_orm.validates` decorator. + + The dictionary contains string attribute names as keys + mapped to the actual validation method. + + """ + + always_refresh: bool + allow_partial_pks: bool + version_id_col: Optional[ColumnElement[Any]] + + with_polymorphic: Optional[ + Tuple[ + Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], + Optional["FromClause"], + ] + ] + + version_id_generator: Optional[Union[Literal[False], Callable[[Any], Any]]] + + local_table: FromClause + """The immediate :class:`_expression.FromClause` which this + :class:`_orm.Mapper` refers towards. - Typically is an instance of :class:`_schema.Table` or - :class:`_expression.Alias`. - May also be ``None``. + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. The "local" table is the selectable that the :class:`_orm.Mapper` is directly responsible for managing from an attribute access and flush perspective. For - non-inheriting mappers, the local table is the same as the - "mapped" table. For joined-table inheritance mappers, local_table - will be the particular sub-table of the overall "join" which - this :class:`_orm.Mapper` represents. If this mapper is a - single-table inheriting mapper, local_table will be ``None``. + non-inheriting mappers, :attr:`.Mapper.local_table` will be the same + as :attr:`.Mapper.persist_selectable`. For inheriting mappers, + :attr:`.Mapper.local_table` refers to the specific portion of + :attr:`.Mapper.persist_selectable` that includes the columns to which + this :class:`.Mapper` is loading/persisting, such as a particular + :class:`.Table` within a join. .. seealso:: :attr:`_orm.Mapper.persist_selectable`. + :attr:`_orm.Mapper.selectable`. + """ - persist_selectable = None - """The :class:`_expression.Selectable` to which this :class:`_orm.Mapper` + persist_selectable: FromClause + """The :class:`_expression.FromClause` to which this :class:`_orm.Mapper` is mapped. - Typically an instance of :class:`_schema.Table`, - :class:`_expression.Join`, or :class:`_expression.Alias`. - - The :attr:`_orm.Mapper.persist_selectable` is separate from - :attr:`_orm.Mapper.selectable` in that the former represents columns - that are mapped on this class or its superclasses, whereas the - latter may be a "polymorphic" selectable that contains additional columns - which are in fact mapped on subclasses only. + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. - "persist selectable" is the "thing the mapper writes to" and - "selectable" is the "thing the mapper selects from". - - :attr:`_orm.Mapper.persist_selectable` is also separate from - :attr:`_orm.Mapper.local_table`, which represents the set of columns that - are locally mapped on this class directly. + The :attr:`_orm.Mapper.persist_selectable` is similar to + :attr:`.Mapper.local_table`, but represents the :class:`.FromClause` that + represents the inheriting class hierarchy overall in an inheritance + scenario. + :attr.`.Mapper.persist_selectable` is also separate from the + :attr:`.Mapper.selectable` attribute, the latter of which may be an + alternate subquery used for selecting columns. + :attr.`.Mapper.persist_selectable` is oriented towards columns that + will be written on a persist operation. .. seealso:: @@ -785,16 +923,15 @@ class Mapper( """ - inherits = None + inherits: Optional[Mapper[Any]] """References the :class:`_orm.Mapper` which this :class:`_orm.Mapper` inherits from, if any. - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - """ - configured = False + inherit_condition: Optional[ColumnElement[bool]] + + configured: bool = False """Represent ``True`` if this :class:`_orm.Mapper` has been configured. This is a *read only* attribute determined during mapper construction. @@ -806,7 +943,7 @@ class Mapper( """ - concrete = None + concrete: bool """Represent ``True`` if this :class:`_orm.Mapper` is a concrete inheritance mapper. @@ -815,21 +952,6 @@ class Mapper( """ - tables = None - """An iterable containing the collection of :class:`_schema.Table` objects - which this :class:`_orm.Mapper` is aware of. - - If the mapper is mapped to a :class:`_expression.Join`, or an - :class:`_expression.Alias` - representing a :class:`_expression.Select`, the individual - :class:`_schema.Table` - objects that comprise the full construct will be represented here. - - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - - """ - primary_key: Tuple[Column[Any], ...] """An iterable containing the collection of :class:`_schema.Column` objects @@ -854,14 +976,6 @@ class Mapper( """ - class_: Type[_O] - """The Python class which this :class:`_orm.Mapper` maps. - - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - - """ - class_manager: ClassManager[_O] """The :class:`.ClassManager` which maintains event listeners and class-bound descriptors for this :class:`_orm.Mapper`. @@ -871,7 +985,7 @@ class Mapper( """ - single = None + single: bool """Represent ``True`` if this :class:`_orm.Mapper` is a single table inheritance mapper. @@ -882,7 +996,7 @@ class Mapper( """ - non_primary = None + non_primary: bool """Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary" mapper, e.g. a mapper that is used only to select rows but not for persistence management. @@ -892,7 +1006,7 @@ class Mapper( """ - polymorphic_on = None + polymorphic_on: Optional[ColumnElement[Any]] """The :class:`_schema.Column` or SQL expression specified as the ``polymorphic_on`` argument for this :class:`_orm.Mapper`, within an inheritance scenario. @@ -906,7 +1020,7 @@ class Mapper( """ - polymorphic_map = None + polymorphic_map: Dict[Any, Mapper[Any]] """A mapping of "polymorphic identity" identifiers mapped to :class:`_orm.Mapper` instances, within an inheritance scenario. @@ -922,7 +1036,7 @@ class Mapper( """ - polymorphic_identity = None + polymorphic_identity: Optional[Any] """Represent an identifier which is matched against the :attr:`_orm.Mapper.polymorphic_on` column during result row loading. @@ -935,7 +1049,7 @@ class Mapper( """ - base_mapper = None + base_mapper: Mapper[Any] """The base-most :class:`_orm.Mapper` in an inheritance chain. In a non-inheriting scenario, this attribute will always be this @@ -948,7 +1062,7 @@ class Mapper( """ - columns = None + columns: ReadOnlyColumnCollection[str, Column[Any]] """A collection of :class:`_schema.Column` or other scalar expression objects maintained by this :class:`_orm.Mapper`. @@ -965,25 +1079,16 @@ class Mapper( """ - validators = None - """An immutable dictionary of attributes which have been decorated - using the :func:`_orm.validates` decorator. - - The dictionary contains string attribute names as keys - mapped to the actual validation method. - - """ - - c = None + c: ReadOnlyColumnCollection[str, Column[Any]] """A synonym for :attr:`_orm.Mapper.columns`.""" - @property + @util.non_memoized_property @util.deprecated("1.3", "Use .persist_selectable") def mapped_table(self): return self.persist_selectable @util.memoized_property - def _path_registry(self) -> PathRegistry: + def _path_registry(self) -> CachingEntityRegistry: return PathRegistry.per_mapper(self) def _configure_inheritance(self): @@ -994,8 +1099,6 @@ class Mapper( self._inheriting_mappers = util.WeakSequence() if self.inherits: - if isinstance(self.inherits, type): - self.inherits = class_mapper(self.inherits, configure=False) if not issubclass(self.class_, self.inherits.class_): raise sa_exc.ArgumentError( "Class '%s' does not inherit from '%s'" @@ -1011,11 +1114,9 @@ class Mapper( "only allowed from a %s mapper" % (np, self.class_.__name__, np) ) - # inherit_condition is optional. - if self.local_table is None: - self.local_table = self.inherits.local_table + + if self.single: self.persist_selectable = self.inherits.persist_selectable - self.single = True elif self.local_table is not self.inherits.local_table: if self.concrete: self.persist_selectable = self.local_table @@ -1068,6 +1169,7 @@ class Mapper( self.local_table.description, ) ) from afe + assert self.inherits.persist_selectable is not None self.persist_selectable = sql.join( self.inherits.persist_selectable, self.local_table, @@ -1149,6 +1251,7 @@ class Mapper( else: self._all_tables = set() self.base_mapper = self + assert self.local_table is not None self.persist_selectable = self.local_table if self.polymorphic_identity is not None: self.polymorphic_map[self.polymorphic_identity] = self @@ -1160,21 +1263,34 @@ class Mapper( % self ) - def _set_with_polymorphic(self, with_polymorphic): + def _set_with_polymorphic( + self, with_polymorphic: Optional[_WithPolymorphicArg] + ) -> None: if with_polymorphic == "*": self.with_polymorphic = ("*", None) elif isinstance(with_polymorphic, (tuple, list)): if isinstance(with_polymorphic[0], (str, tuple, list)): - self.with_polymorphic = with_polymorphic + self.with_polymorphic = cast( + """Tuple[ + Union[ + Literal["*"], + Sequence[Union["Mapper[Any]", Type[Any]]], + ], + Optional["FromClause"], + ]""", + with_polymorphic, + ) else: self.with_polymorphic = (with_polymorphic, None) elif with_polymorphic is not None: - raise sa_exc.ArgumentError("Invalid setting for with_polymorphic") + raise sa_exc.ArgumentError( + f"Invalid setting for with_polymorphic: {with_polymorphic!r}" + ) else: self.with_polymorphic = None if self.with_polymorphic and self.with_polymorphic[1] is not None: - self.with_polymorphic = ( + self.with_polymorphic = ( # type: ignore self.with_polymorphic[0], coercions.expect( roles.StrictFromClauseRole, @@ -1191,6 +1307,7 @@ class Mapper( if self.with_polymorphic is None: self._set_with_polymorphic((subcl,)) elif self.with_polymorphic[0] != "*": + assert isinstance(self.with_polymorphic[0], tuple) self._set_with_polymorphic( (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1]) ) @@ -1241,7 +1358,7 @@ class Mapper( # we expect that declarative has applied the class manager # already and set up a registry. if this is None, # this raises as of 2.0. - manager = attributes.manager_of_class(self.class_) + manager = attributes.opt_manager_of_class(self.class_) if self.non_primary: if not manager or not manager.is_mapped: @@ -1251,6 +1368,8 @@ class Mapper( "Mapper." % self.class_ ) self.class_manager = manager + + assert manager.registry is not None self.registry = manager.registry self._identity_class = manager.mapper._identity_class manager.registry._add_non_primary_mapper(self) @@ -1275,7 +1394,7 @@ class Mapper( manager = instrumentation.register_class( self.class_, mapper=self, - expired_attribute_loader=util.partial( + expired_attribute_loader=util.partial( # type: ignore loading.load_scalar_attributes, self ), # finalize flag means instrument the __init__ method @@ -1284,6 +1403,8 @@ class Mapper( ) self.class_manager = manager + + assert manager.registry is not None self.registry = manager.registry # The remaining members can be added by any mapper, @@ -1315,15 +1436,25 @@ class Mapper( {name: (method, validation_opts)} ) - def _set_dispose_flags(self): + def _set_dispose_flags(self) -> None: self.configured = True self._ready_for_configure = True self._dispose_called = True self.__dict__.pop("_configure_failed", None) - def _configure_pks(self): - self.tables = sql_util.find_tables(self.persist_selectable) + def _configure_pks(self) -> None: + self.tables = cast( + "List[Table]", sql_util.find_tables(self.persist_selectable) + ) + for t in self.tables: + if not isinstance(t, Table): + raise sa_exc.ArgumentError( + f"ORM mappings can only be made against schema-level " + f"Table objects, not TableClause; got " + f"tableclause {t.name !r}" + ) + self._all_tables.update(t for t in self.tables if isinstance(t, Table)) self._pks_by_table = {} self._cols_by_table = {} @@ -1335,16 +1466,16 @@ class Mapper( pk_cols = util.column_set(c for c in all_cols if c.primary_key) # identify primary key columns which are also mapped by this mapper. - tables = set(self.tables + [self.persist_selectable]) - self._all_tables.update(tables) - for t in tables: - if t.primary_key and pk_cols.issuperset(t.primary_key): + for fc in set(self.tables).union([self.persist_selectable]): + if fc.primary_key and pk_cols.issuperset(fc.primary_key): # ordering is important since it determines the ordering of # mapper.primary_key (and therefore query.get()) - self._pks_by_table[t] = util.ordered_column_set( - t.primary_key - ).intersection(pk_cols) - self._cols_by_table[t] = util.ordered_column_set(t.c).intersection( + self._pks_by_table[fc] = util.ordered_column_set( # type: ignore # noqa: E501 + fc.primary_key + ).intersection( + pk_cols + ) + self._cols_by_table[fc] = util.ordered_column_set(fc.c).intersection( # type: ignore # noqa: E501 all_cols ) @@ -1386,10 +1517,15 @@ class Mapper( self.primary_key = self.inherits.primary_key else: # determine primary key from argument or persist_selectable pks + primary_key: Collection[ColumnElement[Any]] + if self._primary_key_argument: primary_key = [ - self.persist_selectable.corresponding_column(c) - for c in self._primary_key_argument + cc if cc is not None else c + for cc, c in ( + (self.persist_selectable.corresponding_column(c), c) + for c in self._primary_key_argument + ) ] else: # if heuristically determined PKs, reduce to the minimal set @@ -1413,7 +1549,7 @@ class Mapper( # determine cols that aren't expressed within our tables; mark these # as "read only" properties which are refreshed upon INSERT/UPDATE - self._readonly_props = set( + self._readonly_props = { self._columntoproperty[col] for col in self._columntoproperty if self._columntoproperty[col] not in self._identity_key_props @@ -1421,12 +1557,12 @@ class Mapper( not hasattr(col, "table") or col.table not in self._cols_by_table ) - ) + } - def _configure_properties(self): + def _configure_properties(self) -> None: # TODO: consider using DedupeColumnCollection - self.columns = self.c = sql_base.ColumnCollection() + self.columns = self.c = sql_base.ColumnCollection() # type: ignore # object attribute names mapped to MapperProperty objects self._props = util.OrderedDict() @@ -1454,7 +1590,6 @@ class Mapper( continue column_key = (self.column_prefix or "") + column.key - if self._should_exclude( column.key, column_key, @@ -1542,6 +1677,7 @@ class Mapper( col = self.polymorphic_on if isinstance(col, schema.Column) and ( self.with_polymorphic is None + or self.with_polymorphic[1] is None or self.with_polymorphic[1].corresponding_column(col) is None ): @@ -1763,8 +1899,8 @@ class Mapper( self.columns.add(col, key) for col in prop.columns + prop._orig_columns: - for col in col.proxy_set: - self._columntoproperty[col] = prop + for proxy_col in col.proxy_set: + self._columntoproperty[proxy_col] = prop prop.key = key @@ -2033,7 +2169,9 @@ class Mapper( self._check_configure() return iter(self._props.values()) - def _mappers_from_spec(self, spec, selectable): + def _mappers_from_spec( + self, spec: Any, selectable: Optional[FromClause] + ) -> Sequence[Mapper[Any]]: """given a with_polymorphic() argument, return the set of mappers it represents. @@ -2044,7 +2182,7 @@ class Mapper( if spec == "*": mappers = list(self.self_and_descendants) elif spec: - mappers = set() + mapper_set = set() for m in util.to_list(spec): m = _class_to_mapper(m) if not m.isa(self): @@ -2053,10 +2191,10 @@ class Mapper( ) if selectable is None: - mappers.update(m.iterate_to_root()) + mapper_set.update(m.iterate_to_root()) else: - mappers.add(m) - mappers = [m for m in self.self_and_descendants if m in mappers] + mapper_set.add(m) + mappers = [m for m in self.self_and_descendants if m in mapper_set] else: mappers = [] @@ -2067,7 +2205,9 @@ class Mapper( mappers = [m for m in mappers if m.local_table in tables] return mappers - def _selectable_from_mappers(self, mappers, innerjoin): + def _selectable_from_mappers( + self, mappers: Iterable[Mapper[Any]], innerjoin: bool + ) -> FromClause: """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), construct an outerjoin amongst those mapper's mapped tables. @@ -2098,13 +2238,13 @@ class Mapper( def _single_table_criterion(self): if self.single and self.inherits and self.polymorphic_on is not None: return self.polymorphic_on._annotate({"parentmapper": self}).in_( - m.polymorphic_identity for m in self.self_and_descendants + [m.polymorphic_identity for m in self.self_and_descendants] ) else: return None @HasMemoized.memoized_attribute - def _with_polymorphic_mappers(self): + def _with_polymorphic_mappers(self) -> Sequence[Mapper[Any]]: self._check_configure() if not self.with_polymorphic: @@ -2124,8 +2264,8 @@ class Mapper( """ self._check_configure() - @HasMemoized.memoized_attribute - def _with_polymorphic_selectable(self): + @HasMemoized_ro_memoized_attribute + def _with_polymorphic_selectable(self) -> FromClause: if not self.with_polymorphic: return self.persist_selectable @@ -2143,7 +2283,7 @@ class Mapper( """ - @HasMemoized.memoized_attribute + @HasMemoized_ro_memoized_attribute def _insert_cols_evaluating_none(self): return dict( ( @@ -2250,7 +2390,7 @@ class Mapper( @HasMemoized.memoized_instancemethod def __clause_element__(self): - annotations = { + annotations: Dict[str, Any] = { "entity_namespace": self, "parententity": self, "parentmapper": self, @@ -2290,7 +2430,7 @@ class Mapper( ) @property - def selectable(self): + def selectable(self) -> FromClause: """The :class:`_schema.FromClause` construct this :class:`_orm.Mapper` selects from by default. @@ -2302,8 +2442,11 @@ class Mapper( return self._with_polymorphic_selectable def _with_polymorphic_args( - self, spec=None, selectable=False, innerjoin=False - ): + self, + spec: Any = None, + selectable: Union[Literal[False, None], FromClause] = False, + innerjoin: bool = False, + ) -> Tuple[Sequence[Mapper[Any]], FromClause]: if selectable not in (None, False): selectable = coercions.expect( roles.StrictFromClauseRole, selectable, allow_select=True @@ -2357,7 +2500,7 @@ class Mapper( ] @HasMemoized.memoized_attribute - def _polymorphic_adapter(self): + def _polymorphic_adapter(self) -> Optional[sql_util.ColumnAdapter]: if self.with_polymorphic: return sql_util.ColumnAdapter( self.selectable, equivalents=self._equivalent_columns @@ -2394,7 +2537,7 @@ class Mapper( yield c @HasMemoized.memoized_attribute - def attrs(self) -> util.ReadOnlyProperties["MapperProperty"]: + def attrs(self) -> util.ReadOnlyProperties[MapperProperty[Any]]: """A namespace of all :class:`.MapperProperty` objects associated this mapper. @@ -2432,7 +2575,7 @@ class Mapper( return util.ReadOnlyProperties(self._props) @HasMemoized.memoized_attribute - def all_orm_descriptors(self): + def all_orm_descriptors(self) -> util.ReadOnlyProperties[InspectionAttr]: """A namespace of all :class:`.InspectionAttr` attributes associated with the mapped class. @@ -2503,7 +2646,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") - def synonyms(self): + def synonyms(self) -> util.ReadOnlyProperties[Synonym[Any]]: """Return a namespace of all :class:`.Synonym` properties maintained by this :class:`_orm.Mapper`. @@ -2523,7 +2666,7 @@ class Mapper( return self.class_ @HasMemoized.memoized_attribute - def column_attrs(self): + def column_attrs(self) -> util.ReadOnlyProperties[ColumnProperty[Any]]: """Return a namespace of all :class:`.ColumnProperty` properties maintained by this :class:`_orm.Mapper`. @@ -2536,9 +2679,9 @@ class Mapper( """ return self._filter_properties(properties.ColumnProperty) - @util.preload_module("sqlalchemy.orm.relationships") @HasMemoized.memoized_attribute - def relationships(self): + @util.preload_module("sqlalchemy.orm.relationships") + def relationships(self) -> util.ReadOnlyProperties[Relationship[Any]]: """A namespace of all :class:`.Relationship` properties maintained by this :class:`_orm.Mapper`. @@ -2567,7 +2710,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") - def composites(self): + def composites(self) -> util.ReadOnlyProperties[Composite[Any]]: """Return a namespace of all :class:`.Composite` properties maintained by this :class:`_orm.Mapper`. @@ -2582,7 +2725,9 @@ class Mapper( util.preloaded.orm_descriptor_props.Composite ) - def _filter_properties(self, type_): + def _filter_properties( + self, type_: Type[_MP] + ) -> util.ReadOnlyProperties[_MP]: self._check_configure() return util.ReadOnlyProperties( util.OrderedDict( @@ -2610,7 +2755,7 @@ class Mapper( ) @HasMemoized.memoized_attribute - def _equivalent_columns(self): + def _equivalent_columns(self) -> _EquivalentColumnMap: """Create a map of all equivalent columns, based on the determination of column pairs that are equated to one another based on inherit condition. This is designed @@ -2630,18 +2775,18 @@ class Mapper( } """ - result = util.column_dict() + result: _EquivalentColumnMap = {} def visit_binary(binary): if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: - result[binary.left] = util.column_set((binary.right,)) + result[binary.left] = {binary.right} if binary.right in result: result[binary.right].add(binary.left) else: - result[binary.right] = util.column_set((binary.left,)) + result[binary.right] = {binary.left} for mapper in self.base_mapper.self_and_descendants: if mapper.inherit_condition is not None: @@ -2711,13 +2856,13 @@ class Mapper( return False - def common_parent(self, other): + def common_parent(self, other: Mapper[Any]) -> bool: """Return true if the given mapper shares a common inherited parent as this mapper.""" return self.base_mapper is other.base_mapper - def is_sibling(self, other): + def is_sibling(self, other: Mapper[Any]) -> bool: """return true if the other mapper is an inheriting sibling to this one. common parent but different branch @@ -2728,7 +2873,9 @@ class Mapper( and not other.isa(self) ) - def _canload(self, state, allow_subtypes): + def _canload( + self, state: InstanceState[Any], allow_subtypes: bool + ) -> bool: s = self.primary_mapper() if self.polymorphic_on is not None or allow_subtypes: return _state_mapper(state).isa(s) @@ -2738,19 +2885,19 @@ class Mapper( def isa(self, other: Mapper[Any]) -> bool: """Return True if the this mapper inherits from the given mapper.""" - m = self + m: Optional[Mapper[Any]] = self while m and m is not other: m = m.inherits return bool(m) - def iterate_to_root(self): - m = self + def iterate_to_root(self) -> Iterator[Mapper[Any]]: + m: Optional[Mapper[Any]] = self while m: yield m m = m.inherits @HasMemoized.memoized_attribute - def self_and_descendants(self): + def self_and_descendants(self) -> Sequence[Mapper[Any]]: """The collection including this mapper and all descendant mappers. This includes not just the immediately inheriting mappers but @@ -2765,7 +2912,7 @@ class Mapper( stack.extend(item._inheriting_mappers) return util.WeakSequence(descendants) - def polymorphic_iterator(self): + def polymorphic_iterator(self) -> Iterator[Mapper[Any]]: """Iterate through the collection including this mapper and all descendant mappers. @@ -2778,18 +2925,18 @@ class Mapper( """ return iter(self.self_and_descendants) - def primary_mapper(self): + def primary_mapper(self) -> Mapper[Any]: """Return the primary mapper corresponding to this mapper's class key (class).""" return self.class_manager.mapper @property - def primary_base_mapper(self): + def primary_base_mapper(self) -> Mapper[Any]: return self.class_manager.mapper.base_mapper def _result_has_identity_key(self, result, adapter=None): - pk_cols = self.primary_key + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key if adapter: pk_cols = [adapter.columns[c] for c in pk_cols] rk = result.keys() @@ -2799,25 +2946,35 @@ class Mapper( else: return True - def identity_key_from_row(self, row, identity_token=None, adapter=None): + def identity_key_from_row( + self, + row: Optional[Union[Row, RowMapping]], + identity_token: Optional[Any] = None, + adapter: Optional[ColumnAdapter] = None, + ) -> _IdentityKeyType[_O]: """Return an identity-map key for use in storing/retrieving an item from the identity map. - :param row: A :class:`.Row` instance. The columns which are - mapped by this :class:`_orm.Mapper` should be locatable in the row, - preferably via the :class:`_schema.Column` - object directly (as is the case - when a :func:`_expression.select` construct is executed), or - via string names of the form ``_``. + :param row: A :class:`.Row` or :class:`.RowMapping` produced from a + result set that selected from the ORM mapped primary key columns. + + .. versionchanged:: 2.0 + :class:`.Row` or :class:`.RowMapping` are accepted + for the "row" argument """ - pk_cols = self.primary_key + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key if adapter: pk_cols = [adapter.columns[c] for c in pk_cols] + if hasattr(row, "_mapping"): + mapping = row._mapping # type: ignore + else: + mapping = cast("Mapping[Any, Any]", row) + return ( self._identity_class, - tuple(row[column] for column in pk_cols), + tuple(mapping[column] for column in pk_cols), # type: ignore identity_token, ) @@ -2852,12 +3009,12 @@ class Mapper( """ state = attributes.instance_state(instance) - return self._identity_key_from_state(state, attributes.PASSIVE_OFF) + return self._identity_key_from_state(state, PassiveFlag.PASSIVE_OFF) def _identity_key_from_state( self, state: InstanceState[_O], - passive: PassiveFlag = attributes.PASSIVE_RETURN_NO_VALUE, + passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, ) -> _IdentityKeyType[_O]: dict_ = state.dict manager = state.manager @@ -2884,7 +3041,7 @@ class Mapper( """ state = attributes.instance_state(instance) identity_key = self._identity_key_from_state( - state, attributes.PASSIVE_OFF + state, PassiveFlag.PASSIVE_OFF ) return identity_key[1] @@ -2913,14 +3070,14 @@ class Mapper( @HasMemoized.memoized_attribute def _all_pk_cols(self): - collection = set() + collection: Set[ColumnClause[Any]] = set() for table in self.tables: collection.update(self._pks_by_table[table]) return collection @HasMemoized.memoized_attribute def _should_undefer_in_wildcard(self): - cols = set(self.primary_key) + cols: Set[ColumnElement[Any]] = set(self.primary_key) if self.polymorphic_on is not None: cols.add(self.polymorphic_on) return cols @@ -2951,11 +3108,11 @@ class Mapper( state = attributes.instance_state(obj) dict_ = attributes.instance_dict(obj) return self._get_committed_state_attr_by_column( - state, dict_, column, passive=attributes.PASSIVE_OFF + state, dict_, column, passive=PassiveFlag.PASSIVE_OFF ) def _get_committed_state_attr_by_column( - self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE + self, state, dict_, column, passive=PassiveFlag.PASSIVE_RETURN_NO_VALUE ): prop = self._columntoproperty[column] @@ -2978,7 +3135,7 @@ class Mapper( col_attribute_names = set(attribute_names).intersection( state.mapper.column_attrs.keys() ) - tables = set( + tables: Set[FromClause] = set( chain( *[ sql_util.find_tables(c, check_columns=True) @@ -3002,7 +3159,7 @@ class Mapper( state, state.dict, leftcol, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) if leftval in orm_util._none_set: raise _OptGetColumnsNotAvailable() @@ -3014,7 +3171,7 @@ class Mapper( state, state.dict, rightcol, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) if rightval in orm_util._none_set: raise _OptGetColumnsNotAvailable() @@ -3022,7 +3179,7 @@ class Mapper( None, rightval, type_=binary.right.type ) - allconds = [] + allconds: List[ColumnElement[bool]] = [] start = False @@ -3035,6 +3192,9 @@ class Mapper( elif not isinstance(mapper.local_table, expression.TableClause): return None if start and not mapper.single: + assert mapper.inherits + assert not mapper.concrete + assert mapper.inherit_condition is not None allconds.append(mapper.inherit_condition) tables.add(mapper.local_table) @@ -3043,11 +3203,13 @@ class Mapper( # descendant-most class should all be present and joined to each # other. try: - allconds[0] = visitors.cloned_traverse( + _traversed = visitors.cloned_traverse( allconds[0], {}, {"binary": visit_binary} ) except _OptGetColumnsNotAvailable: return None + else: + allconds[0] = _traversed cond = sql.and_(*allconds) @@ -3145,6 +3307,8 @@ class Mapper( for pk in self.primary_key ] + in_expr: ColumnElement[Any] + if len(primary_key) > 1: in_expr = sql.tuple_(*primary_key) else: @@ -3209,11 +3373,22 @@ class Mapper( traverse all objects without relying on cascades. """ - visited_states = set() + visited_states: Set[InstanceState[Any]] = set() prp, mpp = object(), object() assert state.mapper.isa(self) + # this is actually a recursive structure, fully typing it seems + # a little too difficult for what it's worth here + visitables: Deque[ + Tuple[ + Deque[Any], + object, + Optional[InstanceState[Any]], + Optional[_InstanceDict], + ] + ] + visitables = deque( [(deque(state.mapper._props.values()), prp, state, state.dict)] ) @@ -3226,8 +3401,10 @@ class Mapper( if item_type is prp: prop = iterator.popleft() - if type_ not in prop.cascade: + if not prop.cascade or type_ not in prop.cascade: continue + assert parent_state is not None + assert parent_dict is not None queue = deque( prop.cascade_iterator( type_, @@ -3267,7 +3444,7 @@ class Mapper( @HasMemoized.memoized_attribute def _sorted_tables(self): - table_to_mapper = {} + table_to_mapper: Dict[Table, Mapper[Any]] = {} for mapper in self.base_mapper.self_and_descendants: for t in mapper.tables: @@ -3316,9 +3493,9 @@ class Mapper( ret[t] = table_to_mapper[t] return ret - def _memo(self, key, callable_): + def _memo(self, key: Any, callable_: Callable[[], _T]) -> _T: if key in self._memoized_values: - return self._memoized_values[key] + return cast(_T, self._memoized_values[key]) else: self._memoized_values[key] = value = callable_() return value @@ -3328,14 +3505,22 @@ class Mapper( """memoized map of tables to collections of columns to be synchronized upwards to the base mapper.""" - result = util.defaultdict(list) + result: util.defaultdict[ + Table, + List[ + Tuple[ + Mapper[Any], + List[Tuple[ColumnElement[Any], ColumnElement[Any]]], + ] + ], + ] = util.defaultdict(list) for table in self._sorted_tables: cols = set(table.c) for m in self.iterate_to_root(): if m._inherits_equated_pairs and cols.intersection( reduce( - set.union, + set.union, # type: ignore [l.proxy_set for l, r in m._inherits_equated_pairs], ) ): @@ -3440,7 +3625,7 @@ def _configure_registries(registries, cascade): else: return - Mapper.dispatch._for_class(Mapper).before_configured() + Mapper.dispatch._for_class(Mapper).before_configured() # type: ignore # noqa: E501 # initialize properties on all mappers # note that _mapper_registry is unordered, which # may randomly conceal/reveal issues related to @@ -3449,7 +3634,7 @@ def _configure_registries(registries, cascade): _do_configure_registries(registries, cascade) finally: _already_compiling = False - Mapper.dispatch._for_class(Mapper).after_configured() + Mapper.dispatch._for_class(Mapper).after_configured() # type: ignore @util.preload_module("sqlalchemy.orm.decl_api") @@ -3480,7 +3665,7 @@ def _do_configure_registries(registries, cascade): "Original exception was: %s" % (mapper, mapper._configure_failed) ) - e._configure_failed = mapper._configure_failed + e._configure_failed = mapper._configure_failed # type: ignore raise e if not mapper.configured: @@ -3636,7 +3821,7 @@ def _event_on_init(state, args, kwargs): instrumenting_mapper._set_polymorphic_identity(state) -class _ColumnMapping(dict): +class _ColumnMapping(Dict["ColumnElement[Any]", "MapperProperty[Any]"]): """Error reporting helper for mapper._columntoproperty.""" __slots__ = ("mapper",) diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index e2cf1d5b04..361cea9757 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -13,22 +13,70 @@ from __future__ import annotations from functools import reduce from itertools import chain import logging +import operator from typing import Any +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple +from typing import TYPE_CHECKING from typing import Union from . import base as orm_base +from ._typing import insp_is_mapper_property from .. import exc -from .. import inspection from .. import util from ..sql import visitors from ..sql.cache_key import HasCacheKey +if TYPE_CHECKING: + from ._typing import _InternalEntityType + from .interfaces import MapperProperty + from .mapper import Mapper + from .relationships import Relationship + from .util import AliasedInsp + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.elements import BindParameter + from ..sql.visitors import anon_map + from ..util.typing import TypeGuard + + def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: + ... + + def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: + ... + +else: + is_root = operator.attrgetter("is_root") + is_entity = operator.attrgetter("is_entity") + + +_SerializedPath = List[Any] + +_PathElementType = Union[ + str, "_InternalEntityType[Any]", "MapperProperty[Any]" +] + +# the representation is in fact +# a tuple with alternating: +# [_InternalEntityType[Any], Union[str, MapperProperty[Any]], +# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...] +# this might someday be a tuple of 2-tuples instead, but paths can be +# chopped at odd intervals as well so this is less flexible +_PathRepresentation = Tuple[_PathElementType, ...] + +_OddPathRepresentation = Sequence["_InternalEntityType[Any]"] +_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]] + + log = logging.getLogger(__name__) -def _unreduce_path(path): +def _unreduce_path(path: _SerializedPath) -> PathRegistry: return PathRegistry.deserialize(path) @@ -67,17 +115,18 @@ class PathRegistry(HasCacheKey): is_token = False is_root = False has_entity = False + is_entity = False - path: Tuple - natural_path: Tuple - parent: Union["PathRegistry", None] + path: _PathRepresentation + natural_path: _PathRepresentation + parent: Optional[PathRegistry] + root: RootRegistry - root: "PathRegistry" - _cache_key_traversal = [ + _cache_key_traversal: _CacheKeyTraversalType = [ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list) ] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: try: return other is not None and self.path == other._path_for_compare except AttributeError: @@ -87,7 +136,7 @@ class PathRegistry(HasCacheKey): ) return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: try: return other is None or self.path != other._path_for_compare except AttributeError: @@ -98,74 +147,88 @@ class PathRegistry(HasCacheKey): return True @property - def _path_for_compare(self): + def _path_for_compare(self) -> Optional[_PathRepresentation]: return self.path - def set(self, attributes, key, value): + def set(self, attributes: Dict[Any, Any], key: Any, value: Any) -> None: log.debug("set '%s' on path '%s' to '%s'", key, self, value) attributes[(key, self.natural_path)] = value - def setdefault(self, attributes, key, value): + def setdefault( + self, attributes: Dict[Any, Any], key: Any, value: Any + ) -> None: log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value) attributes.setdefault((key, self.natural_path), value) - def get(self, attributes, key, value=None): + def get( + self, attributes: Dict[Any, Any], key: Any, value: Optional[Any] = None + ) -> Any: key = (key, self.natural_path) if key in attributes: return attributes[key] else: return value - def __len__(self): + def __len__(self) -> int: return len(self.path) - def __hash__(self): + def __hash__(self) -> int: return id(self) - def __getitem__(self, key: Any) -> "PathRegistry": + def __getitem__(self, key: Any) -> PathRegistry: raise NotImplementedError() + # TODO: what are we using this for? @property - def length(self): + def length(self) -> int: return len(self.path) - def pairs(self): - path = self.path - for i in range(0, len(path), 2): - yield path[i], path[i + 1] - - def contains_mapper(self, mapper): - for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]: + def pairs( + self, + ) -> Iterator[ + Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]] + ]: + odd_path = cast(_OddPathRepresentation, self.path) + even_path = cast(_EvenPathRepresentation, odd_path) + for i in range(0, len(odd_path), 2): + yield odd_path[i], even_path[i + 1] + + def contains_mapper(self, mapper: Mapper[Any]) -> bool: + _m_path = cast(_OddPathRepresentation, self.path) + for path_mapper in [_m_path[i] for i in range(0, len(_m_path), 2)]: if path_mapper.is_mapper and path_mapper.isa(mapper): return True else: return False - def contains(self, attributes, key): + def contains(self, attributes: Dict[Any, Any], key: Any) -> bool: return (key, self.path) in attributes - def __reduce__(self): + def __reduce__(self) -> Any: return _unreduce_path, (self.serialize(),) @classmethod - def _serialize_path(cls, path): + def _serialize_path(cls, path: _PathRepresentation) -> _SerializedPath: + _m_path = cast(_OddPathRepresentation, path) + _p_path = cast(_EvenPathRepresentation, path) + return list( zip( - [ + tuple( m.class_ if (m.is_mapper or m.is_aliased_class) else str(m) - for m in [path[i] for i in range(0, len(path), 2)] - ], - [ - path[i].key if (path[i].is_property) else str(path[i]) - for i in range(1, len(path), 2) - ] - + [None], + for m in [_m_path[i] for i in range(0, len(_m_path), 2)] + ), + tuple( + p.key if insp_is_mapper_property(p) else str(p) + for p in [_p_path[i] for i in range(1, len(_p_path), 2)] + ) + + (None,), ) ) @classmethod - def _deserialize_path(cls, path): - def _deserialize_mapper_token(mcls): + def _deserialize_path(cls, path: _SerializedPath) -> _PathRepresentation: + def _deserialize_mapper_token(mcls: Any) -> Any: return ( # note: we likely dont want configure=True here however # this is maintained at the moment for backwards compatibility @@ -174,15 +237,15 @@ class PathRegistry(HasCacheKey): else PathToken._intern[mcls] ) - def _deserialize_key_token(mcls, key): + def _deserialize_key_token(mcls: Any, key: Any) -> Any: if key is None: return None elif key in PathToken._intern: return PathToken._intern[key] else: - return orm_base._inspect_mapped_class( - mcls, configure=True - ).attrs[key] + mp = orm_base._inspect_mapped_class(mcls, configure=True) + assert mp is not None + return mp.attrs[key] p = tuple( chain( @@ -199,28 +262,63 @@ class PathRegistry(HasCacheKey): p = p[0:-1] return p - def serialize(self) -> Sequence[Any]: + def serialize(self) -> _SerializedPath: path = self.path return self._serialize_path(path) @classmethod - def deserialize(cls, path: Sequence[Any]) -> PathRegistry: + def deserialize(cls, path: _SerializedPath) -> PathRegistry: assert path is not None p = cls._deserialize_path(path) return cls.coerce(p) + @overload @classmethod - def per_mapper(cls, mapper): + def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: + ... + + @overload + @classmethod + def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: + ... + + @classmethod + def per_mapper( + cls, mapper: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: if mapper.is_mapper: return CachingEntityRegistry(cls.root, mapper) else: return SlotsEntityRegistry(cls.root, mapper) @classmethod - def coerce(cls, raw): - return reduce(lambda prev, next: prev[next], raw, cls.root) + def coerce(cls, raw: _PathRepresentation) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] + + # can't quite get mypy to appreciate this one :) + return reduce(_red, raw, cls.root) # type: ignore + + def __add__(self, other: PathRegistry) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] - def token(self, token): + return reduce(_red, other.path, self) + + def __str__(self) -> str: + return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.path!r})" + + +class CreatesToken(PathRegistry): + __slots__ = () + + is_aliased_class: bool + is_root: bool + + def token(self, token: str) -> TokenRegistry: if token.endswith(f":{_WILDCARD_TOKEN}"): return TokenRegistry(self, token) elif token.endswith(f":{_DEFAULT_TOKEN}"): @@ -228,34 +326,47 @@ class PathRegistry(HasCacheKey): else: raise exc.ArgumentError(f"invalid token: {token}") - def __add__(self, other): - return reduce(lambda prev, next: prev[next], other.path, self) - - def __str__(self): - return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]" - - def __repr__(self): - return f"{self.__class__.__name__}({self.path!r})" - -class RootRegistry(PathRegistry): +class RootRegistry(CreatesToken): """Root registry, defers to mappers so that paths are maintained per-root-mapper. """ + __slots__ = () + inherit_cache = True path = natural_path = () has_entity = False is_aliased_class = False is_root = True + is_unnatural = False + + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... - def __getitem__(self, entity): + def __getitem__( + self, entity: Union[str, _InternalEntityType[Any]] + ) -> Union[TokenRegistry, AbstractEntityRegistry]: if entity in PathToken._intern: + if TYPE_CHECKING: + assert isinstance(entity, str) return TokenRegistry(self, PathToken._intern[entity]) else: - return inspection.inspect(entity)._path_registry + try: + return entity._path_registry # type: ignore + except AttributeError: + raise IndexError( + f"invalid argument for RootRegistry.__getitem__: {entity}" + ) PathRegistry.root = RootRegistry() @@ -264,17 +375,19 @@ PathRegistry.root = RootRegistry() class PathToken(orm_base.InspectionAttr, HasCacheKey, str): """cacheable string token""" - _intern = {} + _intern: Dict[str, PathToken] = {} - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: return (str(self),) @property - def _path_for_compare(self): + def _path_for_compare(self) -> Optional[_PathRepresentation]: return None @classmethod - def intern(cls, strvalue): + def intern(cls, strvalue: str) -> PathToken: if strvalue in cls._intern: return cls._intern[strvalue] else: @@ -287,7 +400,10 @@ class TokenRegistry(PathRegistry): inherit_cache = True - def __init__(self, parent, token): + token: str + parent: CreatesToken + + def __init__(self, parent: CreatesToken, token: str): token = PathToken.intern(token) self.token = token @@ -299,21 +415,33 @@ class TokenRegistry(PathRegistry): is_token = True - def generate_for_superclasses(self): - if not self.parent.is_aliased_class and not self.parent.is_root: - for ent in self.parent.mapper.iterate_to_root(): - yield TokenRegistry(self.parent.parent[ent], self.token) + def generate_for_superclasses(self) -> Iterator[PathRegistry]: + parent = self.parent + if is_root(parent): + yield self + return + + if TYPE_CHECKING: + assert isinstance(parent, AbstractEntityRegistry) + if not parent.is_aliased_class: + for mp_ent in parent.mapper.iterate_to_root(): + yield TokenRegistry(parent.parent[mp_ent], self.token) elif ( - self.parent.is_aliased_class - and self.parent.entity._is_with_polymorphic + parent.is_aliased_class + and cast( + "AliasedInsp[Any]", + parent.entity, + )._is_with_polymorphic ): yield self - for ent in self.parent.entity._with_polymorphic_entities: - yield TokenRegistry(self.parent.parent[ent], self.token) + for ent in cast( + "AliasedInsp[Any]", parent.entity + )._with_polymorphic_entities: + yield TokenRegistry(parent.parent[ent], self.token) else: yield self - def __getitem__(self, entity): + def __getitem__(self, entity: Any) -> Any: try: return self.path[entity] except TypeError as err: @@ -321,23 +449,42 @@ class TokenRegistry(PathRegistry): class PropRegistry(PathRegistry): - is_unnatural = False + __slots__ = ( + "prop", + "parent", + "path", + "natural_path", + "has_entity", + "entity", + "mapper", + "_wildcard_path_loader_key", + "_default_path_loader_key", + "_loader_key", + "is_unnatural", + ) inherit_cache = True - def __init__(self, parent, prop): + prop: MapperProperty[Any] + mapper: Optional[Mapper[Any]] + entity: Optional[_InternalEntityType[Any]] + + def __init__( + self, parent: AbstractEntityRegistry, prop: MapperProperty[Any] + ): # restate this path in terms of the # given MapperProperty's parent. - insp = inspection.inspect(parent[-1]) - natural_parent = parent + insp = cast("_InternalEntityType[Any]", parent[-1]) + natural_parent: AbstractEntityRegistry = parent + self.is_unnatural = False - if not insp.is_aliased_class or insp._use_mapper_path: + if not insp.is_aliased_class or insp._use_mapper_path: # type: ignore parent = natural_parent = parent.parent[prop.parent] elif ( insp.is_aliased_class and insp.with_polymorphic_mappers and prop.parent in insp.with_polymorphic_mappers ): - subclass_entity = parent[-1]._entity_for_mapper(prop.parent) + subclass_entity: _InternalEntityType[Any] = parent[-1]._entity_for_mapper(prop.parent) # type: ignore # noqa: E501 parent = parent.parent[subclass_entity] # when building a path where with_polymorphic() is in use, @@ -388,43 +535,74 @@ class PropRegistry(PathRegistry): self.parent = parent self.path = parent.path + (prop,) self.natural_path = natural_parent.natural_path + (prop,) + self.has_entity = prop._links_to_entity + if prop._is_relationship: + if TYPE_CHECKING: + assert isinstance(prop, Relationship) + self.entity = prop.entity + self.mapper = prop.mapper + else: + self.entity = None + self.mapper = None self._wildcard_path_loader_key = ( "loader", - parent.path + self.prop._wildcard_token, + parent.path + self.prop._wildcard_token, # type: ignore ) self._default_path_loader_key = self.prop._default_path_loader_key self._loader_key = ("loader", self.natural_path) - @util.memoized_property - def has_entity(self): - return self.prop._links_to_entity + @property + def entity_path(self) -> AbstractEntityRegistry: + assert self.entity is not None + return self[self.entity] - @util.memoized_property - def entity(self): - return self.prop.entity + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... - @property - def mapper(self): - return self.prop.mapper + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... - @property - def entity_path(self): - return self[self.entity] + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... - def __getitem__(self, entity): + def __getitem__( + self, entity: Union[int, slice, _InternalEntityType[Any]] + ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]: if isinstance(entity, (int, slice)): return self.path[entity] else: return SlotsEntityRegistry(self, entity) -class AbstractEntityRegistry(PathRegistry): - __slots__ = () +class AbstractEntityRegistry(CreatesToken): + __slots__ = ( + "key", + "parent", + "is_aliased_class", + "path", + "entity", + "natural_path", + ) has_entity = True - - def __init__(self, parent, entity): + is_entity = True + + parent: Union[RootRegistry, PropRegistry] + key: _InternalEntityType[Any] + entity: _InternalEntityType[Any] + is_aliased_class: bool + + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): self.key = entity self.parent = parent self.is_aliased_class = entity.is_aliased_class @@ -447,11 +625,11 @@ class AbstractEntityRegistry(PathRegistry): if parent.path and (self.is_aliased_class or parent.is_unnatural): # this is an infrequent code path used only for loader strategies # that also make use of of_type(). - if entity.mapper.isa(parent.natural_path[-1].entity): + if entity.mapper.isa(parent.natural_path[-1].entity): # type: ignore # noqa: E501 self.natural_path = parent.natural_path + (entity.mapper,) else: self.natural_path = parent.natural_path + ( - parent.natural_path[-1].entity, + parent.natural_path[-1].entity, # type: ignore ) # it seems to make sense that since these paths get mixed up # with statements that are cached or not, we should make @@ -465,19 +643,35 @@ class AbstractEntityRegistry(PathRegistry): self.natural_path = self.path @property - def entity_path(self): + def entity_path(self) -> PathRegistry: return self @property - def mapper(self): - return inspection.inspect(self.entity).mapper + def mapper(self) -> Mapper[Any]: + return self.entity.mapper - def __bool__(self): + def __bool__(self) -> bool: return True - __nonzero__ = __bool__ + @overload + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: + ... + + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... - def __getitem__(self, entity): + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... + + def __getitem__( + self, entity: Any + ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]: if isinstance(entity, (int, slice)): return self.path[entity] elif entity in PathToken._intern: @@ -491,31 +685,40 @@ class SlotsEntityRegistry(AbstractEntityRegistry): # version inherit_cache = True - __slots__ = ( - "key", - "parent", - "is_aliased_class", - "entity", - "path", - "natural_path", - ) + +class _ERDict(Dict[Any, Any]): + def __init__(self, registry: CachingEntityRegistry): + self.registry = registry + + def __missing__(self, key: Any) -> PropRegistry: + self[key] = item = PropRegistry(self.registry, key) + + return item -class CachingEntityRegistry(AbstractEntityRegistry, dict): +class CachingEntityRegistry(AbstractEntityRegistry): # for long lived mapper, return dict based caching # version that creates reference cycles + __slots__ = ("_cache",) + inherit_cache = True - def __getitem__(self, entity): + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): + super().__init__(parent, entity) + self._cache = _ERDict(self) + + def pop(self, key: Any, default: Any) -> Any: + return self._cache.pop(key, default) + + def __getitem__(self, entity: Any) -> Any: if isinstance(entity, (int, slice)): return self.path[entity] elif isinstance(entity, PathToken): return TokenRegistry(self, entity) else: - return dict.__getitem__(self, entity) - - def __missing__(self, key): - self[key] = item = PropRegistry(self, key) - - return item + return self._cache[entity] diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c01825b6d6..9f37e84571 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -19,6 +19,8 @@ from typing import cast from typing import List from typing import Optional from typing import Set +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from . import attributes @@ -38,17 +40,22 @@ from .util import _orm_full_deannotate from .. import exc as sa_exc from .. import ForeignKey from .. import log -from .. import sql from .. import util from ..sql import coercions from ..sql import roles from ..sql import sqltypes from ..sql.schema import Column +from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import NoneType +if TYPE_CHECKING: + from ._typing import _ORMColumnExprArgument + from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -78,6 +85,10 @@ class ColumnProperty( inherit_cache = True _links_to_entity = False + columns: List[ColumnElement[Any]] + + _is_polymorphic_discriminator: bool + __slots__ = ( "_orig_columns", "columns", @@ -99,7 +110,19 @@ class ColumnProperty( ) def __init__( - self, column: sql.ColumnElement[_T], *additional_columns, **kwargs + self, + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator]] = None, + descriptor: Optional[Any] = None, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + _instrument: bool = True, ): super(ColumnProperty, self).__init__() columns = (column,) + additional_columns @@ -112,23 +135,24 @@ class ColumnProperty( ) for c in columns ] - self.parent = self.key = None - self.group = kwargs.pop("group", None) - self.deferred = kwargs.pop("deferred", False) - self.raiseload = kwargs.pop("raiseload", False) - self.instrument = kwargs.pop("_instrument", True) - self.comparator_factory = kwargs.pop( - "comparator_factory", self.__class__.Comparator + self.group = group + self.deferred = deferred + self.raiseload = raiseload + self.instrument = _instrument + self.comparator_factory = ( + comparator_factory + if comparator_factory is not None + else self.__class__.Comparator ) - self.descriptor = kwargs.pop("descriptor", None) - self.active_history = kwargs.pop("active_history", False) - self.expire_on_flush = kwargs.pop("expire_on_flush", True) + self.descriptor = descriptor + self.active_history = active_history + self.expire_on_flush = expire_on_flush - if "info" in kwargs: - self.info = kwargs.pop("info") + if info is not None: + self.info = info - if "doc" in kwargs: - self.doc = kwargs.pop("doc") + if doc is not None: + self.doc = doc else: for col in reversed(self.columns): doc = getattr(col, "doc", None) @@ -138,12 +162,6 @@ class ColumnProperty( else: self.doc = None - if kwargs: - raise TypeError( - "%s received unexpected keyword argument(s): %s" - % (self.__class__.__name__, ", ".join(sorted(kwargs.keys()))) - ) - util.set_creation_order(self) self.strategy_key = ( @@ -445,7 +463,10 @@ class MappedColumn( self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) self.foreign_keys = self.column.foreign_keys - self._has_nullable = "nullable" in kw + self._has_nullable = "nullable" in kw and kw.get("nullable") not in ( + None, + SchemaConst.NULL_UNSPECIFIED, + ) util.set_creation_order(self) def _copy(self, **kw): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a754bd4f2a..395d01a1ea 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -30,6 +30,7 @@ from typing import Optional from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import exc as orm_exc from . import interfaces @@ -77,6 +78,8 @@ from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL if TYPE_CHECKING: from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import Alias + from ..sql.selectable import Subquery __all__ = ["Query", "QueryContext"] @@ -2769,14 +2772,14 @@ class AliasOption(interfaces.LoaderOption): "for entities to be matched up to a query that is established " "via :meth:`.Query.from_statement` and now does nothing.", ) - def __init__(self, alias): + def __init__(self, alias: Union[Alias, Subquery]): r"""Return a :class:`.MapperOption` that will indicate to the :class:`_query.Query` that the main table has been aliased. """ - def process_compile_state(self, compile_state): + def process_compile_state(self, compile_state: ORMCompileState): pass diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 58c7c4efd5..66021c9c20 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -21,7 +21,10 @@ import re import typing from typing import Any from typing import Callable +from typing import Dict from typing import Optional +from typing import Sequence +from typing import Tuple from typing import Type from typing import TypeVar from typing import Union @@ -30,6 +33,7 @@ import weakref from . import attributes from . import strategy_options from .base import _is_mapped_class +from .base import class_mapper from .base import state_str from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY @@ -53,7 +57,9 @@ from ..sql import expression from ..sql import operators from ..sql import roles from ..sql import visitors -from ..sql.elements import SQLCoreOperations +from ..sql._typing import _ColumnExpressionArgument +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnClause from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate from ..sql.util import adapt_criterion_to_null @@ -61,11 +67,14 @@ from ..sql.util import ClauseAdapter from ..sql.util import join_condition from ..sql.util import selectables_overlap from ..sql.util import visit_binary_product +from ..util.typing import Literal if typing.TYPE_CHECKING: + from ._typing import _EntityType from .mapper import Mapper from .util import AliasedClass from .util import AliasedInsp + from ..sql.elements import ColumnElement _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -81,6 +90,34 @@ _RelationshipArgumentType = Union[ Callable[[], "AliasedClass[_T]"], ] +_LazyLoadArgumentType = Literal[ + "select", + "joined", + "selectin", + "subquery", + "raise", + "raise_on_sql", + "noload", + "immediate", + "dynamic", + True, + False, + None, +] + + +_RelationshipJoinConditionArgument = Union[ + str, _ColumnExpressionArgument[bool] +] +_ORMOrderByArgument = Union[ + Literal[False], str, _ColumnExpressionArgument[Any] +] +_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] +_ORMColCollectionArgument = Union[ + str, + Sequence[Union[ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole]], +] + def remote(expr): """Annotate a portion of a primaryjoin expression @@ -144,6 +181,7 @@ class Relationship( inherit_cache = True _links_to_entity = True + _is_relationship = True _persistence_only = dict( passive_deletes=False, @@ -159,38 +197,39 @@ class Relationship( self, argument: Optional[_RelationshipArgumentType[_T]] = None, secondary=None, + *, + uselist=None, + collection_class=None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - uselist=None, + back_populates=None, order_by=False, backref=None, - back_populates=None, + cascade_backrefs=False, overlaps=None, post_update=False, - cascade=False, + cascade="save-update, merge", viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=_persistence_only["passive_deletes"], - passive_updates=_persistence_only["passive_updates"], + lazy: _LazyLoadArgumentType = "select", + passive_deletes=False, + passive_updates=True, + active_history=False, + enable_typechecks=True, + foreign_keys=None, remote_side=None, - enable_typechecks=_persistence_only["enable_typechecks"], join_depth=None, comparator_factory=None, single_parent=False, innerjoin=False, distinct_target_key=None, - doc=None, - active_history=_persistence_only["active_history"], - cascade_backrefs=_persistence_only["cascade_backrefs"], load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, query_class=None, info=None, omit_join=None, sync_backref=None, + doc=None, + bake_queries=True, + _local_remote_pairs=None, _legacy_inactive_history_style=False, ): super(Relationship, self).__init__() @@ -250,7 +289,6 @@ class Relationship( self.omit_join = omit_join self.local_remote_pairs = _local_remote_pairs - self.bake_queries = bake_queries self.load_on_pending = load_on_pending self.comparator_factory = comparator_factory or Relationship.Comparator self.comparator = self.comparator_factory(self, None) @@ -267,12 +305,7 @@ class Relationship( else: self._overlaps = () - if cascade is not False: - self.cascade = cascade - elif self.viewonly: - self.cascade = "none" - else: - self.cascade = "save-update, merge" + self.cascade = cascade self.order_by = order_by @@ -539,9 +572,9 @@ class Relationship( def _criterion_exists( self, - criterion: Optional[SQLCoreOperations[Any]] = None, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> Exists[bool]: + ) -> Exists: if getattr(self, "_of_type", None): info = inspect(self._of_type) target_mapper, to_selectable, is_aliased_class = ( @@ -898,7 +931,12 @@ class Relationship( comparator: Comparator[_T] - def _with_parent(self, instance, alias_secondary=True, from_entity=None): + def _with_parent( + self, + instance: object, + alias_secondary: bool = True, + from_entity: Optional[_EntityType[Any]] = None, + ) -> ColumnElement[bool]: assert instance is not None adapt_source = None if from_entity is not None: @@ -1502,7 +1540,7 @@ class Relationship( argument = argument if isinstance(argument, type): - entity = mapperlib.class_mapper(argument, configure=False) + entity = class_mapper(argument, configure=False) else: try: entity = inspect(argument) @@ -1568,7 +1606,7 @@ class Relationship( """Test that this relationship is legal, warn about inheritance conflicts.""" mapperlib = util.preloaded.orm_mapper - if self.parent.non_primary and not mapperlib.class_mapper( + if self.parent.non_primary and not class_mapper( self.parent.class_, configure=False ).has_property(self.key): raise sa_exc.ArgumentError( @@ -1585,29 +1623,23 @@ class Relationship( ) @property - def cascade(self): + def cascade(self) -> CascadeOptions: """Return the current cascade setting for this :class:`.Relationship`. """ return self._cascade @cascade.setter - def cascade(self, cascade): + def cascade(self, cascade: Union[str, CascadeOptions]): self._set_cascade(cascade) - def _set_cascade(self, cascade): - cascade = CascadeOptions(cascade) + def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]): + cascade = CascadeOptions(cascade_arg) if self.viewonly: - non_viewonly = set(cascade).difference( - CascadeOptions._viewonly_cascades + cascade = CascadeOptions( + cascade.intersection(CascadeOptions._viewonly_cascades) ) - if non_viewonly: - raise sa_exc.ArgumentError( - 'Cascade settings "%s" apply to persistence operations ' - "and should not be combined with a viewonly=True " - "relationship." % (", ".join(sorted(non_viewonly))) - ) if "mapper" in self.__dict__: self._check_cascade_settings(cascade) @@ -1754,8 +1786,8 @@ class Relationship( relationship = Relationship( parent, self.secondary, - pj, - sj, + primaryjoin=pj, + secondaryjoin=sj, foreign_keys=foreign_keys, back_populates=self.key, **kwargs, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 5b1d0bb087..74035ec0aa 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -39,6 +39,7 @@ from . import persistence from . import query from . import state as statelib from ._typing import _O +from ._typing import insp_is_mapper from ._typing import is_composite_class from ._typing import is_user_defined_option from .base import _class_to_mapper @@ -69,12 +70,14 @@ from ..engine.util import TransactionalContext from ..event import dispatcher from ..event import EventTarget from ..inspection import inspect +from ..inspection import Inspectable from ..sql import coercions from ..sql import dml from ..sql import roles from ..sql import Select from ..sql import visitors from ..sql.base import CompileState +from ..sql.schema import Table from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import IdentitySet @@ -90,6 +93,7 @@ if typing.TYPE_CHECKING: from .path_registry import PathRegistry from ..engine import Result from ..engine import Row + from ..engine import RowMapping from ..engine.base import Transaction from ..engine.base import TwoPhaseTransaction from ..engine.interfaces import _CoreAnyExecuteParams @@ -103,6 +107,7 @@ if typing.TYPE_CHECKING: from ..sql.base import Executable from ..sql.elements import ClauseElement from ..sql.schema import Table + from ..sql.selectable import TableClause __all__ = [ "Session", @@ -184,7 +189,7 @@ class _SessionClassMethods: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row, RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: """Return an identity key. @@ -2050,9 +2055,12 @@ class Session(_SessionClassMethods, EventTarget): else: self.__binds[key] = bind else: - if insp.is_selectable: + if TYPE_CHECKING: + assert isinstance(insp, Inspectable) + + if isinstance(insp, Table): self.__binds[insp] = bind - elif insp.is_mapper: + elif insp_is_mapper(insp): self.__binds[insp.class_] = bind for _selectable in insp._all_tables: self.__binds[_selectable] = bind @@ -2211,7 +2219,7 @@ class Session(_SessionClassMethods, EventTarget): # we don't have self.bind and either have self.__binds # or we don't have self.__binds (which is legacy). Look at the # mapper and the clause - if mapper is clause is None: + if mapper is None and clause is None: if self.bind: return self.bind else: @@ -2350,7 +2358,10 @@ class Session(_SessionClassMethods, EventTarget): key = mapper.identity_key_from_primary_key( primary_key_identity, identity_token=identity_token ) - return loading.get_from_identity(self, mapper, key, passive) + + # work around: https://github.com/python/typing/discussions/1143 + return_value = loading.get_from_identity(self, mapper, key, passive) + return return_value @util.non_memoized_property @contextlib.contextmanager diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 85e0151937..2d85ba7f64 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -2162,10 +2162,11 @@ class JoinedLoader(AbstractRelationshipLoader): else: to_adapt = self._gen_pooled_aliased_class(compile_state) - clauses = inspect(to_adapt)._memo( + to_adapt_insp = inspect(to_adapt) + clauses = to_adapt_insp._memo( ("joinedloader_ormadapter", self), orm_util.ORMAdapter, - to_adapt, + to_adapt_insp, equivalents=self.mapper._equivalent_columns, adapt_required=True, allow_label_resolve=False, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4699781a42..3934de5355 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -11,8 +11,16 @@ import re import types import typing from typing import Any +from typing import cast +from typing import Dict +from typing import FrozenSet from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Match from typing import Optional +from typing import Sequence from typing import Tuple from typing import Type from typing import TypeVar @@ -20,32 +28,35 @@ from typing import Union import weakref from . import attributes # noqa -from .base import _class_to_mapper # noqa -from .base import _never_set # noqa -from .base import _none_set # noqa -from .base import attribute_str # noqa -from .base import class_mapper # noqa -from .base import InspectionAttr # noqa -from .base import instance_str # noqa -from .base import object_mapper # noqa -from .base import object_state # noqa -from .base import state_attribute_str # noqa -from .base import state_class_str # noqa -from .base import state_str # noqa +from ._typing import _O +from ._typing import insp_is_aliased_class +from ._typing import insp_is_mapper +from ._typing import prop_is_relationship +from .base import _class_to_mapper as _class_to_mapper +from .base import _never_set as _never_set +from .base import _none_set as _none_set +from .base import attribute_str as attribute_str +from .base import class_mapper as class_mapper +from .base import InspectionAttr as InspectionAttr +from .base import instance_str as instance_str +from .base import object_mapper as object_mapper +from .base import object_state as object_state +from .base import state_attribute_str as state_attribute_str +from .base import state_class_str as state_class_str +from .base import state_str as state_str from .interfaces import CriteriaOption -from .interfaces import MapperProperty # noqa +from .interfaces import MapperProperty as MapperProperty from .interfaces import ORMColumnsClauseRole from .interfaces import ORMEntityColumnsClauseRole from .interfaces import ORMFromClauseRole -from .interfaces import PropComparator # noqa -from .path_registry import PathRegistry # noqa +from .interfaces import PropComparator as PropComparator +from .path_registry import PathRegistry as PathRegistry from .. import event from .. import exc as sa_exc from .. import inspection from .. import sql from .. import util from ..engine.result import result_tuple -from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import lambdas @@ -54,19 +65,39 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection +from ..sql.cache_key import HasCacheKey +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.elements import ColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation from ..util.typing import is_origin_of +from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _EntityType from ._typing import _IdentityKeyType from ._typing import _InternalEntityType + from ._typing import _ORMColumnExprArgument + from .context import _MapperEntity + from .context import ORMCompileState from .mapper import Mapper + from .relationships import Relationship from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _FromClauseArgument + from ..sql._typing import _OnClauseArgument from ..sql._typing import _PropagateAttrsType + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import BindParameter + from ..sql.selectable import _ColumnsClauseElement from ..sql.selectable import Alias + from ..sql.selectable import Subquery + from ..sql.visitors import _ET + from ..sql.visitors import anon_map + from ..sql.visitors import ExternallyTraversible _T = TypeVar("_T", bound=Any) @@ -84,7 +115,7 @@ all_cascades = frozenset( ) -class CascadeOptions(frozenset): +class CascadeOptions(FrozenSet[str]): """Keeps track of the options sent to :paramref:`.relationship.cascade`""" @@ -104,6 +135,13 @@ class CascadeOptions(frozenset): "delete_orphan", ) + save_update: bool + delete: bool + refresh_expire: bool + merge: bool + expunge: bool + delete_orphan: bool + def __new__(cls, value_list): if isinstance(value_list, str) or value_list is None: return cls.from_string(value_list) @@ -127,7 +165,7 @@ class CascadeOptions(frozenset): values.clear() values.discard("all") - self = frozenset.__new__(CascadeOptions, values) + self = super().__new__(cls, values) # type: ignore self.save_update = "save-update" in values self.delete = "delete" in values self.refresh_expire = "refresh-expire" in values @@ -238,7 +276,7 @@ def polymorphic_union( """ - colnames = util.OrderedSet() + colnames: util.OrderedSet[str] = util.OrderedSet() colnamemaps = {} types = {} for key in table_map: @@ -299,13 +337,13 @@ def polymorphic_union( def identity_key( - class_: Optional[Type[Any]] = None, + class_: Optional[Type[_T]] = None, ident: Union[Any, Tuple[Any, ...]] = None, *, - instance: Optional[Any] = None, - row: Optional[Row] = None, + instance: Optional[_T] = None, + row: Optional[Union[Row, RowMapping]] = None, identity_token: Optional[Any] = None, -) -> _IdentityKeyType: +) -> _IdentityKeyType[_T]: r"""Generate "identity key" tuples, as are used as keys in the :attr:`.Session.identity_map` dictionary. @@ -351,7 +389,7 @@ def identity_key( * ``identity_key(class, row=row, identity_token=token)`` This form is similar to the class/tuple form, except is passed a - database result row as a :class:`.Row` object. + database result row as a :class:`.Row` or :class:`.RowMapping` object. E.g.:: @@ -375,7 +413,7 @@ def identity_key( if ident is None: raise sa_exc.ArgumentError("ident or row is required") return mapper.identity_key_from_primary_key( - util.to_list(ident), identity_token=identity_token + tuple(util.to_list(ident)), identity_token=identity_token ) else: return mapper.identity_key_from_row( @@ -394,24 +432,26 @@ class ORMAdapter(sql_util.ColumnAdapter): """ - is_aliased_class = False - aliased_insp = None + is_aliased_class: bool + aliased_insp: Optional[AliasedInsp[Any]] def __init__( self, - entity, - equivalents=None, - adapt_required=False, - allow_label_resolve=True, - anonymize_labels=False, + entity: _InternalEntityType[Any], + equivalents: Optional[_EquivalentColumnMap] = None, + adapt_required: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, ): - info = inspection.inspect(entity) - self.mapper = info.mapper - selectable = info.selectable - if info.is_aliased_class: + self.mapper = entity.mapper + selectable = entity.selectable + if insp_is_aliased_class(entity): self.is_aliased_class = True - self.aliased_insp = info + self.aliased_insp = entity + else: + self.is_aliased_class = False + self.aliased_insp = None sql_util.ColumnAdapter.__init__( self, @@ -428,7 +468,7 @@ class ORMAdapter(sql_util.ColumnAdapter): return not entity or entity.isa(self.mapper) -class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): +class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]): r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` @@ -489,19 +529,20 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): def __init__( self, - mapped_class_or_ac: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, - # TODO: None for default here? - with_polymorphic_mappers=(), - with_polymorphic_discriminator=None, - base_alias=None, - use_mapper_path=False, - represents_outer_join=False, + mapped_class_or_ac: _EntityType[_O], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]] = None, + with_polymorphic_discriminator: Optional[ColumnElement[Any]] = None, + base_alias: Optional[AliasedInsp[Any]] = None, + use_mapper_path: bool = False, + represents_outer_join: bool = False, ): - insp = inspection.inspect(mapped_class_or_ac) + insp = cast( + "_InternalEntityType[_O]", inspection.inspect(mapped_class_or_ac) + ) mapper = insp.mapper nest_adapters = False @@ -519,6 +560,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): elif insp.is_aliased_class: nest_adapters = True + assert alias is not None self._aliased_insp = AliasedInsp( self, insp, @@ -540,7 +582,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): self.__name__ = f"aliased({mapper.class_.__name__})" @classmethod - def _reconstitute_from_aliased_insp(cls, aliased_insp): + def _reconstitute_from_aliased_insp( + cls, aliased_insp: AliasedInsp[_O] + ) -> AliasedClass[_O]: obj = cls.__new__(cls) obj.__name__ = f"aliased({aliased_insp.mapper.class_.__name__})" obj._aliased_insp = aliased_insp @@ -555,7 +599,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return obj - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: try: _aliased_insp = self.__dict__["_aliased_insp"] except KeyError: @@ -584,7 +628,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return attr - def _get_from_serialized(self, key, mapped_class, aliased_insp): + def _get_from_serialized( + self, key: str, mapped_class: _O, aliased_insp: AliasedInsp[_O] + ) -> Any: # this method is only used in terms of the # sqlalchemy.ext.serializer extension attr = getattr(mapped_class, key) @@ -605,23 +651,25 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return attr - def __repr__(self): + def __repr__(self) -> str: return "" % ( id(self), self._aliased_insp._target.__name__, ) - def __str__(self): + def __str__(self) -> str: return str(self._aliased_insp) +@inspection._self_inspects class AliasedInsp( ORMEntityColumnsClauseRole, ORMFromClauseRole, - sql_base.HasCacheKey, + HasCacheKey, InspectionAttr, MemoizedSlots, - Generic[_T], + inspection.Inspectable["AliasedInsp[_O]"], + Generic[_O], ): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -685,19 +733,36 @@ class AliasedInsp( "_nest_adapters", ) + mapper: Mapper[_O] + selectable: FromClause + _adapter: sql_util.ColumnAdapter + with_polymorphic_mappers: Sequence[Mapper[Any]] + _with_polymorphic_entities: Sequence[AliasedInsp[Any]] + + _weak_entity: weakref.ref[AliasedClass[_O]] + """the AliasedClass that refers to this AliasedInsp""" + + _target: Union[_O, AliasedClass[_O]] + """the thing referred towards by the AliasedClass/AliasedInsp. + + In the vast majority of cases, this is the mapped class. However + it may also be another AliasedClass (alias of alias). + + """ + def __init__( self, - entity: _EntityType, - inspected: _InternalEntityType, - selectable, - name, - with_polymorphic_mappers, - polymorphic_on, - _base_alias, - _use_mapper_path, - adapt_on_names, - represents_outer_join, - nest_adapters, + entity: AliasedClass[_O], + inspected: _InternalEntityType[_O], + selectable: FromClause, + name: Optional[str], + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]], + polymorphic_on: Optional[ColumnElement[Any]], + _base_alias: Optional[AliasedInsp[Any]], + _use_mapper_path: bool, + adapt_on_names: bool, + represents_outer_join: bool, + nest_adapters: bool, ): mapped_class_or_ac = inspected.entity @@ -752,23 +817,22 @@ class AliasedInsp( ) if nest_adapters: + # supports "aliased class of aliased class" use case + assert isinstance(inspected, AliasedInsp) self._adapter = inspected._adapter.wrap(self._adapter) self._adapt_on_names = adapt_on_names self._target = mapped_class_or_ac - # self._target = mapper.class_ # mapped_class_or_ac @classmethod def _alias_factory( cls, - element: Union[ - Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]" - ], - alias=None, - name=None, - flat=False, - adapt_on_names=False, - ) -> Union["AliasedClass[_T]", "Alias"]: + element: Union[_EntityType[_O], FromClause], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + ) -> Union[AliasedClass[_O], FromClause]: if isinstance(element, FromClause): if adapt_on_names: @@ -793,16 +857,16 @@ class AliasedInsp( @classmethod def _with_polymorphic_factory( cls, - base, - classes, - selectable=False, - flat=False, - polymorphic_on=None, - aliased=False, - innerjoin=False, - adapt_on_names=False, - _use_mapper_path=False, - ): + base: Union[_O, Mapper[_O]], + classes: Iterable[Type[Any]], + selectable: Union[Literal[False, None], FromClause] = False, + flat: bool = False, + polymorphic_on: Optional[ColumnElement[Any]] = None, + aliased: bool = False, + innerjoin: bool = False, + adapt_on_names: bool = False, + _use_mapper_path: bool = False, + ) -> AliasedClass[_O]: primary_mapper = _class_to_mapper(base) @@ -816,7 +880,9 @@ class AliasedInsp( classes, selectable, innerjoin=innerjoin ) if aliased or flat: + assert selectable is not None selectable = selectable._anonymous_fromclause(flat=flat) + return AliasedClass( base, selectable, @@ -828,7 +894,7 @@ class AliasedInsp( ) @property - def entity(self): + def entity(self) -> AliasedClass[_O]: # to eliminate reference cycles, the AliasedClass is held weakly. # this produces some situations where the AliasedClass gets lost, # particularly when one is created internally and only the AliasedInsp @@ -844,7 +910,7 @@ class AliasedInsp( is_aliased_class = True "always returns True" - def _memoized_method___clause_element__(self): + def _memoized_method___clause_element__(self) -> FromClause: return self.selectable._annotate( { "parentmapper": self.mapper, @@ -856,7 +922,7 @@ class AliasedInsp( ) @property - def entity_namespace(self): + def entity_namespace(self) -> AliasedClass[_O]: return self.entity _cache_key_traversal = [ @@ -866,7 +932,7 @@ class AliasedInsp( ] @property - def class_(self): + def class_(self) -> Type[_O]: """Return the mapped class ultimately represented by this :class:`.AliasedInsp`.""" return self.mapper.class_ @@ -878,7 +944,7 @@ class AliasedInsp( else: return PathRegistry.per_mapper(self) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "entity": self.entity, "mapper": self.mapper, @@ -893,8 +959,8 @@ class AliasedInsp( "nest_adapters": self._nest_adapters, } - def __setstate__(self, state): - self.__init__( + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__init__( # type: ignore state["entity"], state["mapper"], state["alias"], @@ -908,7 +974,7 @@ class AliasedInsp( state["nest_adapters"], ) - def _merge_with(self, other): + def _merge_with(self, other: AliasedInsp[_O]) -> AliasedInsp[_O]: # assert self._is_with_polymorphic # assert other._is_with_polymorphic @@ -929,7 +995,6 @@ class AliasedInsp( classes, None, innerjoin=not other.represents_outer_join ) selectable = selectable._anonymous_fromclause(flat=True) - return AliasedClass( primary_mapper, selectable, @@ -937,10 +1002,13 @@ class AliasedInsp( with_polymorphic_discriminator=other.polymorphic_on, use_mapper_path=other._use_mapper_path, represents_outer_join=other.represents_outer_join, - ) + )._aliased_insp - def _adapt_element(self, elem, key=None): - d = { + def _adapt_element( + self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None + ) -> _ORMColumnExprArgument[_T]: + assert isinstance(elem, ColumnElement) + d: Dict[str, Any] = { "parententity": self, "parentmapper": self.mapper, } @@ -1084,35 +1152,45 @@ class LoaderCriteriaOption(CriteriaOption): ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), ] + root_entity: Optional[Type[Any]] + entity: Optional[_InternalEntityType[Any]] + where_criteria: Union[ColumnElement[bool], lambdas.DeferredLambdaElement] + deferred_where_criteria: bool + include_aliases: bool + propagate_to_loaders: bool + def __init__( self, - entity_or_base, - where_criteria, - loader_only=False, - include_aliases=False, - propagate_to_loaders=True, - track_closure_variables=True, + entity_or_base: _EntityType[Any], + where_criteria: _ColumnExpressionArgument[bool], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, ): - entity = inspection.inspect(entity_or_base, False) + entity = cast( + "_InternalEntityType[Any]", + inspection.inspect(entity_or_base, False), + ) if entity is None: - self.root_entity = entity_or_base + self.root_entity = cast("Type[Any]", entity_or_base) self.entity = None else: self.root_entity = None self.entity = entity if callable(where_criteria): + if self.root_entity is not None: + wrap_entity = self.root_entity + else: + assert entity is not None + wrap_entity = entity.entity + self.deferred_where_criteria = True self.where_criteria = lambdas.DeferredLambdaElement( - where_criteria, + where_criteria, # type: ignore roles.WhereHavingRole, - lambda_args=( - _WrapUserEntity( - self.root_entity - if self.root_entity is not None - else self.entity.entity, - ), - ), + lambda_args=(_WrapUserEntity(wrap_entity),), opts=lambdas.LambdaOptions( track_closure_variables=track_closure_variables ), @@ -1126,22 +1204,27 @@ class LoaderCriteriaOption(CriteriaOption): self.include_aliases = include_aliases self.propagate_to_loaders = propagate_to_loaders - def _all_mappers(self): + def _all_mappers(self) -> Iterator[Mapper[Any]]: + if self.entity: - for ent in self.entity.mapper.self_and_descendants: - yield ent + for mp_ent in self.entity.mapper.self_and_descendants: + yield mp_ent else: + assert self.root_entity stack = list(self.root_entity.__subclasses__()) while stack: subclass = stack.pop(0) - ent = inspection.inspect(subclass, raiseerr=False) + ent = cast( + "_InternalEntityType[Any]", + inspection.inspect(subclass, raiseerr=False), + ) if ent: for mp in ent.mapper.self_and_descendants: yield mp else: stack.extend(subclass.__subclasses__()) - def _should_include(self, compile_state): + def _should_include(self, compile_state: ORMCompileState) -> bool: if ( compile_state.select_statement._annotations.get( "for_loader_criteria", None @@ -1151,21 +1234,29 @@ class LoaderCriteriaOption(CriteriaOption): return False return True - def _resolve_where_criteria(self, ext_info): + def _resolve_where_criteria( + self, ext_info: _InternalEntityType[Any] + ) -> ColumnElement[bool]: if self.deferred_where_criteria: - crit = self.where_criteria._resolve_with_args(ext_info.entity) + crit = cast( + "ColumnElement[bool]", + self.where_criteria._resolve_with_args(ext_info.entity), + ) else: - crit = self.where_criteria + crit = self.where_criteria # type: ignore + assert isinstance(crit, ColumnElement) return sql_util._deep_annotate( crit, {"for_loader_criteria": self}, detect_subquery_cols=True ) def process_compile_state_replaced_entities( - self, compile_state, mapper_entities - ): - return self.process_compile_state(compile_state) + self, + compile_state: ORMCompileState, + mapper_entities: Iterable[_MapperEntity], + ) -> None: + self.process_compile_state(compile_state) - def process_compile_state(self, compile_state): + def process_compile_state(self, compile_state: ORMCompileState) -> None: """Apply a modification to a given :class:`.CompileState`.""" # if options to limit the criteria to immediate query only, @@ -1173,7 +1264,7 @@ class LoaderCriteriaOption(CriteriaOption): self.get_global_criteria(compile_state.global_attributes) - def get_global_criteria(self, attributes): + def get_global_criteria(self, attributes: Dict[Any, Any]) -> None: for mp in self._all_mappers(): load_criteria = attributes.setdefault( ("additional_entity_criteria", mp), [] @@ -1183,14 +1274,14 @@ class LoaderCriteriaOption(CriteriaOption): inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) -inspection._inspects(AliasedInsp)(lambda target: target) @inspection._self_inspects class Bundle( ORMColumnsClauseRole, SupportsCloneAnnotations, - sql_base.MemoizedHasCacheKey, + MemoizedHasCacheKey, + inspection.Inspectable["Bundle"], InspectionAttr, ): """A grouping of SQL expressions that are returned by a :class:`.Query` @@ -1227,7 +1318,11 @@ class Bundle( _propagate_attrs: _PropagateAttrsType = util.immutabledict() - def __init__(self, name, *exprs, **kw): + exprs: List[_ColumnsClauseElement] + + def __init__( + self, name: str, *exprs: _ColumnExpressionArgument[Any], **kw: Any + ): r"""Construct a new :class:`.Bundle`. e.g.:: @@ -1246,37 +1341,43 @@ class Bundle( """ self.name = self._label = name - self.exprs = exprs = [ + coerced_exprs = [ coercions.expect( roles.ColumnsClauseRole, expr, apply_propagate_attrs=self ) for expr in exprs ] + self.exprs = coerced_exprs self.c = self.columns = ColumnCollection( (getattr(col, "key", col._label), col) - for col in [e._annotations.get("bundle", e) for e in exprs] - ) + for col in [e._annotations.get("bundle", e) for e in coerced_exprs] + ).as_readonly() self.single_entity = kw.pop("single_entity", self.single_entity) - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: return (self.__class__, self.name, self.single_entity) + tuple( [expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs] ) @property - def mapper(self): + def mapper(self) -> Mapper[Any]: return self.exprs[0]._annotations.get("parentmapper", None) @property - def entity(self): + def entity(self) -> _InternalEntityType[Any]: return self.exprs[0]._annotations.get("parententity", None) @property - def entity_namespace(self): + def entity_namespace( + self, + ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: return self.c - columns = None + columns: ReadOnlyColumnCollection[str, ColumnElement[Any]] + """A namespace of SQL expressions referred to by this :class:`.Bundle`. e.g.:: @@ -1301,7 +1402,7 @@ class Bundle( """ - c = None + c: ReadOnlyColumnCollection[str, ColumnElement[Any]] """An alias for :attr:`.Bundle.columns`.""" def _clone(self): @@ -1400,32 +1501,30 @@ class _ORMJoin(expression.Join): def __init__( self, - left, - right, - onclause=None, - isouter=False, - full=False, - _left_memo=None, - _right_memo=None, - _extra_criteria=(), + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + _left_memo: Optional[Any] = None, + _right_memo: Optional[Any] = None, + _extra_criteria: Sequence[ColumnElement[bool]] = (), ): - left_info = inspection.inspect(left) + left_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(left), + ) - right_info = inspection.inspect(right) + right_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(right), + ) adapt_to = right_info.selectable # used by joined eager loader self._left_memo = _left_memo self._right_memo = _right_memo - # legacy, for string attr name ON clause. if that's removed - # then the "_joined_from_info" concept can go - left_orm_info = getattr(left, "_joined_from_info", left_info) - self._joined_from_info = right_info - if isinstance(onclause, str): - onclause = getattr(left_orm_info.entity, onclause) - # #### - if isinstance(onclause, attributes.QueryableAttribute): on_selectable = onclause.comparator._source_selectable() prop = onclause.property @@ -1477,20 +1576,23 @@ class _ORMJoin(expression.Join): augment_onclause = onclause is None and _extra_criteria expression.Join.__init__(self, left, right, onclause, isouter, full) + assert self.onclause is not None + if augment_onclause: self.onclause &= sql.and_(*_extra_criteria) if ( not prop and getattr(right_info, "mapper", None) - and right_info.mapper.single + and right_info.mapper.single # type: ignore ): + right_info = cast("_InternalEntityType[Any]", right_info) # if single inheritance target and we are using a manual # or implicit ON clause, augment it the same way we'd augment the # WHERE. single_crit = right_info.mapper._single_table_criterion if single_crit is not None: - if right_info.is_aliased_class: + if insp_is_aliased_class(right_info): single_crit = right_info._adapter.traverse(single_crit) self.onclause = self.onclause & single_crit @@ -1525,19 +1627,27 @@ class _ORMJoin(expression.Join): def join( self, - right, - onclause=None, - isouter=False, - full=False, - join_to_left=None, - ): + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + ) -> _ORMJoin: return _ORMJoin(self, right, onclause, full=full, isouter=isouter) - def outerjoin(self, right, onclause=None, full=False, join_to_left=None): + def outerjoin( + self, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, + ) -> _ORMJoin: return _ORMJoin(self, right, onclause, isouter=True, full=full) -def with_parent(instance, prop, from_entity=None): +def with_parent( + instance: object, + prop: attributes.QueryableAttribute[Any], + from_entity: Optional[_EntityType[Any]] = None, +) -> ColumnElement[bool]: """Create filtering criterion that relates this query's primary entity to the given related instance, using established :func:`_orm.relationship()` @@ -1588,6 +1698,8 @@ def with_parent(instance, prop, from_entity=None): .. versionadded:: 1.2 """ + prop_t: Relationship[Any] + if isinstance(prop, str): raise sa_exc.ArgumentError( "with_parent() accepts class-bound mapped attributes, not strings" @@ -1595,12 +1707,19 @@ def with_parent(instance, prop, from_entity=None): elif isinstance(prop, attributes.QueryableAttribute): if prop._of_type: from_entity = prop._of_type - prop = prop.property + if not prop_is_relationship(prop.property): + raise sa_exc.ArgumentError( + f"Expected relationship property for with_parent(), " + f"got {prop.property}" + ) + prop_t = prop.property + else: + prop_t = prop - return prop._with_parent(instance, from_entity=from_entity) + return prop_t._with_parent(instance, from_entity=from_entity) -def has_identity(object_): +def has_identity(object_: object) -> bool: """Return True if the given object has a database identity. @@ -1616,7 +1735,7 @@ def has_identity(object_): return state.has_identity -def was_deleted(object_): +def was_deleted(object_: object) -> bool: """Return True if the given object was deleted within a session flush. @@ -1633,27 +1752,32 @@ def was_deleted(object_): return state.was_deleted -def _entity_corresponds_to(given, entity): +def _entity_corresponds_to( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: """determine if 'given' corresponds to 'entity', in terms of an entity passed to Query that would match the same entity being referred to elsewhere in the query. """ - if entity.is_aliased_class: - if given.is_aliased_class: + if insp_is_aliased_class(entity): + if insp_is_aliased_class(given): if entity._base_alias() is given._base_alias(): return True return False - elif given.is_aliased_class: + elif insp_is_aliased_class(given): if given._use_mapper_path: return entity in given.with_polymorphic_mappers else: return entity is given + assert insp_is_mapper(given) return entity.common_parent(given) -def _entity_corresponds_to_use_path_impl(given, entity): +def _entity_corresponds_to_use_path_impl( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: """determine if 'given' corresponds to 'entity', in terms of a path of loader options where a mapped attribute is taken to be a member of a parent entity. @@ -1673,13 +1797,13 @@ def _entity_corresponds_to_use_path_impl(given, entity): """ - if given.is_aliased_class: + if insp_is_aliased_class(given): return ( - entity.is_aliased_class + insp_is_aliased_class(entity) and not entity._use_mapper_path and (given is entity or entity in given._with_polymorphic_entities) ) - elif not entity.is_aliased_class: + elif not insp_is_aliased_class(entity): return given.isa(entity.mapper) else: return ( @@ -1688,7 +1812,7 @@ def _entity_corresponds_to_use_path_impl(given, entity): ) -def _entity_isa(given, mapper): +def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: """determine if 'given' "is a" mapper, in terms of the given would load rows of type 'mapper'. @@ -1703,42 +1827,6 @@ def _entity_isa(given, mapper): return given.isa(mapper) -def randomize_unitofwork(): - """Use random-ordering sets within the unit of work in order - to detect unit of work sorting issues. - - This is a utility function that can be used to help reproduce - inconsistent unit of work sorting issues. For example, - if two kinds of objects A and B are being inserted, and - B has a foreign key reference to A - the A must be inserted first. - However, if there is no relationship between A and B, the unit of work - won't know to perform this sorting, and an operation may or may not - fail, depending on how the ordering works out. Since Python sets - and dictionaries have non-deterministic ordering, such an issue may - occur on some runs and not on others, and in practice it tends to - have a great dependence on the state of the interpreter. This leads - to so-called "heisenbugs" where changing entirely irrelevant aspects - of the test program still cause the failure behavior to change. - - By calling ``randomize_unitofwork()`` when a script first runs, the - ordering of a key series of sets within the unit of work implementation - are randomized, so that the script can be minimized down to the - fundamental mapping and operation that's failing, while still reproducing - the issue on at least some runs. - - This utility is also available when running the test suite via the - ``--reversetop`` flag. - - """ - from sqlalchemy.orm import unitofwork, session, mapper, dependency - from sqlalchemy.util import topological - from sqlalchemy.testing.util import RandomSet - - topological.set = ( - unitofwork.set - ) = session.set = mapper.set = dependency.set = RandomSet - - def _getitem(iterable_query, item): """calculate __getitem__ in terms of an iterable query object that also has a slice() method. @@ -1780,16 +1868,21 @@ def _getitem(iterable_query, item): return list(iterable_query[item : item + 1])[0] -def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type): +def _is_mapped_annotation( + raw_annotation: Union[type, str], cls: Type[Any] +) -> bool: annotated = de_stringify_annotation(cls, raw_annotation) return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") -def _cleanup_mapped_str_annotation(annotation): +def _cleanup_mapped_str_annotation(annotation: str) -> str: # fix up an annotation that comes in as the form: # 'Mapped[List[Address]]' so that it instead looks like: # 'Mapped[List["Address"]]' , which will allow us to get # "Address" as a string + + inner: Optional[Match[str]] + mm = re.match(r"^(.+?)\[(.+)\]$", annotation) if mm and mm.group(1) == "Mapped": stack = [] @@ -1839,8 +1932,8 @@ def _extract_mapped_subtype( else: if ( not hasattr(annotated, "__origin__") - or not issubclass(annotated.__origin__, attr_cls) - and not issubclass(attr_cls, annotated.__origin__) + or not issubclass(annotated.__origin__, attr_cls) # type: ignore + and not issubclass(attr_cls, annotated.__origin__) # type: ignore ): our_annotated_str = ( annotated.__name__ @@ -1853,9 +1946,9 @@ def _extract_mapped_subtype( f'"{attr_cls.__name__}[{our_annotated_str}]".' ) - if len(annotated.__args__) != 1: + if len(annotated.__args__) != 1: # type: ignore raise sa_exc.ArgumentError( "Expected sub-type for Mapped[] annotation" ) - return annotated.__args__[0] + return annotated.__args__[0] # type: ignore diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index ea21e01c66..605f75ec4f 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -389,7 +389,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: def bindparam( - key: str, + key: Optional[str], value: Any = _NoArg.NO_ARG, type_: Optional[TypeEngine[_T]] = None, unique: bool = False, @@ -521,6 +521,11 @@ def bindparam( key, or if its length is too long and truncation is required. + If omitted, an "anonymous" name is generated for the bound parameter; + when given a value to bind, the end result is equivalent to calling upon + the :func:`.literal` function with a value to bind, particularly + if the :paramref:`.bindparam.unique` parameter is also provided. + :param value: Initial value for this bind param. Will be used at statement execution time as the value for this parameter passed to the diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b0a717a1a3..53d29b628f 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -2,13 +2,14 @@ from __future__ import annotations import operator from typing import Any +from typing import Callable from typing import Dict +from typing import Set from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.sql.base import Executable from . import roles from .. import util from ..inspection import Inspectable @@ -16,6 +17,7 @@ from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from .base import Executable from .compiler import Compiled from .compiler import DDLCompiler from .compiler import SQLCompiler @@ -27,17 +29,20 @@ if TYPE_CHECKING: from .elements import quoted_name from .elements import SQLCoreOperations from .elements import TextClause + from .lambdas import LambdaElement from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column from .schema import DefaultGenerator from .schema import Sequence + from .schema import Table from .selectable import Alias from .selectable import FromClause from .selectable import Join from .selectable import NamedFromClause from .selectable import ReturnsRows from .selectable import Select + from .selectable import Selectable from .selectable import SelectBase from .selectable import Subquery from .selectable import TableClause @@ -46,7 +51,6 @@ if TYPE_CHECKING: from .type_api import TypeEngine from ..util.typing import TypeGuard - _T = TypeVar("_T", bound=Any) @@ -89,7 +93,11 @@ sets; select(...), insert().returning(...), etc. """ _ColumnExpressionArgument = Union[ - "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] + "ColumnElement[_T]", + _HasClauseElement, + roles.ExpressionElementRole[_T], + Callable[[], "ColumnElement[_T]"], + "LambdaElement", ] """narrower "column expression" argument. @@ -103,6 +111,7 @@ overall which brings in the TextClause object also. """ + _InfoType = Dict[Any, Any] """the .info dictionary accepted and used throughout Core /ORM""" @@ -169,6 +178,8 @@ _PropagateAttrsType = util.immutabledict[str, Any] _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] +_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]] + if TYPE_CHECKING: def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: @@ -195,6 +206,9 @@ if TYPE_CHECKING: def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: ... + def is_selectable(t: Any) -> TypeGuard[Selectable]: + ... + def is_select_base( t: Union[Executable, ReturnsRows] ) -> TypeGuard[SelectBase]: @@ -224,6 +238,7 @@ else: is_from_clause = operator.attrgetter("_is_from_clause") is_tuple_type = operator.attrgetter("_is_tuple_type") is_table_value_type = operator.attrgetter("_is_table_value") + is_selectable = operator.attrgetter("is_selectable") is_select_base = operator.attrgetter("_is_select_base") is_select_statement = operator.attrgetter("_is_select_statement") is_table = operator.attrgetter("_is_table") diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f7692dbc2a..f81878d55d 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -218,7 +218,7 @@ def _generative(fn: _Fn) -> _Fn: """ - @util.decorator + @util.decorator # type: ignore def _generative( fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any ) -> _SelfGenerativeType: @@ -244,7 +244,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: for name in names ] - @util.decorator + @util.decorator # type: ignore def check(fn, *args, **kw): # make pylance happy by not including "self" in the argument # list @@ -260,7 +260,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: raise exc.InvalidRequestError(msg) return fn(self, *args, **kw) - return check + return check # type: ignore def _clone(element, **kw): @@ -1750,15 +1750,14 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): self._collection.append((k, col)) self._colset.update(c for (k, c) in self._collection) - # https://github.com/python/mypy/issues/12610 self._index.update( - (idx, c) for idx, (k, c) in enumerate(self._collection) # type: ignore # noqa: E501 + (idx, c) for idx, (k, c) in enumerate(self._collection) ) for col in replace_col: self.replace(col) def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: - self._populate_separate_keys((col.key, col) for col in iter_) # type: ignore # noqa: E501 + self._populate_separate_keys((col.key, col) for col in iter_) def remove(self, column: _NAMEDCOL) -> None: if column not in self._colset: @@ -1772,9 +1771,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): (k, c) for (k, c) in self._collection if c is not column ] - # https://github.com/python/mypy/issues/12610 self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501 + {idx: col for idx, (k, col) in enumerate(self._collection)} ) # delete higher index del self._index[len(self._collection)] @@ -1827,9 +1825,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): self._index.clear() - # https://github.com/python/mypy/issues/12610 self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501 + {idx: col for idx, (k, col) in enumerate(self._collection)} ) self._index.update(self._collection) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 4bf45da9cb..0659709ab4 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -214,6 +214,7 @@ def expect( Type[roles.ExpressionElementRole[Any]], Type[roles.LimitOffsetRole], Type[roles.WhereHavingRole], + Type[roles.OnClauseRole], ], element: Any, **kw: Any, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 938be0f817..c524a2602c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1078,7 +1078,7 @@ class SQLCompiler(Compiled): return list(self.insert_prefetch) + list(self.update_prefetch) @util.memoized_property - def _global_attributes(self): + def _global_attributes(self) -> Dict[Any, Any]: return {} @util.memoized_instancemethod diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 6ac7c24483..052af6ac9d 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -14,6 +14,7 @@ from __future__ import annotations import typing from typing import Any from typing import Callable +from typing import Iterable from typing import List from typing import Optional from typing import Sequence as typing_Sequence @@ -1143,7 +1144,7 @@ class SchemaDropper(InvokeDDLBase): def sort_tables( - tables: typing_Sequence["Table"], + tables: Iterable["Table"], skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, extra_dependencies: Optional[ typing_Sequence[Tuple["Table", "Table"]] diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ea0fa79962..34d5127ab7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -293,11 +293,18 @@ class ClauseElement( __visit_name__ = "clause" - _propagate_attrs: _PropagateAttrsType = util.immutabledict() - """like annotations, however these propagate outwards liberally - as SQL constructs are built, and are set up at construction time. + if TYPE_CHECKING: - """ + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + """like annotations, however these propagate outwards liberally + as SQL constructs are built, and are set up at construction time. + + """ + ... + + else: + _propagate_attrs = util.EMPTY_DICT @util.ro_memoized_property def description(self) -> Optional[str]: @@ -343,7 +350,9 @@ class ClauseElement( def _from_objects(self) -> List[FromClause]: return [] - def _set_propagate_attrs(self, values): + def _set_propagate_attrs( + self: SelfClauseElement, values: Mapping[str, Any] + ) -> SelfClauseElement: # usually, self._propagate_attrs is empty here. one case where it's # not is a subquery against ORM select, that is then pulled as a # property of an aliased class. should all be good @@ -526,13 +535,10 @@ class ClauseElement( if unique: bind._convert_to_unique() - return cast( - SelfClauseElement, - cloned_traverse( - self, - {"maintain_key": True, "detect_subquery_cols": True}, - {"bindparam": visit_bindparam}, - ), + return cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, ) def compare(self, other, **kw): @@ -730,7 +736,9 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): # redefined with the specific types returned by ColumnElement hierarchies if typing.TYPE_CHECKING: - _propagate_attrs: _PropagateAttrsType + @util.non_memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + ... def operate( self, op: OperatorType, *other: Any, **kwargs: Any @@ -2064,10 +2072,11 @@ class TextClause( roles.OrderByRole, roles.FromClauseRole, roles.SelectStatementRole, - roles.BinaryElementRole[Any], roles.InElementRole, Executable, DQLDMLClauseElement, + roles.BinaryElementRole[Any], + inspection.Inspectable["TextClause"], ): """Represent a literal SQL text fragment. diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index da15c305fc..4b220188f7 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -444,7 +444,7 @@ class DeferredLambdaElement(LambdaElement): def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) - def _resolve_with_args(self, *lambda_args): + def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement: assert isinstance(self._rec, AnalyzedFunction) tracker_fn = self._rec.tracker_instrumented_fn expr = tracker_fn(*lambda_args) @@ -478,7 +478,7 @@ class DeferredLambdaElement(LambdaElement): for deferred_copy_internals in self._transforms: expr = deferred_copy_internals(expr) - return expr + return expr # type: ignore def _copy_internals( self, clone=_clone, deferred_copy_internals=None, **kw diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 577d868fdb..231c70a5ba 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -22,9 +22,7 @@ if TYPE_CHECKING: from .base import _EntityNamespace from .base import ColumnCollection from .base import ReadOnlyColumnCollection - from .elements import ClauseElement from .elements import ColumnClause - from .elements import ColumnElement from .elements import Label from .elements import NamedColumn from .selectable import _SelectIterable @@ -271,7 +269,14 @@ class StatementRole(SQLRole): __slots__ = () _role_name = "Executable SQL or text() construct" - _propagate_attrs: _PropagateAttrsType = util.immutabledict() + if TYPE_CHECKING: + + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + ... + + else: + _propagate_attrs = util.EMPTY_DICT class SelectStatementRole(StatementRole, ReturnsRowsRole): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 92b9cc62c2..52ba60a62c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -144,9 +144,9 @@ class SchemaConst(Enum): NULL_UNSPECIFIED = 3 """Symbol indicating the "nullable" keyword was not passed to a Column. - Normally we would expect None to be acceptable for this but some backends - such as that of SQL Server place special signficance on a "nullability" - value of None. + This is used to distinguish between the use case of passing + ``nullable=None`` to a :class:`.Column`, which has special meaning + on some backends such as SQL Server. """ @@ -308,7 +308,9 @@ class HasSchemaAttr(SchemaItem): schema: Optional[str] -class Table(DialectKWArgs, HasSchemaAttr, TableClause): +class Table( + DialectKWArgs, HasSchemaAttr, TableClause, inspection.Inspectable["Table"] +): r"""Represent a table in a database. e.g.:: @@ -1318,117 +1320,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): inherit_cache = True key: str - @overload - def __init__( - self, - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload - def __init__( - self, - __name: str, - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload def __init__( self, - __type: _TypeEngineArgument[_T], - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload - def __init__( - self, - __name: str, - __type: _TypeEngineArgument[_T], + __name_pos: Optional[ + Union[str, _TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[_T], SchemaEventTarget] + ] = None, *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - def __init__( - self, - *args: Union[str, _TypeEngineArgument[_T], SchemaEventTarget], name: Optional[str] = None, type_: Optional[_TypeEngineArgument[_T]] = None, autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", @@ -1440,7 +1340,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): info: Optional[_InfoType] = None, nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, + ] = SchemaConst.NULL_UNSPECIFIED, onupdate: Optional[Any] = None, primary_key: bool = False, server_default: Optional[_ServerDefaultType] = None, @@ -1953,7 +1853,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): """ # noqa: E501, RST201, RST202 - l_args = list(args) + l_args = [__name_pos, __type_pos] + list(args) del args if l_args: @@ -1963,6 +1863,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): "May not pass name positionally and as a keyword." ) name = l_args.pop(0) # type: ignore + elif l_args[0] is None: + l_args.pop(0) if l_args: coltype = l_args[0] @@ -1972,6 +1874,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): "May not pass type_ positionally and as a keyword." ) type_ = l_args.pop(0) # type: ignore + elif l_args[0] is None: + l_args.pop(0) if name is not None: name = quoted_name(name, quote) @@ -1989,7 +1893,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): self.primary_key = primary_key self._user_defined_nullable = udn = nullable - if udn is not NULL_UNSPECIFIED: self.nullable = udn else: @@ -5128,7 +5031,7 @@ class MetaData(HasSchemaAttr): def clear(self) -> None: """Clear all Table objects from this MetaData.""" - dict.clear(self.tables) + dict.clear(self.tables) # type: ignore self._schemas.clear() self._fk_memos.clear() diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index aab3c678c5..9d4d1d6c79 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1223,7 +1223,9 @@ class Join(roles.DMLTableRole, FromClause): @util.preload_module("sqlalchemy.sql.util") def _populate_column_collection(self): sqlutil = util.preloaded.sql_util - columns = [c for c in self.left.c] + [c for c in self.right.c] + columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [ + c for c in self.right.c + ] self.primary_key.extend( # type: ignore sqlutil.reduce_columns( diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 2843431549..d08fef60a9 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -17,7 +17,9 @@ from typing import AbstractSet from typing import Any from typing import Callable from typing import cast +from typing import Collection from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import Optional @@ -32,15 +34,15 @@ from . import coercions from . import operators from . import roles from . import visitors +from ._typing import is_text_clause from .annotation import _deep_annotate as _deep_annotate from .annotation import _deep_deannotate as _deep_deannotate from .annotation import _shallow_annotate as _shallow_annotate from .base import _expand_cloned from .base import _from_objects -from .base import ColumnSet -from .cache_key import HasCacheKey # noqa -from .ddl import sort_tables # noqa -from .elements import _find_columns +from .cache_key import HasCacheKey as HasCacheKey +from .ddl import sort_tables as sort_tables +from .elements import _find_columns as _find_columns from .elements import _label_reference from .elements import _textual_label_reference from .elements import BindParameter @@ -67,10 +69,13 @@ from ..util.typing import Protocol if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument + from .elements import TextClause from .roles import FromClauseRole from .selectable import _JoinTargetElement from .selectable import _OnClauseElement + from .selectable import _SelectIterable from .selectable import Selectable from .visitors import _TraverseCallableType from .visitors import ExternallyTraversible @@ -752,7 +757,29 @@ def splice_joins( return ret -def reduce_columns(columns, *clauses, **kw): +@overload +def reduce_columns( + columns: Iterable[ColumnElement[Any]], + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Sequence[ColumnElement[Any]]: + ... + + +@overload +def reduce_columns( + columns: _SelectIterable, + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Sequence[Union[ColumnElement[Any], TextClause]]: + ... + + +def reduce_columns( + columns: _SelectIterable, + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Collection[Union[ColumnElement[Any], TextClause]]: r"""given a list of columns, return a 'reduced' set based on natural equivalents. @@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw): ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) only_synonyms = kw.pop("only_synonyms", False) - columns = util.ordered_column_set(columns) + column_set = util.OrderedSet(columns) + cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference( + c for c in column_set if is_text_clause(c) # type: ignore + ) omit = util.column_set() - for col in columns: + for col in cset_no_text: for fk in chain(*[c.foreign_keys for c in col.proxy_set]): - for c in columns: + for c in cset_no_text: if c is col: continue try: @@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw): def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set( - chain(*[c.proxy_set for c in columns.difference(omit)]) + chain( + *[c.proxy_set for c in cset_no_text.difference(omit)] + ) ) if binary.left in cols and binary.right in cols: - for c in reversed(columns): + for c in reversed(cset_no_text): if c.shares_lineage(binary.right) and ( not only_synonyms or c.name == binary.left.name ): @@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw): if clause is not None: visitors.traverse(clause, {}, {"binary": visit_binary}) - return ColumnSet(columns.difference(omit)) + return column_set.difference(omit) def criterion_as_pairs( @@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def __init__( self, selectable: Selectable, - equivalents: Optional[ - Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] - ] = None, + equivalents: Optional[_EquivalentColumnMap] = None, include_fn: Optional[Callable[[ClauseElement], bool]] = None, exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, adapt_on_names: bool = False, @@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): class _ColumnLookup(Protocol): - def __getitem__( - self, key: ColumnElement[Any] - ) -> Optional[ColumnElement[Any]]: + @overload + def __getitem__(self, key: None) -> None: + ... + + @overload + def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: + ... + + @overload + def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: + ... + + @overload + def __getitem__(self, key: _ET) -> _ET: + ... + + def __getitem__(self, key: Any) -> Any: ... @@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter): def __init__( self, selectable: Selectable, - equivalents: Optional[ - Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] - ] = None, + equivalents: Optional[_EquivalentColumnMap] = None, adapt_required: bool = False, include_fn: Optional[Callable[[ClauseElement], bool]] = None, exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, @@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter): return ac - def traverse(self, obj): + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload + def traverse(self, obj: _ET) -> _ET: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: return self.columns[obj] def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: @@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter): adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process - def adapt_check_present(self, col): + def adapt_check_present( + self, col: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: newcol = self.columns[col] if newcol is col and self._corresponding_column(col, True) is None: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 7363f9ddc2..e0a66fbcf4 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -961,12 +961,16 @@ def cloned_traverse( ... +# a bit of controversy here, as the clone of the lead element +# *could* in theory replace with an entirely different kind of element. +# however this is really not how cloned_traverse is ever used internally +# at least. @overload def cloned_traverse( - obj: ExternallyTraversible, + obj: _ET, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> ExternallyTraversible: +) -> _ET: ... diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index b908585128..16924a0a1b 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -166,14 +166,6 @@ def setup_options(make_option): help="write out generated follower idents to , " "when -n is used", ) - make_option( - "--reversetop", - action="store_true", - dest="reversetop", - default=False, - help="Use a random-ordering set implementation in the ORM " - "(helps reveal dependency issues)", - ) make_option( "--requirements", action="callback", @@ -475,14 +467,6 @@ def _prep_testing_database(options, file_config): provision.drop_all_schema_objects(cfg, cfg.db) -@post -def _reverse_topological(options, file_config): - if options.reversetop: - from sqlalchemy.orm.util import randomize_unitofwork - - randomize_unitofwork() - - @post def _post_setup_options(opt, file_config): from sqlalchemy.testing import config diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 086b008de6..ed69450903 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -26,6 +26,7 @@ from typing import Mapping from typing import NoReturn from typing import Optional from typing import overload +from typing import Sequence from typing import Set from typing import Tuple from typing import TypeVar @@ -287,8 +288,8 @@ OrderedDict = dict sort_dictionary = _ordered_dictionary_sort -class WeakSequence: - def __init__(self, __elements=()): +class WeakSequence(Sequence[_T]): + def __init__(self, __elements: Sequence[_T] = ()): # adapted from weakref.WeakKeyDictionary, prevent reference # cycles in the collection itself def _remove(item, selfref=weakref.ref(self)): diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 88deac28f8..b02bca28f8 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -92,7 +92,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): new = dict.__new__(self.__class__) dict.__init__(new, self) - dict.update(new, __d) + dict.update(new, __d) # type: ignore return new def _union_w_kw( @@ -105,7 +105,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): new = dict.__new__(self.__class__) dict.__init__(new, self) if __d: - dict.update(new, __d) + dict.update(new, __d) # type: ignore dict.update(new, kw) # type: ignore return new @@ -118,7 +118,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): if new is None: new = dict.__new__(self.__class__) dict.__init__(new, self) - dict.update(new, d) + dict.update(new, d) # type: ignore if new is None: return self diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 5c536b675f..7c80ef4e02 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -13,7 +13,6 @@ from __future__ import annotations import re from typing import Any from typing import Callable -from typing import cast from typing import Dict from typing import Match from typing import Optional @@ -79,7 +78,7 @@ def warn_deprecated_limited( def deprecated_cls( - version: str, message: str, constructor: str = "__init__" + version: str, message: str, constructor: Optional[str] = "__init__" ) -> Callable[[Type[_T]], Type[_T]]: header = ".. deprecated:: %s %s" % (version, (message or "")) @@ -288,7 +287,9 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]: check_any_kw = spec.varkw - @decorator + # latest mypy has opinions here, not sure if they implemented + # Concatenate or something + @decorator # type: ignore def warned(fn: _F, *args: Any, **kwargs: Any) -> _F: for m in check_defaults: if (defaults[m] is None and kwargs[m] is not None) or ( @@ -332,7 +333,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]: for param, (version, message) in specs.items() }, ) - decorated = cast(_F, warned)(fn) + decorated = warned(fn) # type: ignore decorated.__doc__ = doc return decorated # type: ignore[no-any-return] @@ -352,7 +353,7 @@ def _sanitize_restructured_text(text: str) -> str: def _decorate_cls_with_warning( cls: Type[_T], - constructor: str, + constructor: Optional[str], wtype: Type[exc.SADeprecationWarning], message: str, version: str, @@ -418,7 +419,7 @@ def _decorate_with_warning( else: doc_only = "" - @decorator + @decorator # type: ignore def warned(fn: _F, *args: Any, **kwargs: Any) -> _F: skip_warning = not enable_warnings or kwargs.pop( "_sa_skip_warning", False @@ -435,9 +436,9 @@ def _decorate_with_warning( doc = inject_docstring_text(doc, docstring_header, 1) - decorated = cast(_F, warned)(func) + decorated = warned(func) # type: ignore decorated.__doc__ = doc - decorated._sa_warn = lambda: _warn_with_version( + decorated._sa_warn = lambda: _warn_with_version( # type: ignore message, version, wtype, stacklevel=3 ) return decorated # type: ignore[no-any-return] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 9b3692d595..49c5d693af 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -238,6 +238,8 @@ def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]: _Fn = TypeVar("_Fn", bound="Callable[..., Any]") +# this seems to be in flux in recent mypy versions + def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]: """A signature-matching decorator factory.""" diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py index 260250b2cf..ee3227d775 100644 --- a/lib/sqlalchemy/util/preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -23,6 +23,8 @@ _FN = TypeVar("_FN", bound=Callable[..., Any]) if TYPE_CHECKING: from sqlalchemy.engine import default as engine_default + from sqlalchemy.orm import descriptor_props as orm_descriptor_props + from sqlalchemy.orm import relationships as orm_relationships from sqlalchemy.orm import session as orm_session from sqlalchemy.orm import util as orm_util from sqlalchemy.sql import dml as sql_dml diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index b3f3b93870..d192dc06bf 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -187,7 +187,9 @@ def is_union(type_): return is_origin_of(type_, "Union") -def is_origin_of(type_, *names, module=None): +def is_origin_of( + type_: Any, *names: str, module: Optional[str] = None +) -> bool: """return True if the given type has an __origin__ with the given name and optional module.""" @@ -200,7 +202,7 @@ def is_origin_of(type_, *names, module=None): ) -def _get_type_name(type_): +def _get_type_name(type_: Type[Any]) -> str: if compat.py310: return type_.__name__ else: @@ -208,4 +210,4 @@ def _get_type_name(type_): if typ_name is None: typ_name = getattr(type_, "_name", None) - return typ_name + return typ_name # type: ignore diff --git a/pyproject.toml b/pyproject.toml index e727ee1e49..d16f03c032 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ incremental = true [[tool.mypy.overrides]] -# ad-hoc ignores +##################################################################### +# modules / packages explicitly not checked by Mypy at all right now. module = [ "sqlalchemy.engine.reflection", # interim, should be strict @@ -86,7 +87,8 @@ module = [ warn_unused_ignores = false ignore_errors = true -# strict checking +################################################ +# modules explicitly for Mypy strict checking [[tool.mypy.overrides]] module = [ @@ -98,6 +100,11 @@ module = [ "sqlalchemy.engine.*", "sqlalchemy.pool.*", + # uncomment, trying to make sure mypy + # is at a baseline + # "sqlalchemy.orm._orm_constructors", + + "sqlalchemy.orm.path_registry", "sqlalchemy.orm.scoping", "sqlalchemy.orm.session", "sqlalchemy.orm.state", @@ -114,7 +121,8 @@ warn_unused_ignores = false ignore_errors = false strict = true -# partial checking +################################################ +# modules explicitly for Mypy non-strict checking [[tool.mypy.overrides]] module = [ @@ -135,6 +143,12 @@ module = [ "sqlalchemy.sql.traversals", "sqlalchemy.sql.util", + "sqlalchemy.orm._orm_constructors", + + "sqlalchemy.orm.interfaces", + "sqlalchemy.orm.mapper", + "sqlalchemy.orm.util", + "sqlalchemy.util.*", ] diff --git a/test/ext/mypy/plain_files/association_proxy_one.py b/test/ext/mypy/plain_files/association_proxy_one.py index c5c897956a..e8b57a0c02 100644 --- a/test/ext/mypy/plain_files/association_proxy_one.py +++ b/test/ext/mypy/plain_files/association_proxy_one.py @@ -40,8 +40,8 @@ class Address(Base): u1 = User() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*\[builtins.str\]\] + # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\] reveal_type(User.email_addresses) - # EXPECTED_TYPE: builtins.set\*\[builtins.str\] + # EXPECTED_TYPE: builtins.set\*?\[builtins.str\] reveal_type(u1.email_addresses) diff --git a/test/ext/mypy/plain_files/experimental_relationship.py b/test/ext/mypy/plain_files/experimental_relationship.py index e97a9598b0..fe2742072c 100644 --- a/test/ext/mypy/plain_files/experimental_relationship.py +++ b/test/ext/mypy/plain_files/experimental_relationship.py @@ -8,7 +8,6 @@ from typing import Set from sqlalchemy import ForeignKey from sqlalchemy import Integer -from sqlalchemy import String from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -42,8 +41,8 @@ class Address(Base): id = mapped_column(Integer, primary_key=True) user_id = mapped_column(ForeignKey("user.id")) - email = mapped_column(String, nullable=False) - email_name = mapped_column("email_name", String, nullable=False) + email: Mapped[str] + email_name: Mapped[str] = mapped_column("email_name") user_style_one: Mapped[User] = relationship() user_style_two: Mapped["User"] = relationship() @@ -56,14 +55,14 @@ if typing.TYPE_CHECKING: # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] reveal_type(User.extra_name) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] reveal_type(Address.email) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] reveal_type(Address.email_name) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[experimental_relationship.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*\[experimental_relationship.Address\]\] + # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\] reveal_type(User.addresses_style_two) diff --git a/test/ext/mypy/plain_files/hybrid_one.py b/test/ext/mypy/plain_files/hybrid_one.py index 7d97024afe..d9f97ebcff 100644 --- a/test/ext/mypy/plain_files/hybrid_one.py +++ b/test/ext/mypy/plain_files/hybrid_one.py @@ -47,7 +47,7 @@ expr2 = Interval.contains(7) expr3 = Interval.intersects(i2) if typing.TYPE_CHECKING: - # EXPECTED_TYPE: builtins.int\* + # EXPECTED_TYPE: builtins.int\*? reveal_type(i1.length) # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] diff --git a/test/ext/mypy/plain_files/hybrid_two.py b/test/ext/mypy/plain_files/hybrid_two.py index 6bfabbd30a..ab2970656e 100644 --- a/test/ext/mypy/plain_files/hybrid_two.py +++ b/test/ext/mypy/plain_files/hybrid_two.py @@ -69,10 +69,10 @@ expr3 = Interval.radius.in_([0.5, 5.2]) if typing.TYPE_CHECKING: - # EXPECTED_TYPE: builtins.int\* + # EXPECTED_TYPE: builtins.int\*? reveal_type(i1.length) - # EXPECTED_TYPE: builtins.float\* + # EXPECTED_TYPE: builtins.float\*? reveal_type(i2.radius) # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] diff --git a/test/ext/mypy/plain_files/mapped_column.py b/test/ext/mypy/plain_files/mapped_column.py index b20beeb3a3..14f4ad845a 100644 --- a/test/ext/mypy/plain_files/mapped_column.py +++ b/test/ext/mypy/plain_files/mapped_column.py @@ -14,68 +14,67 @@ class Base(DeclarativeBase): class X(Base): __tablename__ = "x" + # these are fine - pk, column is not null, have the attribute be + # non-optional, fine id: Mapped[int] = mapped_column(primary_key=True) int_id: Mapped[int] = mapped_column(Integer, primary_key=True) - # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + # but this is also "fine" because the developer may wish to have the object + # in a pending state with None for the id for some period of time. + # "primary_key=True" will still be interpreted correctly in DDL err_int_id: Mapped[Optional[int]] = mapped_column( Integer, primary_key=True ) - id_name: Mapped[int] = mapped_column("id_name", primary_key=True) - int_id_name: Mapped[int] = mapped_column( - "int_id_name", Integer, primary_key=True - ) - - # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + # also fine, X(err_int_id_name) is None when you first make the + # object err_int_id_name: Mapped[Optional[int]] = mapped_column( "err_int_id_name", Integer, primary_key=True ) - # note we arent getting into primary_key=True / nullable=True here. - # leaving that as undefined for now + id_name: Mapped[int] = mapped_column("id_name", primary_key=True) + int_id_name: Mapped[int] = mapped_column( + "int_id_name", Integer, primary_key=True + ) a: Mapped[str] = mapped_column() b: Mapped[Optional[str]] = mapped_column() - # can't detect error because no SQL type is present + # this can't be detected because we don't know the type c: Mapped[str] = mapped_column(nullable=True) d: Mapped[str] = mapped_column(nullable=False) e: Mapped[Optional[str]] = mapped_column(nullable=True) - # can't detect error because no SQL type is present f: Mapped[Optional[str]] = mapped_column(nullable=False) g: Mapped[str] = mapped_column(String) h: Mapped[Optional[str]] = mapped_column(String) - # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + # this probably is wrong. however at the moment it seems better to + # decouple the right hand arguments from declaring things about the + # left side since it mostly doesn't work in any case. i: Mapped[str] = mapped_column(String, nullable=True) j: Mapped[str] = mapped_column(String, nullable=False) k: Mapped[Optional[str]] = mapped_column(String, nullable=True) - # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type l: Mapped[Optional[str]] = mapped_column(String, nullable=False) a_name: Mapped[str] = mapped_column("a_name") b_name: Mapped[Optional[str]] = mapped_column("b_name") - # can't detect error because no SQL type is present c_name: Mapped[str] = mapped_column("c_name", nullable=True) d_name: Mapped[str] = mapped_column("d_name", nullable=False) e_name: Mapped[Optional[str]] = mapped_column("e_name", nullable=True) - # can't detect error because no SQL type is present f_name: Mapped[Optional[str]] = mapped_column("f_name", nullable=False) g_name: Mapped[str] = mapped_column("g_name", String) h_name: Mapped[Optional[str]] = mapped_column("h_name", String) - # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types i_name: Mapped[str] = mapped_column("i_name", String, nullable=True) j_name: Mapped[str] = mapped_column("j_name", String, nullable=False) @@ -86,7 +85,6 @@ class X(Base): l_name: Mapped[Optional[str]] = mapped_column( "l_name", - # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type String, nullable=False, ) diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py index 78b0a467ce..f9b9b2ffe5 100644 --- a/test/ext/mypy/plain_files/sql_operations.py +++ b/test/ext/mypy/plain_files/sql_operations.py @@ -3,6 +3,7 @@ import typing from sqlalchemy import Boolean from sqlalchemy import column from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import String @@ -32,6 +33,7 @@ expr7 = c1 + "x" expr8 = c2 + 10 +stmt = select(column("q")).where(lambda: column("g") > 5).where(c2 == 5) if typing.TYPE_CHECKING: diff --git a/test/ext/mypy/plain_files/trad_relationship_uselist.py b/test/ext/mypy/plain_files/trad_relationship_uselist.py index b43dcd594b..af7d292be7 100644 --- a/test/ext/mypy/plain_files/trad_relationship_uselist.py +++ b/test/ext/mypy/plain_files/trad_relationship_uselist.py @@ -101,45 +101,45 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\] reveal_type(User.addresses_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[Any\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_three) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_three_cast) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[Any\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_four) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] reveal_type(Address.user_style_one_typed) # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] reveal_type(Address.user_style_two_typed) # reveal_type(Address.user_style_six) # reveal_type(Address.user_style_seven) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_eight) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_nine) # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_ten) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*\[builtins.str, trad_relationship_uselist.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\] reveal_type(Address.user_style_ten_typed) diff --git a/test/ext/mypy/plain_files/traditional_relationship.py b/test/ext/mypy/plain_files/traditional_relationship.py index 473ccb2824..ce131dd004 100644 --- a/test/ext/mypy/plain_files/traditional_relationship.py +++ b/test/ext/mypy/plain_files/traditional_relationship.py @@ -60,29 +60,29 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\] reveal_type(User.addresses_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] reveal_type(Address.user_style_one_typed) # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] reveal_type(Address.user_style_two_typed) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] reveal_type(Address.user_style_three) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] reveal_type(Address.user_style_four) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_five) diff --git a/test/ext/mypy/plugin_files/relationship_6255_one.py b/test/ext/mypy/plugin_files/relationship_6255_one.py index 0c8e3c4f64..15961c703a 100644 --- a/test/ext/mypy/plugin_files/relationship_6255_one.py +++ b/test/ext/mypy/plugin_files/relationship_6255_one.py @@ -17,7 +17,7 @@ class User(Base): __tablename__ = "user" id = mapped_column(Integer, primary_key=True) - name = mapped_column(String, nullable=True) + name: Mapped[Optional[str]] = mapped_column(String, nullable=True) addresses: Mapped[List["Address"]] = relationship( "Address", back_populates="user" diff --git a/test/ext/mypy/plugin_files/typing_err3.py b/test/ext/mypy/plugin_files/typing_err3.py index 466e636a78..d29909c3c9 100644 --- a/test/ext/mypy/plugin_files/typing_err3.py +++ b/test/ext/mypy/plugin_files/typing_err3.py @@ -43,11 +43,11 @@ class Address(Base): @declared_attr def email_address(cls) -> Column[String]: - # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa + # EXPECTED_MYPY: Argument 1 to "Column" has incompatible type "bool"; return Column(True) @declared_attr # EXPECTED_MYPY: Invalid type comment or annotation def thisisweird(cls) -> Column(String): - # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa + # EXPECTED_MYPY: Argument 1 to "Column" has incompatible type "bool"; return Column(False) diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index 43443b7f64..7830fcee68 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -18,6 +18,7 @@ from sqlalchemy.orm.instrumentation import register_class from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not @@ -110,7 +111,13 @@ class DisposeTest(_ExtBase, fixtures.TestBase): class MyClass: __sa_instrumentation_manager__ = MyClassState - assert attributes.manager_of_class(MyClass) is None + assert attributes.opt_manager_of_class(MyClass) is None + + with expect_raises_message( + sa.orm.exc.UnmappedClassError, + r"Can't locate an instrumentation manager for class .*MyClass", + ): + attributes.manager_of_class(MyClass) t = Table( "my_table", @@ -120,7 +127,7 @@ class DisposeTest(_ExtBase, fixtures.TestBase): registry.map_imperatively(MyClass, t) - manager = attributes.manager_of_class(MyClass) + manager = attributes.opt_manager_of_class(MyClass) is_not(manager, None) is_(manager, MyClass.xyz) @@ -128,7 +135,7 @@ class DisposeTest(_ExtBase, fixtures.TestBase): registry.dispose() - manager = attributes.manager_of_class(MyClass) + manager = attributes.opt_manager_of_class(MyClass) is_(manager, None) assert not hasattr(MyClass, "xyz") @@ -532,9 +539,9 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): register_class(Known) k, u = Known(), Unknown() - assert instrumentation.manager_of_class(Unknown) is None - assert instrumentation.manager_of_class(Known) is not None - assert instrumentation.manager_of_class(None) is None + assert instrumentation.opt_manager_of_class(Unknown) is None + assert instrumentation.opt_manager_of_class(Known) is not None + assert instrumentation.opt_manager_of_class(None) is None assert attributes.instance_state(k) is not None assert_raises((AttributeError, KeyError), attributes.instance_state, u) @@ -583,7 +590,10 @@ class FinderTest(_ExtBase, fixtures.ORMTest): ) register_class(A) - ne_(type(manager_of_class(A)), instrumentation.ClassManager) + ne_( + type(attributes.opt_manager_of_class(A)), + instrumentation.ClassManager, + ) def test_nativeext_submanager(self): class Mine(instrumentation.ClassManager): diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 67abc8971b..b50cbc2bae 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -391,6 +391,33 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): polymorphic_identity=0, ) + def test_polymorphic_on_not_present_col_partial_wpoly(self): + """fix for partial with_polymorphic(). + + found_during_type_annotation + + """ + t2, t1 = self.tables.t2, self.tables.t1 + Parent = self.classes.Parent + t1t2_join = select(t1.c.x).select_from(t1.join(t2)).alias() + + def go(): + t1t2_join_2 = select(t1.c.q).select_from(t1.join(t2)).alias() + self.mapper_registry.map_imperatively( + Parent, + t2, + polymorphic_on=t1t2_join.c.x, + with_polymorphic=("*", None), + polymorphic_identity=0, + ) + + assert_raises_message( + sa_exc.InvalidRequestError, + "Could not map polymorphic_on column 'x' to the mapped table - " + "polymorphic loads will not function properly", + go, + ) + def test_polymorphic_on_not_present_col(self): t2, t1 = self.tables.t2, self.tables.t1 Parent = self.classes.Parent diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index 51f37f0286..5a171e3722 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -9,6 +9,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import attributes from sqlalchemy.orm import backref +from sqlalchemy.orm import CascadeOptions from sqlalchemy.orm import class_mapper from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import exc as orm_exc @@ -4217,11 +4218,12 @@ class SubclassCascadeTest(fixtures.DeclarativeMappedTest): eq_(s.query(Language).count(), 0) -class ViewonlyFlagWarningTest(fixtures.MappedTest): - """test for #4993. +class ViewonlyCascadeUpdate(fixtures.MappedTest): + """Test that cascades are trimmed accordingly when viewonly is set. - In 1.4, this moves to test/orm/test_cascade, deprecation warnings - become errors, will then be for #4994. + Originally #4993 and #4994 this was raising an error for invalid + cascades. in 2.0 this is simplified to just remove the write + cascades, allows the default cascade to be reasonable. """ @@ -4250,21 +4252,17 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): pass @testing.combinations( - ({"delete"}, {"delete"}), + ({"delete"}, {"none"}), ( {"all, delete-orphan"}, - {"delete", "delete-orphan", "merge", "save-update"}, + {"refresh-expire", "expunge"}, ), - ({"save-update, expunge"}, {"save-update"}), + ({"save-update, expunge"}, {"expunge"}), ) - def test_write_cascades(self, setting, settings_that_warn): + def test_write_cascades(self, setting, expected): Order = self.classes.Order - assert_raises_message( - sa_exc.ArgumentError, - 'Cascade settings "%s" apply to persistence ' - "operations" % (", ".join(sorted(settings_that_warn))), - relationship, + r = relationship( Order, primaryjoin=( self.tables.users.c.id == foreign(self.tables.orders.c.user_id) @@ -4272,6 +4270,7 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): cascade=", ".join(sorted(setting)), viewonly=True, ) + eq_(r.cascade, CascadeOptions(expected)) def test_expunge_cascade(self): User, Order, orders, users = ( @@ -4425,23 +4424,6 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): eq_(umapper.attrs["orders"].cascade, set()) - def test_write_cascade_disallowed_w_viewonly(self): - - Order = self.classes.Order - - assert_raises_message( - sa_exc.ArgumentError, - 'Cascade settings "delete, delete-orphan, merge, save-update" ' - "apply to persistence operations", - relationship, - Order, - primaryjoin=( - self.tables.users.c.id == foreign(self.tables.orders.c.user_id) - ), - cascade="all, delete, delete-orphan", - viewonly=True, - ) - class CollectionCascadesNoBackrefTest(fixtures.TestBase): """test the removal of cascade_backrefs behavior diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index 3b10103001..437129af16 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_warns_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import ne_ from sqlalchemy.testing.fixtures import fixture_session @@ -739,9 +740,15 @@ class MiscTest(fixtures.MappedTest): assert instrumentation.manager_of_class(A) is manager instrumentation.unregister_class(A) - assert instrumentation.manager_of_class(A) is None + assert instrumentation.opt_manager_of_class(A) is None assert not hasattr(A, "x") + with expect_raises_message( + sa.orm.exc.UnmappedClassError, + r"Can't locate an instrumentation manager for class .*A", + ): + instrumentation.manager_of_class(A) + assert A.__init__ == object.__init__ def test_compileonattr_rel_backref_a(self): diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index f71ab30327..43a34eae4c 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -1374,6 +1374,16 @@ class JoinTest(QueryTest, AssertsCompiledSQL): [User(name="fred")], ) + def test_str_not_accepted_orm_join(self): + User, Address = self.classes.User, self.classes.Address + + with expect_raises_message( + sa.exc.ArgumentError, + "ON clause, typically a SQL expression or ORM " + "relationship attribute expected, got 'addresses'.", + ): + outerjoin(User, Address, "addresses") + def test_aliased_classes(self): User, Address = self.classes.User, self.classes.Address @@ -1409,7 +1419,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): eq_(result, [(user8, address3)]) result = ( - q.select_from(outerjoin(User, AdAlias, "addresses")) + q.select_from(outerjoin(User, AdAlias, User.addresses)) .filter(AdAlias.email_address == "ed@bettyboop.com") .all() ) @@ -1504,7 +1514,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): q = sess.query(Order) q = ( q.add_entity(Item) - .select_from(join(Order, Item, "items")) + .select_from(join(Order, Item, Order.items)) .order_by(Order.id, Item.id) ) result = q.all() @@ -1513,7 +1523,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): IAlias = aliased(Item) q = ( sess.query(Order, IAlias) - .select_from(join(Order, IAlias, "items")) + .select_from(join(Order, IAlias, Order.items)) .filter(IAlias.description == "item 3") ) result = q.all() @@ -2569,18 +2579,6 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): s.query(Node).join(Node.children)._compile_context, ) - def test_explicit_join_1(self): - Node = self.classes.Node - n1 = aliased(Node) - n2 = aliased(Node) - - self.assert_compile( - join(Node, n1, "children").join(n2, "children"), - "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " - "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id", - use_default_dialect=True, - ) - def test_explicit_join_2(self): Node = self.classes.Node n1 = aliased(Node) @@ -2598,12 +2596,8 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): n1 = aliased(Node) n2 = aliased(Node) - # the join_to_left=False here is unfortunate. the default on this - # flag should be False. self.assert_compile( - join(Node, n1, Node.children).join( - n2, Node.children, join_to_left=False - ), + join(Node, n1, Node.children).join(n2, Node.children), "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " "JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id", use_default_dialect=True, @@ -2646,7 +2640,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): node = ( sess.query(Node) - .select_from(join(Node, n1, "children")) + .select_from(join(Node, n1, n1.children)) .filter(n1.data == "n122") .first() ) @@ -2660,7 +2654,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): node = ( sess.query(Node) - .select_from(join(Node, n1, "children").join(n2, "children")) + .select_from(join(Node, n1, Node.children).join(n2, n1.children)) .filter(n2.data == "n122") .first() ) @@ -2676,7 +2670,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): node = ( sess.query(Node) .select_from( - join(Node, n1, Node.id == n1.parent_id).join(n2, "children") + join(Node, n1, Node.id == n1.parent_id).join(n2, n1.children) ) .filter(n2.data == "n122") .first() @@ -2691,7 +2685,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): node = ( sess.query(Node) - .select_from(join(Node, n1, "parent").join(n2, "parent")) + .select_from(join(Node, n1, Node.parent).join(n2, n1.parent)) .filter( and_(Node.data == "n122", n1.data == "n12", n2.data == "n1") ) @@ -2708,7 +2702,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): eq_( list( sess.query(Node) - .select_from(join(Node, n1, "parent").join(n2, "parent")) + .select_from(join(Node, n1, Node.parent).join(n2, n1.parent)) .filter( and_( Node.data == "n122", n1.data == "n12", n2.data == "n1" @@ -3085,7 +3079,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest): n1 = aliased(Node) eq_( sess.query(Node) - .select_from(join(Node, n1, "children")) + .select_from(join(Node, n1, Node.children)) .filter(n1.data.in_(["n3", "n7"])) .order_by(Node.id) .all(), diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 980c82fbe2..d8cc48939b 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2,6 +2,7 @@ import logging import logging.handlers import sqlalchemy as sa +from sqlalchemy import column from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer @@ -9,6 +10,7 @@ from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import String +from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import util from sqlalchemy.engine import default @@ -132,6 +134,22 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): ): self.mapper(User, users) + def test_no_table(self): + """test new error condition raised for table=None + + found_during_type_annotation + + """ + + User = self.classes.User + + with expect_raises_message( + sa.exc.ArgumentError, + r"Mapper\[User\(None\)\] has None for a primary table " + r"argument and does not specify 'inherits'", + ): + self.mapper(User, None) + def test_cant_call_legacy_constructor_directly(self): users, User = ( self.tables.users, @@ -341,6 +359,34 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): s, ) + def test_no_tableclause(self): + """It's not tested for a Mapper to have lower-case table() objects + as part of its collection of tables, and in particular these objects + won't report on constraints or primary keys, which while this doesn't + necessarily disqualify them from being part of a mapper, we don't + have assumptions figured out right now to accommodate them. + + found_during_type_annotation + + """ + User = self.classes.User + users = self.tables.users + + address = table( + "address", + column("address_id", Integer), + column("user_id", Integer), + ) + + with expect_raises_message( + sa.exc.ArgumentError, + "ORM mappings can only be made against schema-level Table " + "objects, not TableClause; got tableclause 'address'", + ): + self.mapper_registry.map_imperatively( + User, users.join(address, users.c.id == address.c.user_id) + ) + def test_reconfigure_on_other_mapper(self): """A configure trigger on an already-configured mapper still triggers a check against all mappers.""" @@ -666,7 +712,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): (column_property, (users.c.name,)), (relationship, (Address,)), (composite, (MyComposite, "id", "name")), - (synonym, "foo"), + (synonym, ("foo",)), ]: obj = constructor(info={"x": "y"}, *args) eq_(obj.info, {"x": "y"}) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index 96759e3889..d6fadc449c 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -11,7 +11,6 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes -from sqlalchemy.orm import class_mapper from sqlalchemy.orm import column_property from sqlalchemy.orm import contains_eager from sqlalchemy.orm import defaultload @@ -82,8 +81,7 @@ class PathTest: r = [] for i, item in enumerate(path): if i % 2 == 0: - if isinstance(item, type): - item = class_mapper(item) + item = inspect(item) else: if isinstance(item, str): item = inspect(r[-1]).mapper.attrs[item] diff --git a/test/orm/test_query.py b/test/orm/test_query.py index d0c8f41084..55414364c5 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -6389,6 +6389,23 @@ class ParentTest(QueryTest, AssertsCompiledSQL): Order(description="order 5"), ] == o + def test_invalid_property(self): + """Test if with_parent is passed a non-relationship + + found_during_type_annotation + + """ + User, Address = self.classes.User, self.classes.Address + + sess = fixture_session() + u1 = sess.get(User, 7) + with expect_raises_message( + sa_exc.ArgumentError, + r"Expected relationship property for with_parent\(\), " + "got User.name", + ): + with_parent(u1, User.name) + def test_select_from(self): User, Address = self.classes.User, self.classes.Address diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index 03c31dc0ff..c829582fdf 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -5,6 +5,7 @@ from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import Table from sqlalchemy import testing +from sqlalchemy.engine import result from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased @@ -465,8 +466,7 @@ class IdentityKeyTest(_fixtures.FixtureTest): def _cases(): return testing.combinations( - (orm_util,), - (Session,), + (orm_util,), (Session,), argnames="ormutil" ) @_cases() @@ -504,12 +504,29 @@ class IdentityKeyTest(_fixtures.FixtureTest): eq_(key, (User, (u.id,), None)) @_cases() - def test_identity_key_3(self, ormutil): + @testing.combinations("dict", "row", "mapping", argnames="rowtype") + def test_identity_key_3(self, ormutil, rowtype): + """test a real Row works with identity_key. + + this was broken w/ 1.4 future mode as we are assuming a mapping + here. to prevent regressions, identity_key now accepts any of + dict, RowMapping, Row for the "row". + + found_during_type_annotation + + + """ User, users = self.classes.User, self.tables.users self.mapper_registry.map_imperatively(User, users) - row = {users.c.id: 1, users.c.name: "Frank"} + if rowtype == "dict": + row = {users.c.id: 1, users.c.name: "Frank"} + elif rowtype in ("mapping", "row"): + row = result.result_tuple([users.c.id, users.c.name])((1, "Frank")) + if rowtype == "mapping": + row = row._mapping + key = ormutil.identity_key(User, row=row) eq_(key, (User, (1,), None)) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 9fdc519389..7fa39825c4 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -776,6 +776,28 @@ class SelectableTest( "table1.col3, table1.colx FROM table1) AS anon_1", ) + def test_reduce_cols_odd_expressions(self): + """test util.reduce_columns() works with text, non-col expressions + in a SELECT. + + found_during_type_annotation + + """ + + stmt = select( + table1.c.col1, + table1.c.col3 * 5, + text("some_expr"), + table2.c.col2, + func.foo(), + ).join(table2) + self.assert_compile( + stmt.reduce_columns(only_synonyms=False), + "SELECT table1.col1, table1.col3 * :col3_1 AS anon_1, " + "some_expr, foo() AS foo_1 FROM table1 JOIN table2 " + "ON table1.col1 = table2.col2", + ) + def test_with_only_generative_no_list(self): s1 = table1.select().scalar_subquery() -- 2.47.2