From: Mike Bayer Date: Fri, 6 May 2022 20:09:52 +0000 (-0400) Subject: revenge of pep 484 X-Git-Tag: rel_2_0_0b1~311^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=18a73fb1d1c267842ead5dacd05a49f4344d8b22;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git revenge of pep 484 trying to get remaining must-haves for ORM Change-Id: I66a3ecbbb8e5ba37c818c8a92737b576ecf012f7 --- diff --git a/doc/build/changelog/unreleased_20/map_decl.rst b/doc/build/changelog/unreleased_20/map_decl.rst new file mode 100644 index 0000000000..9e27f5d8b3 --- /dev/null +++ b/doc/build/changelog/unreleased_20/map_decl.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, orm + + Fixed issue where the :meth:`_orm.registry.map_declaratively` method + would return an internal "mapper config" object and not the + :class:`.Mapper` object as stated in the API documentation. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index f4e22df2db..d5f0d81263 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -21,6 +21,7 @@ from typing import ClassVar from typing import Dict from typing import Iterator from typing import List +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Tuple @@ -53,7 +54,11 @@ _UNPICKLED = util.symbol("unpickled") if typing.TYPE_CHECKING: + from .base import Connection + from .default import DefaultExecutionContext from .interfaces import _DBAPICursorDescription + from .interfaces import DBAPICursor + from .interfaces import Dialect from .interfaces import ExecutionContext from .result import _KeyIndexType from .result import _KeyMapRecType @@ -61,6 +66,7 @@ if typing.TYPE_CHECKING: from .result import _ProcessorsType from ..sql.type_api import _ResultProcessorType + _T = TypeVar("_T", bound=Any) # metadata entry tuple indexes. @@ -235,7 +241,7 @@ class CursorResultMetaData(ResultMetaData): ) = context.result_column_struct num_ctx_cols = len(result_columns) else: - result_columns = ( + result_columns = ( # type: ignore cols_are_ordered ) = ( num_ctx_cols @@ -776,25 +782,53 @@ class ResultFetchStrategy: alternate_cursor_description: Optional[_DBAPICursorDescription] = None - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: raise NotImplementedError() - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: raise NotImplementedError() - def yield_per(self, result, dbapi_cursor, num): + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: return - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: raise NotImplementedError() - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: raise NotImplementedError() - def fetchall(self, result): + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: raise NotImplementedError() - def handle_exception(self, result, dbapi_cursor, err): + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: raise err @@ -882,18 +916,32 @@ class CursorFetchStrategy(ResultFetchStrategy): __slots__ = () - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: result.cursor_strategy = _NO_CURSOR_DQL - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: result.cursor_strategy = _NO_CURSOR_DQL - def handle_exception(self, result, dbapi_cursor, err): + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: result.connection._handle_dbapi_exception( err, None, None, dbapi_cursor, result.context ) - def yield_per(self, result, dbapi_cursor, num): + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: result.cursor_strategy = BufferedRowCursorFetchStrategy( dbapi_cursor, {"max_row_buffer": num}, @@ -901,7 +949,12 @@ class CursorFetchStrategy(ResultFetchStrategy): growth_factor=0, ) - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: try: row = dbapi_cursor.fetchone() if row is None: @@ -910,7 +963,12 @@ class CursorFetchStrategy(ResultFetchStrategy): except BaseException as e: self.handle_exception(result, dbapi_cursor, e) - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: try: if size is None: l = dbapi_cursor.fetchmany() @@ -923,7 +981,11 @@ class CursorFetchStrategy(ResultFetchStrategy): except BaseException as e: self.handle_exception(result, dbapi_cursor, e) - def fetchall(self, result, dbapi_cursor): + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: try: rows = dbapi_cursor.fetchall() result._soft_close() @@ -1163,6 +1225,9 @@ class _NoResultMetaData(ResultMetaData): _NO_RESULT_METADATA = _NoResultMetaData() +SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]") + + class CursorResult(Result[_T]): """A Result that is representing state from a DBAPI cursor. @@ -1199,7 +1264,17 @@ class CursorResult(Result[_T]): closed: bool = False _is_cursor = True - def __init__(self, context, cursor_strategy, cursor_description): + context: DefaultExecutionContext + dialect: Dialect + cursor_strategy: ResultFetchStrategy + connection: Connection + + def __init__( + self, + context: DefaultExecutionContext, + cursor_strategy: ResultFetchStrategy, + cursor_description: Optional[_DBAPICursorDescription], + ): self.context = context self.dialect = context.dialect self.cursor = context.cursor @@ -1333,7 +1408,7 @@ class CursorResult(Result[_T]): if not self._soft_closed: cursor = self.cursor - self.cursor = None + self.cursor = None # type: ignore self.connection._safe_close_cursor(cursor) self._soft_closed = True @@ -1605,7 +1680,7 @@ class CursorResult(Result[_T]): return self.dialect.supports_sane_multi_rowcount @util.memoized_property - def rowcount(self): + def rowcount(self) -> int: """Return the 'rowcount' for this result. The 'rowcount' reports the number of rows *matched* @@ -1655,6 +1730,7 @@ class CursorResult(Result[_T]): return self.context.rowcount except BaseException as e: self.cursor_strategy.handle_exception(self, self.cursor, e) + raise # not called @property def lastrowid(self): @@ -1749,7 +1825,7 @@ class CursorResult(Result[_T]): ) return merged_result - def close(self): + def close(self) -> Any: """Close this :class:`_engine.CursorResult`. This closes out the underlying DBAPI cursor corresponding to the @@ -1772,7 +1848,7 @@ class CursorResult(Result[_T]): self._soft_close(hard=True) @_generative - def yield_per(self, num): + def yield_per(self: SelfCursorResult, num: int) -> SelfCursorResult: self._yield_per = num self.cursor_strategy.yield_per(self, self.cursor, num) return self diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 6094ad0fbb..fc114efa3a 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -64,6 +64,7 @@ if typing.TYPE_CHECKING: from .base import Engine from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPICursorDescription from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _IsolationLevel @@ -1285,8 +1286,8 @@ class DefaultExecutionContext(ExecutionContext): def handle_dbapi_exception(self, e): pass - @property - def rowcount(self): + @util.non_memoized_property + def rowcount(self) -> int: return self.cursor.rowcount def supports_sane_rowcount(self): @@ -1304,7 +1305,7 @@ class DefaultExecutionContext(ExecutionContext): strategy = _cursor.BufferedRowCursorFetchStrategy( self.cursor, self.execution_options ) - cursor_description = ( + cursor_description: _DBAPICursorDescription = ( strategy.alternate_cursor_description or self.cursor.description ) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 6410246039..e5414b70f3 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -133,17 +133,7 @@ class DBAPICursor(Protocol): @property def description( self, - ) -> Sequence[ - Tuple[ - str, - "DBAPIType", - Optional[int], - Optional[int], - Optional[int], - Optional[int], - Optional[bool], - ] - ]: + ) -> _DBAPICursorDescription: """The description attribute of the Cursor. .. seealso:: @@ -217,7 +207,15 @@ _DBAPIMultiExecuteParams = Union[ _DBAPIAnyExecuteParams = Union[ _DBAPIMultiExecuteParams, _DBAPISingleExecuteParams ] -_DBAPICursorDescription = Tuple[str, Any, Any, Any, Any, Any, Any] +_DBAPICursorDescription = Tuple[ + str, + "DBAPIType", + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[bool], +] _AnySingleExecuteParams = _DBAPISingleExecuteParams _AnyMultiExecuteParams = _DBAPIMultiExecuteParams @@ -2297,6 +2295,9 @@ class ExecutionContext: """ + engine: Engine + """engine which the Connection is associated with""" + connection: Connection """Connection object which can be freely used by default value generators to execute SQL. This Connection should reference the diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 420ba5c8c3..7db95eac9b 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -53,7 +53,6 @@ from ..orm import ORMDescriptor from ..orm.base import SQLORMOperations from ..sql import operators from ..sql import or_ -from ..sql.elements import SQLCoreOperations from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self @@ -64,8 +63,10 @@ if typing.TYPE_CHECKING: from ..orm.interfaces import MapperProperty from ..orm.interfaces import PropComparator from ..orm.mapper import Mapper + from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _InfoType + _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) _T_con = TypeVar("_T_con", bound=Any, contravariant=True) @@ -631,7 +632,9 @@ class AssociationProxyInstance(SQLORMOperations[_T]): @property def _comparator(self) -> PropComparator[Any]: - return self._get_property().comparator + return getattr( # type: ignore + self.owning_class, self.target_collection + ).comparator def __clause_element__(self) -> NoReturn: raise NotImplementedError( @@ -957,7 +960,9 @@ class AssociationProxyInstance(SQLORMOperations[_T]): proxy.setter = setter def _criterion_exists( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[bool]: is_has = kwargs.pop("is_has", None) @@ -969,8 +974,8 @@ class AssociationProxyInstance(SQLORMOperations[_T]): return self._comparator._criterion_exists(inner) if self._target_is_object: - prop = getattr(self.target_class, self.value_attr) - value_expr = prop._criterion_exists(criterion, **kwargs) + attr = getattr(self.target_class, self.value_attr) + value_expr = attr.comparator._criterion_exists(criterion, **kwargs) else: if kwargs: raise exc.ArgumentError( @@ -988,8 +993,10 @@ class AssociationProxyInstance(SQLORMOperations[_T]): return self._comparator._criterion_exists(value_expr) def any( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any - ) -> SQLCoreOperations[Any]: + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: """Produce a proxied 'any' expression using EXISTS. This expression will be a composed product @@ -1010,8 +1017,10 @@ class AssociationProxyInstance(SQLORMOperations[_T]): ) def has( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any - ) -> SQLCoreOperations[Any]: + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: """Produce a proxied 'has' expression using EXISTS. This expression will be a composed product @@ -1069,12 +1078,16 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]): self._ambiguous() def any( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> NoReturn: self._ambiguous() def has( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> NoReturn: self._ambiguous() diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py index 9faf2ed51f..22fa83c58f 100644 --- a/lib/sqlalchemy/ext/declarative/extensions.py +++ b/lib/sqlalchemy/ext/declarative/extensions.py @@ -8,7 +8,10 @@ """Public API functions and helpers for declarative.""" +from __future__ import annotations +from typing import Callable +from typing import TYPE_CHECKING from ... import inspection from ...orm import exc as orm_exc @@ -20,6 +23,10 @@ from ...orm.util import polymorphic_union from ...schema import Table from ...util import OrderedDict +if TYPE_CHECKING: + from ...engine.reflection import Inspector + from ...sql.schema import MetaData + class ConcreteBase: """A helper class for 'concrete' declarative mappings. @@ -380,31 +387,36 @@ class DeferredReflection: mapper = thingy.cls.__mapper__ metadata = mapper.class_.metadata for rel in mapper._props.values(): + if ( isinstance(rel, relationships.Relationship) - and rel.secondary is not None + and rel._init_args.secondary._is_populated() ): - if isinstance(rel.secondary, Table): - cls._reflect_table(rel.secondary, insp) - elif isinstance(rel.secondary, str): + + secondary_arg = rel._init_args.secondary + + if isinstance(secondary_arg.argument, Table): + cls._reflect_table(secondary_arg.argument, insp) + elif isinstance(secondary_arg.argument, str): _, resolve_arg = _resolver(rel.parent.class_, rel) - rel.secondary = resolve_arg(rel.secondary) - rel.secondary._resolvers += ( + resolver = resolve_arg( + secondary_arg.argument, True + ) + resolver._resolvers += ( cls._sa_deferred_table_resolver( insp, metadata ), ) - # controversy! do we resolve it here? or leave - # it deferred? I think doing it here is necessary - # so the connection does not leak. - rel.secondary = rel.secondary() + secondary_arg.argument = resolver() @classmethod - def _sa_deferred_table_resolver(cls, inspector, metadata): - def _resolve(key): + def _sa_deferred_table_resolver( + cls, inspector: Inspector, metadata: MetaData + ) -> Callable[[str], Table]: + def _resolve(key: str) -> Table: t1 = Table(key, metadata) cls._reflect_table(t1, inspector) return t1 diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index ea558495b4..accfa8949c 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -1305,8 +1305,8 @@ class Comparator(interfaces.PropComparator[_T]): return ret_expr @util.non_memoized_property - def property(self) -> Optional[interfaces.MapperProperty[_T]]: - return None + def property(self) -> interfaces.MapperProperty[_T]: + raise NotImplementedError() def adapt_to_entity( self, adapt_to_entity: AliasedInsp[Any] @@ -1344,7 +1344,7 @@ class ExprComparator(Comparator[_T]): return [(self.expression, value)] @util.non_memoized_property - def property(self) -> Optional[MapperProperty[_T]]: + def property(self) -> MapperProperty[_T]: # this accessor is not normally used, however is accessed by things # like ORM synonyms if the hybrid is used in this context; the # .property attribute is not necessarily accessible diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 560db9817b..18a18bd800 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls from __future__ import annotations @@ -12,6 +11,8 @@ import typing from typing import Any from typing import Callable from typing import Collection +from typing import Iterable +from typing import NoReturn from typing import Optional from typing import overload from typing import Type @@ -45,6 +46,7 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _EntityType from ._typing import _ORMColumnExprArgument + from .descriptor_props import _CC from .descriptor_props import _CompositeAttrType from .interfaces import PropComparator from .mapper import Mapper @@ -54,14 +56,19 @@ if TYPE_CHECKING: from .relationships import _ORMColCollectionArgument from .relationships import _ORMOrderByArgument from .relationships import _RelationshipJoinConditionArgument + from .session import _SessionBind from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _FromClauseArgument from ..sql._typing import _InfoType + from ..sql._typing import _OnClauseArgument from ..sql._typing import _TypeEngineArgument + from ..sql.elements import ColumnElement from ..sql.schema import _ServerDefaultType from ..sql.schema import FetchedValue from ..sql.selectable import Alias from ..sql.selectable import Subquery + _T = typing.TypeVar("_T") @@ -424,10 +431,10 @@ def column_property( @overload def composite( - class_: Type[_T], + class_: Type[_CC], *attrs: _CompositeAttrType[Any], **kwargs: Any, -) -> Composite[_T]: +) -> Composite[_CC]: ... @@ -680,7 +687,7 @@ def with_loader_criteria( def relationship( argument: Optional[_RelationshipArgumentType[Any]] = None, - secondary: Optional[FromClause] = None, + secondary: Optional[Union[FromClause, str]] = None, *, uselist: Optional[bool] = None, collection_class: Optional[ @@ -696,14 +703,14 @@ def relationship( cascade: str = "save-update, merge", viewonly: bool = False, lazy: _LazyLoadArgumentType = "select", - passive_deletes: bool = False, + passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, active_history: bool = False, enable_typechecks: bool = True, foreign_keys: Optional[_ORMColCollectionArgument] = None, remote_side: Optional[_ORMColCollectionArgument] = None, join_depth: Optional[int] = None, - comparator_factory: Optional[Type[PropComparator[Any]]] = None, + comparator_factory: Optional[Type[Relationship.Comparator[Any]]] = None, single_parent: bool = False, innerjoin: bool = False, distinct_target_key: Optional[bool] = None, @@ -1660,10 +1667,19 @@ def synonym( than can be achieved with synonyms. """ - return Synonym(name, map_column, descriptor, comparator_factory, doc, info) + return Synonym( + name, + map_column=map_column, + descriptor=descriptor, + comparator_factory=comparator_factory, + doc=doc, + info=info, + ) -def create_session(bind=None, **kwargs): +def create_session( + bind: Optional[_SessionBind] = None, **kwargs: Any +) -> Session: r"""Create a new :class:`.Session` with no automation enabled by default. @@ -1699,7 +1715,7 @@ def create_session(bind=None, **kwargs): return Session(bind=bind, **kwargs) -def _mapper_fn(*arg, **kw): +def _mapper_fn(*arg: Any, **kw: Any) -> NoReturn: """Placeholder for the now-removed ``mapper()`` function. Classical mappings should be performed using the @@ -1726,7 +1742,9 @@ def _mapper_fn(*arg, **kw): ) -def dynamic_loader(argument, **kw): +def dynamic_loader( + argument: Optional[_RelationshipArgumentType[Any]] = None, **kw: Any +) -> Relationship[Any]: """Construct a dynamically-loading mapper property. This is essentially the same as @@ -1746,7 +1764,7 @@ def dynamic_loader(argument, **kw): return relationship(argument, **kw) -def backref(name, **kwargs): +def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument: """Create a back reference with explicit keyword arguments, which are the same arguments one can send to :func:`relationship`. @@ -1765,7 +1783,11 @@ def backref(name, **kwargs): return (name, kwargs) -def deferred(*columns, **kw): +def deferred( + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + **kw: Any, +) -> ColumnProperty[_T]: r"""Indicate a column-based mapped attribute that by default will not load unless accessed. @@ -1791,7 +1813,8 @@ def deferred(*columns, **kw): :ref:`deferred` """ - return ColumnProperty(deferred=True, *columns, **kw) + kw["deferred"] = True + return ColumnProperty(column, *additional_columns, **kw) def query_expression( @@ -1824,7 +1847,7 @@ def query_expression( return prop -def clear_mappers(): +def clear_mappers() -> None: """Remove all mappers from all classes. .. versionchanged:: 1.4 This function now locates all @@ -2003,16 +2026,16 @@ def aliased( def with_polymorphic( - base, - classes, - selectable=False, - flat=False, - polymorphic_on=None, - aliased=False, - adapt_on_names=False, - innerjoin=False, - _use_mapper_path=False, -): + base: Union[_O, Mapper[_O]], + classes: Iterable[Type[Any]], + selectable: Union[Literal[False, None], FromClause] = False, + flat: bool = False, + polymorphic_on: Optional[ColumnElement[Any]] = None, + aliased: bool = False, + innerjoin: bool = False, + adapt_on_names: bool = False, + _use_mapper_path: bool = False, +) -> AliasedClass[_O]: """Produce an :class:`.AliasedClass` construct which specifies columns for descendant mappers of the given base. @@ -2096,7 +2119,13 @@ def with_polymorphic( ) -def join(left, right, onclause=None, isouter=False, full=False): +def join( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, +) -> _ORMJoin: r"""Produce an inner join between left and right clauses. :func:`_orm.join` is an extension to the core join interface @@ -2135,7 +2164,12 @@ def join(left, right, onclause=None, isouter=False, full=False): return _ORMJoin(left, right, onclause, isouter, full) -def outerjoin(left, right, onclause=None, full=False): +def outerjoin( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, +) -> _ORMJoin: """Produce a left outer join between left and right clauses. This is the "outer join" version of the :func:`_orm.join` function, diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 29d82340ab..0e624afe2a 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -2,8 +2,8 @@ from __future__ import annotations import operator from typing import Any -from typing import Callable from typing import Dict +from typing import Mapping from typing import Optional from typing import Tuple from typing import Type @@ -20,9 +20,12 @@ from ..util.typing import TypeGuard if TYPE_CHECKING: from .attributes import AttributeImpl from .attributes import CollectionAttributeImpl + from .attributes import HasCollectionAdapter + from .attributes import QueryableAttribute from .base import PassiveFlag from .decl_api import registry as _registry_type from .descriptor_props import _CompositeClassProto + from .interfaces import InspectionAttr from .interfaces import MapperProperty from .interfaces import UserDefinedOption from .mapper import Mapper @@ -30,11 +33,14 @@ if TYPE_CHECKING: from .state import InstanceState from .util import AliasedClass from .util import AliasedInsp + from ..sql._typing import _CE from ..sql.base import ExecutableOption _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + # I would have preferred this were bound=object however it seems # to not travel in all situations when defined in that way. _O = TypeVar("_O", bound=Any) @@ -42,6 +48,12 @@ _O = TypeVar("_O", bound=Any) """ +_OO = TypeVar("_OO", bound=object) +"""The 'ORM mapped object, that's definitely object' type. + +""" + + if TYPE_CHECKING: _RegistryType = _registry_type @@ -54,6 +66,7 @@ _EntityType = Union[ ] +_ClassDict = Mapping[str, Any] _InstanceDict = Dict[str, Any] _IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]] @@ -64,10 +77,19 @@ _ORMColumnExprArgument = Union[ roles.ExpressionElementRole[_T], ] -# somehow Protocol didn't want to work for this one -_ORMAdapterProto = Callable[ - [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T] -] + +_ORMCOLEXPR = TypeVar("_ORMCOLEXPR", bound=ColumnElement[Any]) + + +class _ORMAdapterProto(Protocol): + """protocol for the :class:`.AliasedInsp._orm_adapt_element` method + which is a synonym for :class:`.AliasedInsp._adapt_element`. + + + """ + + def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: + ... class _LoaderCallable(Protocol): @@ -96,6 +118,16 @@ if TYPE_CHECKING: def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: ... + def insp_is_attribute( + obj: InspectionAttr, + ) -> TypeGuard[QueryableAttribute[Any]]: + ... + + def attr_is_internal_proxy( + obj: InspectionAttr, + ) -> TypeGuard[QueryableAttribute[Any]]: + ... + def prop_is_relationship( prop: MapperProperty[Any], ) -> TypeGuard[Relationship[Any]]: @@ -106,9 +138,19 @@ if TYPE_CHECKING: ) -> TypeGuard[CollectionAttributeImpl]: ... + def is_has_collection_adapter( + impl: AttributeImpl, + ) -> TypeGuard[HasCollectionAdapter]: + ... + else: insp_is_mapper_property = operator.attrgetter("is_property") insp_is_mapper = operator.attrgetter("is_mapper") insp_is_aliased_class = operator.attrgetter("is_aliased_class") + insp_is_attribute = operator.attrgetter("is_attribute") + attr_is_internal_proxy = operator.attrgetter("_is_internal_proxy") is_collection_impl = operator.attrgetter("collection") prop_is_relationship = operator.attrgetter("_is_relationship") + is_has_collection_adapter = operator.attrgetter( + "_is_has_collection_adapter" + ) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 9aeaeaa272..b5faa7cbf1 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -117,7 +117,9 @@ class NoKey(str): pass -_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]] +_AllPendingType = Sequence[ + Tuple[Optional["InstanceState[Any]"], Optional[object]] +] NO_KEY = NoKey("no name") @@ -798,6 +800,8 @@ class AttributeImpl: supports_population: bool dynamic: bool + _is_has_collection_adapter = False + _replace_token: AttributeEventToken _remove_token: AttributeEventToken _append_token: AttributeEventToken @@ -1140,7 +1144,7 @@ class AttributeImpl: state: InstanceState[Any], dict_: _InstanceDict, value: Any, - initiator: Optional[AttributeEventToken], + initiator: Optional[AttributeEventToken] = None, passive: PassiveFlag = PASSIVE_OFF, check_old: Any = None, pop: bool = False, @@ -1236,7 +1240,7 @@ class ScalarAttributeImpl(AttributeImpl): state: InstanceState[Any], dict_: Dict[str, Any], value: Any, - initiator: Optional[AttributeEventToken], + initiator: Optional[AttributeEventToken] = None, passive: PassiveFlag = PASSIVE_OFF, check_old: Optional[object] = None, pop: bool = False, @@ -1402,7 +1406,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): state: InstanceState[Any], dict_: _InstanceDict, value: Any, - initiator: Optional[AttributeEventToken], + initiator: Optional[AttributeEventToken] = None, passive: PassiveFlag = PASSIVE_OFF, check_old: Any = None, pop: bool = False, @@ -1494,6 +1498,9 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): class HasCollectionAdapter: __slots__ = () + collection: bool + _is_has_collection_adapter = True + def _dispose_previous_collection( self, state: InstanceState[Any], @@ -1508,7 +1515,7 @@ class HasCollectionAdapter: self, state: InstanceState[Any], dict_: _InstanceDict, - user_data: Optional[_AdaptedCollectionProtocol] = None, + user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., ) -> CollectionAdapter: ... @@ -1518,8 +1525,18 @@ class HasCollectionAdapter: self, state: InstanceState[Any], dict_: _InstanceDict, - user_data: Optional[_AdaptedCollectionProtocol] = None, - passive: PassiveFlag = PASSIVE_OFF, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., + passive: PassiveFlag = ..., ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter ]: @@ -1530,12 +1547,25 @@ class HasCollectionAdapter: state: InstanceState[Any], dict_: _InstanceDict, user_data: Optional[_AdaptedCollectionProtocol] = None, - passive: PassiveFlag = PASSIVE_OFF, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter ]: raise NotImplementedError() + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: + raise NotImplementedError() + if TYPE_CHECKING: @@ -1790,7 +1820,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): initiator: Optional[AttributeEventToken], passive: PassiveFlag = PASSIVE_OFF, ) -> None: - collection = self.get_collection(state, dict_, passive=passive) + collection = self.get_collection( + state, dict_, user_data=None, passive=passive + ) if collection is PASSIVE_NO_RESULT: value = self.fire_append_event(state, dict_, value, initiator) assert ( @@ -1810,7 +1842,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): initiator: Optional[AttributeEventToken], passive: PassiveFlag = PASSIVE_OFF, ) -> None: - collection = self.get_collection(state, state.dict, passive=passive) + collection = self.get_collection( + state, state.dict, user_data=None, passive=passive + ) if collection is PASSIVE_NO_RESULT: self.fire_remove_event(state, dict_, value, initiator) assert ( @@ -1844,7 +1878,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): dict_: _InstanceDict, value: Any, initiator: Optional[AttributeEventToken] = None, - passive: PassiveFlag = PASSIVE_OFF, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, check_old: Any = None, pop: bool = False, _adapt: bool = True, @@ -1963,7 +1997,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): self, state: InstanceState[Any], dict_: _InstanceDict, - user_data: Optional[_AdaptedCollectionProtocol] = None, + user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., ) -> CollectionAdapter: ... @@ -1973,7 +2007,17 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): self, state: InstanceState[Any], dict_: _InstanceDict, - user_data: Optional[_AdaptedCollectionProtocol] = None, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., passive: PassiveFlag = PASSIVE_OFF, ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter @@ -2490,7 +2534,7 @@ def register_attribute_impl( impl_class: Optional[Type[AttributeImpl]] = None, backref: Optional[str] = None, **kw: Any, -) -> InstrumentedAttribute[Any]: +) -> QueryableAttribute[Any]: manager = manager_of_class(class_) if uselist: @@ -2599,7 +2643,7 @@ def init_state_collection( attr._dispose_previous_collection(state, old, old_collection, False) user_data = attr._default_value(state, dict_) - adapter = attr.get_collection(state, dict_, user_data) + adapter: CollectionAdapter = attr.get_collection(state, dict_, user_data) adapter._reset_empty() return adapter diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 0ace9b1cb6..63f873fd0e 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -18,6 +18,7 @@ from typing import Any from typing import Callable from typing import Dict from typing import Generic +from typing import no_type_check from typing import Optional from typing import overload from typing import Type @@ -35,17 +36,20 @@ from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag from ..util.langhelpers import TypingOnly from ..util.typing import Literal -from ..util.typing import Self if typing.TYPE_CHECKING: + from ._typing import _EntityType from ._typing import _ExternalEntityType from ._typing import _InternalEntityType from .attributes import InstrumentedAttribute from .instrumentation import ClassManager + from .interfaces import PropComparator from .mapper import Mapper from .state import InstanceState from .util import AliasedClass + from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement _T = TypeVar("_T", bound=Any) @@ -191,35 +195,34 @@ EXT_CONTINUE = util.symbol("EXT_CONTINUE") EXT_STOP = util.symbol("EXT_STOP") EXT_SKIP = util.symbol("EXT_SKIP") -ONETOMANY = util.symbol( - "ONETOMANY", + +class RelationshipDirection(Enum): + ONETOMANY = 1 """Indicates the one-to-many direction for a :func:`_orm.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """, -) + """ -MANYTOONE = util.symbol( - "MANYTOONE", + MANYTOONE = 2 """Indicates the many-to-one direction for a :func:`_orm.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """, -) + """ -MANYTOMANY = util.symbol( - "MANYTOMANY", + MANYTOMANY = 3 """Indicates the many-to-many direction for a :func:`_orm.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """, -) + """ + + +ONETOMANY, MANYTOONE, MANYTOMANY = tuple(RelationshipDirection) class InspectionAttrExtensionType(Enum): @@ -249,7 +252,7 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_F = TypeVar("_F", bound=Callable) +_F = TypeVar("_F", bound=Callable[..., Any]) _Self = TypeVar("_Self") @@ -397,29 +400,34 @@ def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]: return None -def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]: +def _class_to_mapper( + class_or_mapper: Union[Mapper[_T], Type[_T]] +) -> Mapper[_T]: + # can't get mypy to see an overload for this insp = inspection.inspect(class_or_mapper, False) if insp is not None: - return insp.mapper + return insp.mapper # type: ignore else: + assert isinstance(class_or_mapper, type) raise exc.UnmappedClassError(class_or_mapper) def _mapper_or_none( - entity: Union[_T, _InternalEntityType[_T]] + entity: Union[Type[_T], _InternalEntityType[_T]] ) -> Optional[Mapper[_T]]: """Return the :class:`_orm.Mapper` for the given class or None if the class is not mapped. """ + # can't get mypy to see an overload for this insp = inspection.inspect(entity, False) if insp is not None: - return insp.mapper + return insp.mapper # type: ignore else: return None -def _is_mapped_class(entity): +def _is_mapped_class(entity: Any) -> bool: """Return True if the given object is a mapped class, :class:`_orm.Mapper`, or :class:`.AliasedClass`. """ @@ -432,20 +440,13 @@ def _is_mapped_class(entity): ) -def _orm_columns(entity): - insp = inspection.inspect(entity, False) - if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"): - return [c for c in insp.selectable.c] - else: - return [entity] - - -def _is_aliased_class(entity): +def _is_aliased_class(entity: Any) -> bool: insp = inspection.inspect(entity, False) return insp is not None and getattr(insp, "is_aliased_class", False) -def _entity_descriptor(entity, key): +@no_type_check +def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any: """Return a class attribute given an entity and string name. May return :class:`.InstrumentedAttribute` or user-defined @@ -651,16 +652,26 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): if typing.TYPE_CHECKING: - def of_type(self, class_): + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: ... - def and_(self, *criteria): + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> PropComparator[bool]: ... - def any(self, criterion=None, **kwargs): # noqa: A001 + def any( # noqa: A001 + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: ... - def has(self, criterion=None, **kwargs): + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: ... @@ -673,7 +684,9 @@ class ORMDescriptor(Generic[_T], TypingOnly): if typing.TYPE_CHECKING: @overload - def __get__(self: Self, instance: Any, owner: Literal[None]) -> Self: + def __get__( + self, instance: Any, owner: Literal[None] + ) -> ORMDescriptor[_T]: ... @overload diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 473468c6cd..b3fcd29ea3 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """Routines to handle the string class registry used by declarative. @@ -16,7 +15,22 @@ This system allows specification of classes and expressions used in from __future__ import annotations import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Mapping from typing import MutableMapping +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -29,6 +43,14 @@ from .. import exc from .. import inspection from .. import util from ..sql.schema import _get_table_key +from ..util.typing import CallableReference + +if TYPE_CHECKING: + from .relationships import Relationship + from ..sql.schema import MetaData + from ..sql.schema import Table + +_T = TypeVar("_T", bound=Any) _ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] @@ -36,10 +58,12 @@ _ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] # the _decl_class_registry, which is usually weak referencing. # the internal registries here link to classes with weakrefs and remove # themselves when all references to contained classes are removed. -_registries = set() +_registries: Set[ClsRegistryToken] = set() -def add_class(classname, cls, decl_class_registry): +def add_class( + classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType +) -> None: """Add a class to the _decl_class_registry associated with the given declarative class. @@ -49,13 +73,15 @@ def add_class(classname, cls, decl_class_registry): existing = decl_class_registry[classname] if not isinstance(existing, _MultipleClassMarker): existing = decl_class_registry[classname] = _MultipleClassMarker( - [cls, existing] + [cls, cast("Type[Any]", existing)] ) else: decl_class_registry[classname] = cls try: - root_module = decl_class_registry["_sa_module_registry"] + root_module = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) except KeyError: decl_class_registry[ "_sa_module_registry" @@ -79,7 +105,9 @@ def add_class(classname, cls, decl_class_registry): module.add_class(classname, cls) -def remove_class(classname, cls, decl_class_registry): +def remove_class( + classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType +) -> None: if classname in decl_class_registry: existing = decl_class_registry[classname] if isinstance(existing, _MultipleClassMarker): @@ -88,7 +116,9 @@ def remove_class(classname, cls, decl_class_registry): del decl_class_registry[classname] try: - root_module = decl_class_registry["_sa_module_registry"] + root_module = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) except KeyError: return @@ -102,7 +132,11 @@ def remove_class(classname, cls, decl_class_registry): module.remove_class(classname, cls) -def _key_is_empty(key, decl_class_registry, test): +def _key_is_empty( + key: str, + decl_class_registry: _ClsRegistryType, + test: Callable[[Any], bool], +) -> bool: """test if a key is empty of a certain object. used for unit tests against the registry to see if garbage collection @@ -124,6 +158,8 @@ def _key_is_empty(key, decl_class_registry, test): for sub_thing in thing.contents: if test(sub_thing): return False + else: + raise NotImplementedError("unknown codepath") else: return not test(thing) @@ -142,20 +178,27 @@ class _MultipleClassMarker(ClsRegistryToken): __slots__ = "on_remove", "contents", "__weakref__" - def __init__(self, classes, on_remove=None): + contents: Set[weakref.ref[Type[Any]]] + on_remove: CallableReference[Optional[Callable[[], None]]] + + def __init__( + self, + classes: Iterable[Type[Any]], + on_remove: Optional[Callable[[], None]] = None, + ): self.on_remove = on_remove self.contents = set( [weakref.ref(item, self._remove_item) for item in classes] ) _registries.add(self) - def remove_item(self, cls): + def remove_item(self, cls: Type[Any]) -> None: self._remove_item(weakref.ref(cls)) - def __iter__(self): + def __iter__(self) -> Generator[Optional[Type[Any]], None, None]: return (ref() for ref in self.contents) - def attempt_get(self, path, key): + def attempt_get(self, path: List[str], key: str) -> Type[Any]: if len(self.contents) > 1: raise exc.InvalidRequestError( 'Multiple classes found for path "%s" ' @@ -170,14 +213,14 @@ class _MultipleClassMarker(ClsRegistryToken): raise NameError(key) return cls - def _remove_item(self, ref): + def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None: self.contents.discard(ref) if not self.contents: _registries.discard(self) if self.on_remove: self.on_remove() - def add_item(self, item): + def add_item(self, item: Type[Any]) -> None: # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, # [ticket:3208] @@ -206,7 +249,12 @@ class _ModuleMarker(ClsRegistryToken): __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" - def __init__(self, name, parent): + parent: Optional[_ModuleMarker] + contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]] + mod_ns: _ModNS + path: List[str] + + def __init__(self, name: str, parent: Optional[_ModuleMarker]): self.parent = parent self.name = name self.contents = {} @@ -217,51 +265,53 @@ class _ModuleMarker(ClsRegistryToken): self.path = [] _registries.add(self) - def __contains__(self, name): + def __contains__(self, name: str) -> bool: return name in self.contents - def __getitem__(self, name): + def __getitem__(self, name: str) -> ClsRegistryToken: return self.contents[name] - def _remove_item(self, name): + def _remove_item(self, name: str) -> None: self.contents.pop(name, None) if not self.contents and self.parent is not None: self.parent._remove_item(self.name) _registries.discard(self) - def resolve_attr(self, key): - return getattr(self.mod_ns, key) + def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: + return self.mod_ns.__getattr__(key) - def get_module(self, name): + def get_module(self, name: str) -> _ModuleMarker: if name not in self.contents: marker = _ModuleMarker(name, self) self.contents[name] = marker else: - marker = self.contents[name] + marker = cast(_ModuleMarker, self.contents[name]) return marker - def add_class(self, name, cls): + def add_class(self, name: str, cls: Type[Any]) -> None: if name in self.contents: - existing = self.contents[name] + existing = cast(_MultipleClassMarker, self.contents[name]) existing.add_item(cls) else: existing = self.contents[name] = _MultipleClassMarker( [cls], on_remove=lambda: self._remove_item(name) ) - def remove_class(self, name, cls): + def remove_class(self, name: str, cls: Type[Any]) -> None: if name in self.contents: - existing = self.contents[name] + existing = cast(_MultipleClassMarker, self.contents[name]) existing.remove_item(cls) class _ModNS: __slots__ = ("__parent",) - def __init__(self, parent): + __parent: _ModuleMarker + + def __init__(self, parent: _ModuleMarker): self.__parent = parent - def __getattr__(self, key): + def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]: try: value = self.__parent.contents[key] except KeyError: @@ -282,10 +332,12 @@ class _ModNS: class _GetColumns: __slots__ = ("cls",) - def __init__(self, cls): + cls: Type[Any] + + def __init__(self, cls: Type[Any]): self.cls = cls - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: mp = class_mapper(self.cls, configure=False) if mp: if key not in mp.all_orm_descriptors: @@ -296,6 +348,7 @@ class _GetColumns: desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION: + assert isinstance(desc, attributes.QueryableAttribute) prop = desc.property if isinstance(prop, Synonym): key = prop.name @@ -316,15 +369,18 @@ inspection._inspects(_GetColumns)( class _GetTable: __slots__ = "key", "metadata" - def __init__(self, key, metadata): + key: str + metadata: MetaData + + def __init__(self, key: str, metadata: MetaData): self.key = key self.metadata = metadata - def __getattr__(self, key): + def __getattr__(self, key: str) -> Table: return self.metadata.tables[_get_table_key(key, self.key)] -def _determine_container(key, value): +def _determine_container(key: str, value: Any) -> _GetColumns: if isinstance(value, _MultipleClassMarker): value = value.attempt_get([], key) return _GetColumns(value) @@ -341,7 +397,21 @@ class _class_resolver: "favor_tables", ) - def __init__(self, cls, prop, fallback, arg, favor_tables=False): + cls: Type[Any] + prop: Relationship[Any] + fallback: Mapping[str, Any] + arg: str + favor_tables: bool + _resolvers: Tuple[Callable[[str], Any], ...] + + def __init__( + self, + cls: Type[Any], + prop: Relationship[Any], + fallback: Mapping[str, Any], + arg: str, + favor_tables: bool = False, + ): self.cls = cls self.prop = prop self.arg = arg @@ -350,11 +420,12 @@ class _class_resolver: self._resolvers = () self.favor_tables = favor_tables - def _access_cls(self, key): + def _access_cls(self, key: str) -> Any: cls = self.cls manager = attributes.manager_of_class(cls) decl_base = manager.registry + assert decl_base is not None decl_class_registry = decl_base._class_registry metadata = decl_base.metadata @@ -362,7 +433,7 @@ class _class_resolver: if key in metadata.tables: return metadata.tables[key] elif key in metadata._schemas: - return _GetTable(key, cls.metadata) + return _GetTable(key, getattr(cls, "metadata", metadata)) if key in decl_class_registry: return _determine_container(key, decl_class_registry[key]) @@ -371,13 +442,14 @@ class _class_resolver: if key in metadata.tables: return metadata.tables[key] elif key in metadata._schemas: - return _GetTable(key, cls.metadata) + return _GetTable(key, getattr(cls, "metadata", metadata)) - if ( - "_sa_module_registry" in decl_class_registry - and key in decl_class_registry["_sa_module_registry"] + if "_sa_module_registry" in decl_class_registry and key in cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] ): - registry = decl_class_registry["_sa_module_registry"] + registry = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) return registry.resolve_attr(key) elif self._resolvers: for resolv in self._resolvers: @@ -387,7 +459,7 @@ class _class_resolver: return self.fallback[key] - def _raise_for_name(self, name, err): + def _raise_for_name(self, name: str, err: Exception) -> NoReturn: generic_match = re.match(r"(.+)\[(.+)\]", name) if generic_match: @@ -409,7 +481,7 @@ class _class_resolver: % (self.prop.parent, self.arg, name, self.cls) ) from err - def _resolve_name(self): + def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]: name = self.arg d = self._dict rval = None @@ -427,9 +499,11 @@ class _class_resolver: if isinstance(rval, _GetColumns): return rval.cls else: + if TYPE_CHECKING: + assert isinstance(rval, (type, Table, _ModNS)) return rval - def __call__(self): + def __call__(self) -> Any: try: x = eval(self.arg, globals(), self._dict) @@ -441,10 +515,15 @@ class _class_resolver: self._raise_for_name(n.args[0], n) -_fallback_dict = None +_fallback_dict: Mapping[str, Any] = None # type: ignore -def _resolver(cls, prop): +def _resolver( + cls: Type[Any], prop: Relationship[Any] +) -> Tuple[ + Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], + Callable[[str, bool], _class_resolver], +]: global _fallback_dict @@ -456,12 +535,14 @@ def _resolver(cls, prop): {"foreign": foreign, "remote": remote} ) - def resolve_arg(arg, favor_tables=False): + def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver: return _class_resolver( cls, prop, _fallback_dict, arg, favor_tables=favor_tables ) - def resolve_name(arg): + def resolve_name( + arg: str, + ) -> Callable[[], Union[Type[Any], Table, _ModNS]]: return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name return resolve_name, resolve_arg diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index da0da0fcfc..78fe89d05f 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -115,6 +115,7 @@ from typing import Collection from typing import Dict from typing import Iterable from typing import List +from typing import NoReturn from typing import Optional from typing import Set from typing import Tuple @@ -130,6 +131,7 @@ from ..util.compat import inspect_getfullargspec from ..util.typing import Protocol if typing.TYPE_CHECKING: + from .attributes import AttributeEventToken from .attributes import CollectionAttributeImpl from .mapped_collection import attribute_mapped_collection from .mapped_collection import column_mapped_collection @@ -500,7 +502,7 @@ class CollectionAdapter: self.invalidated = False self.empty = False - def _warn_invalidated(self): + def _warn_invalidated(self) -> None: util.warn("This collection has been invalidated.") @property @@ -509,7 +511,7 @@ class CollectionAdapter: return self._data() @property - def _referenced_by_owner(self): + def _referenced_by_owner(self) -> bool: """return True if the owner state still refers to this collection. This will return False within a bulk replace operation, @@ -521,7 +523,9 @@ class CollectionAdapter: def bulk_appender(self): return self._data()._sa_appender - def append_with_event(self, item, initiator=None): + def append_with_event( + self, item: Any, initiator: Optional[AttributeEventToken] = None + ) -> None: """Add an entity to the collection, firing mutation events.""" self._data()._sa_appender(item, _sa_initiator=initiator) @@ -533,7 +537,7 @@ class CollectionAdapter: self.empty = True self.owner_state._empty_collections[self._key] = user_data - def _reset_empty(self): + def _reset_empty(self) -> None: assert ( self.empty ), "This collection adapter is not in the 'empty' state" @@ -542,20 +546,20 @@ class CollectionAdapter: self._key ] = self.owner_state._empty_collections.pop(self._key) - def _refuse_empty(self): + def _refuse_empty(self) -> NoReturn: raise sa_exc.InvalidRequestError( "This is a special 'empty' collection which cannot accommodate " "internal mutation operations" ) - def append_without_event(self, item): + def append_without_event(self, item: Any) -> None: """Add or restore an entity to the collection, firing no events.""" if self.empty: self._refuse_empty() self._data()._sa_appender(item, _sa_initiator=False) - def append_multiple_without_event(self, items): + def append_multiple_without_event(self, items: Iterable[Any]) -> None: """Add or restore an entity to the collection, firing no events.""" if self.empty: self._refuse_empty() @@ -566,17 +570,21 @@ class CollectionAdapter: def bulk_remover(self): return self._data()._sa_remover - def remove_with_event(self, item, initiator=None): + def remove_with_event( + self, item: Any, initiator: Optional[AttributeEventToken] = None + ) -> None: """Remove an entity from the collection, firing mutation events.""" self._data()._sa_remover(item, _sa_initiator=initiator) - def remove_without_event(self, item): + def remove_without_event(self, item: Any) -> None: """Remove an entity from the collection, firing no events.""" if self.empty: self._refuse_empty() self._data()._sa_remover(item, _sa_initiator=False) - def clear_with_event(self, initiator=None): + def clear_with_event( + self, initiator: Optional[AttributeEventToken] = None + ) -> None: """Empty the collection, firing a mutation event for each entity.""" if self.empty: @@ -585,7 +593,7 @@ class CollectionAdapter: for item in list(self): remover(item, _sa_initiator=initiator) - def clear_without_event(self): + def clear_without_event(self) -> None: """Empty the collection, firing no events.""" if self.empty: diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 28fea2f9b3..58556bb580 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -12,6 +12,7 @@ import itertools from typing import Any from typing import cast from typing import Dict +from typing import Iterable from typing import List from typing import Optional from typing import Set @@ -43,6 +44,7 @@ from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql._typing import _TP from ..sql._typing import is_dml from ..sql._typing import is_insert_update from ..sql._typing import is_select_base @@ -55,22 +57,32 @@ from ..sql.base import Options from ..sql.dml import UpdateBase from ..sql.elements import GroupedElement from ..sql.elements import TextClause -from ..sql.selectable import ExecutableReturnsRows from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import Select from ..sql.selectable import SelectLabelStyle from ..sql.selectable import SelectState +from ..sql.selectable import TypedReturnsRows from ..sql.visitors import InternalTraversal if TYPE_CHECKING: from ._typing import _InternalEntityType + from .loading import PostLoad from .mapper import Mapper from .query import Query + from .session import _BindArguments + from .session import Session + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter + from ..sql._typing import _ColumnsClauseArgument + from ..sql.compiler import SQLCompiler from ..sql.dml import _DMLTableElement from ..sql.elements import ColumnElement + from ..sql.selectable import _JoinTargetElement from ..sql.selectable import _LabelConventionCallable + from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import ExecutableReturnsRows from ..sql.selectable import SelectBase from ..sql.type_api import TypeEngine @@ -80,7 +92,7 @@ _path_registry = PathRegistry.root _EMPTY_DICT = util.immutabledict() -LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM") +LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM class QueryContext: @@ -109,6 +121,10 @@ class QueryContext: "loaders_require_uniquing", ) + runid: int + post_load_paths: Dict[PathRegistry, PostLoad] + compile_state: ORMCompileState + class default_load_options(Options): _only_return_tuples = False _populate_existing = False @@ -123,13 +139,16 @@ class QueryContext: def __init__( self, - compile_state, - statement, - params, - session, - load_options, - execution_options=None, - bind_arguments=None, + compile_state: CompileState, + statement: Union[Select[Any], FromStatement[Any]], + params: _CoreSingleExecuteParams, + session: Session, + load_options: Union[ + Type[QueryContext.default_load_options], + QueryContext.default_load_options, + ], + execution_options: Optional[_ExecuteOptionsParameter] = None, + bind_arguments: Optional[_BindArguments] = None, ): self.load_options = load_options self.execution_options = execution_options or _EMPTY_DICT @@ -220,8 +239,8 @@ class ORMCompileState(CompileState): attributes: Dict[Any, Any] global_attributes: Dict[Any, Any] - statement: Union[Select, FromStatement] - select_statement: Union[Select, FromStatement] + statement: Union[Select[Any], FromStatement[Any]] + select_statement: Union[Select[Any], FromStatement[Any]] _entities: List[_QueryEntity] _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter] compile_options: Union[ @@ -238,6 +257,7 @@ class ORMCompileState(CompileState): Tuple[Any, ...] ] current_path: PathRegistry = _path_registry + _has_mapper_entities = False def __init__(self, *arg, **kw): raise NotImplementedError() @@ -266,7 +286,12 @@ class ORMCompileState(CompileState): return SelectState._column_naming_convention(label_style) @classmethod - def create_for_statement(cls, statement_container, compiler, **kw): + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: """Create a context for a statement given a :class:`.Compiler`. This method is always invoked in the context of SQLCompiler.process(). @@ -443,7 +468,12 @@ class ORMFromStatementCompileState(ORMCompileState): eager_joins = _EMPTY_DICT @classmethod - def create_for_statement(cls, statement_container, compiler, **kw): + def create_for_statement( + cls, + statement_container: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: if compiler is not None: toplevel = not compiler.stack @@ -577,7 +607,7 @@ class ORMFromStatementCompileState(ORMCompileState): return None -class FromStatement(GroupedElement, Generative, ExecutableReturnsRows): +class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): """Core construct that represents a load of ORM objects from various :class:`.ReturnsRows` and other classes including: @@ -595,7 +625,7 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows): _for_update_arg = None - element: Union[SelectBase, TextClause, UpdateBase] + element: Union[ExecutableReturnsRows, TextClause] _traverse_internals = [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), @@ -606,7 +636,11 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows): ("_compile_options", InternalTraversal.dp_has_cache_key) ] - def __init__(self, entities, element): + def __init__( + self, + entities: Iterable[_ColumnsClauseArgument[Any]], + element: Union[ExecutableReturnsRows, TextClause], + ): self._raw_columns = [ coercions.expect( roles.ColumnsClauseRole, @@ -701,7 +735,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState): _having_criteria = () @classmethod - def create_for_statement(cls, statement, compiler, **kw): + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) @@ -1073,9 +1112,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ) @classmethod - @util.preload_module("sqlalchemy.orm.query") def from_statement(cls, statement, from_statement): - query = util.preloaded.orm_query from_statement = coercions.expect( roles.ReturnsRowsRole, @@ -1083,7 +1120,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): apply_propagate_attrs=statement, ) - stmt = query.FromStatement(statement._raw_columns, from_statement) + stmt = FromStatement(statement._raw_columns, from_statement) stmt.__dict__.update( _with_options=statement._with_options, @@ -2114,7 +2151,9 @@ def _column_descriptions( return d -def _legacy_filter_by_entity_zero(query_or_augmented_select): +def _legacy_filter_by_entity_zero( + query_or_augmented_select: Union[Query[Any], Select[Any]] +) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if self._setup_joins: _last_joined_entity = self._last_joined_entity @@ -2127,7 +2166,9 @@ def _legacy_filter_by_entity_zero(query_or_augmented_select): return _entity_from_pre_ent_zero(self) -def _entity_from_pre_ent_zero(query_or_augmented_select): +def _entity_from_pre_ent_zero( + query_or_augmented_select: Union[Query[Any], Select[Any]] +) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if not self._raw_columns: return None @@ -2144,13 +2185,19 @@ def _entity_from_pre_ent_zero(query_or_augmented_select): return ent -def _determine_last_joined_entity(setup_joins, entity_zero=None): +def _determine_last_joined_entity( + setup_joins: Tuple[_SetupJoinsElement, ...], + entity_zero: Optional[_InternalEntityType[Any]] = None, +) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]: if not setup_joins: return None (target, onclause, from_, flags) = setup_joins[-1] - if isinstance(target, interfaces.PropComparator): + if isinstance( + target, + attributes.QueryableAttribute, + ): return target.entity else: return target @@ -2161,6 +2208,8 @@ class _QueryEntity: __slots__ = () + supports_single_entity: bool + _non_hashable_value = False _null_column_type = False use_id_for_hash = False @@ -2173,6 +2222,9 @@ class _QueryEntity: def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() + def row_processor(self, context, result): + raise NotImplementedError() + @classmethod def to_compile_state( cls, compile_state, entities, entities_collection, is_current_entities diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index fbe35f92ab..1c343b04ce 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php + """Public API functions and helpers for declarative.""" from __future__ import annotations @@ -14,16 +15,21 @@ import typing from typing import Any from typing import Callable from typing import ClassVar +from typing import Dict +from typing import FrozenSet +from typing import Iterator from typing import Mapping from typing import Optional +from typing import overload +from typing import Set from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union import weakref from . import attributes from . import clsregistry -from . import exc as orm_exc from . import instrumentation from . import interfaces from . import mapperlib @@ -38,24 +44,40 @@ from .decl_base import _del_attribute from .decl_base import _mapper from .descriptor_props import Synonym as _orm_synonym from .mapper import Mapper +from .state import InstanceState from .. import exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations from ..sql.schema import MetaData from ..sql.selectable import FromClause -from ..sql.type_api import TypeEngine from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing +from ..util.typing import CallableReference +from ..util.typing import Literal +if TYPE_CHECKING: + from ._typing import _O + from ._typing import _RegistryType + from .descriptor_props import Synonym + from .instrumentation import ClassManager + from .interfaces import MapperProperty + from ..sql._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) -_TypeAnnotationMapType = Mapping[Type, Union[Type[TypeEngine], TypeEngine]] +# it's not clear how to have Annotated, Union objects etc. as keys here +# from a typing perspective so just leave it open ended for now +_TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"] +_MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"] + +_DeclaredAttrDecorated = Callable[ + ..., Union[Mapped[_T], SQLCoreOperations[_T]] +] -def has_inherited_table(cls): +def has_inherited_table(cls: Type[_O]) -> bool: """Given a class, return True if any of the classes it inherits from has a mapped table, otherwise return False. @@ -75,13 +97,13 @@ def has_inherited_table(cls): class _DynamicAttributesType(type): - def __setattr__(cls, key, value): + def __setattr__(cls, key: str, value: Any) -> None: if "__mapper__" in cls.__dict__: _add_attribute(cls, key, value) else: type.__setattr__(cls, key, value) - def __delattr__(cls, key): + def __delattr__(cls, key: str) -> None: if "__mapper__" in cls.__dict__: _del_attribute(cls, key) else: @@ -89,7 +111,7 @@ class _DynamicAttributesType(type): class DeclarativeAttributeIntercept( - _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"] + _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] ): """Metaclass that may be used in conjunction with the :class:`_orm.DeclarativeBase` class to support addition of class @@ -99,10 +121,10 @@ class DeclarativeAttributeIntercept( class DeclarativeMeta( - _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"] + _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] ): metadata: MetaData - registry: "RegistryType" + registry: RegistryType def __init__( cls, classname: Any, bases: Any, dict_: Any, **kw: Any @@ -130,7 +152,9 @@ class DeclarativeMeta( type.__init__(cls, classname, bases, dict_) -def synonym_for(name, map_column=False): +def synonym_for( + name: str, map_column: bool = False +) -> Callable[[Callable[..., Any]], Synonym[Any]]: """Decorator that produces an :func:`_orm.synonym` attribute in conjunction with a Python descriptor. @@ -164,7 +188,7 @@ def synonym_for(name, map_column=False): """ - def decorate(fn): + def decorate(fn: Callable[..., Any]) -> Synonym[Any]: return _orm_synonym(name, map_column=map_column, descriptor=fn) return decorate @@ -255,16 +279,16 @@ class declared_attr(interfaces._MappedAttribute[_T]): if typing.TYPE_CHECKING: - def __set__(self, instance, value): + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any): + def __delete__(self, instance: Any) -> None: ... def __init__( self, - fn: Callable[..., Union[Mapped[_T], SQLCoreOperations[_T]]], - cascading=False, + fn: _DeclaredAttrDecorated[_T], + cascading: bool = False, ): self.fget = fn self._cascading = cascading @@ -273,10 +297,28 @@ class declared_attr(interfaces._MappedAttribute[_T]): def _collect_return_annotation(self) -> Optional[Type[Any]]: return util.get_annotations(self.fget).get("return") - def __get__(self, instance, owner) -> InstrumentedAttribute[_T]: + # this is the Mapped[] API where at class descriptor get time we want + # the type checker to see InstrumentedAttribute[_T]. However the + # callable function prior to mapping in fact calls the given + # declarative function that does not return InstrumentedAttribute + @overload + def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: + ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T], _T]: # the declared_attr needs to make use of a cache that exists # for the span of the declarative scan_attributes() phase. # to achieve this we look at the class manager that's configured. + + # note this method should not be called outside of the declarative + # setup phase + cls = owner manager = attributes.opt_manager_of_class(cls) if manager is None: @@ -287,30 +329,33 @@ class declared_attr(interfaces._MappedAttribute[_T]): "Unmanaged access of declarative attribute %s from " "non-mapped class %s" % (self.fget.__name__, cls.__name__) ) - return self.fget(cls) + return self.fget(cls) # type: ignore elif manager.is_mapped: # the class is mapped, which means we're outside of the declarative # scan setup, just run the function. - return self.fget(cls) + return self.fget(cls) # type: ignore # here, we are inside of the declarative scan. use the registry # that is tracking the values of these attributes. declarative_scan = manager.declarative_scan() + + # assert that we are in fact in the declarative scan assert declarative_scan is not None + reg = declarative_scan.declared_attr_reg if self in reg: - return reg[self] + return reg[self] # type: ignore else: reg[self] = obj = self.fget(cls) - return obj + return obj # type: ignore @hybridmethod - def _stateful(cls, **kw): + def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]: return _stateful_declared_attr(**kw) @hybridproperty - def cascading(cls): + def cascading(cls) -> _stateful_declared_attr[_T]: """Mark a :class:`.declared_attr` as cascading. This is a special-use modifier which indicates that a column @@ -372,20 +417,23 @@ class declared_attr(interfaces._MappedAttribute[_T]): return cls._stateful(cascading=True) -class _stateful_declared_attr(declared_attr): - def __init__(self, **kw): +class _stateful_declared_attr(declared_attr[_T]): + kw: Dict[str, Any] + + def __init__(self, **kw: Any): self.kw = kw - def _stateful(self, **kw): + @hybridmethod + def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]: new_kw = self.kw.copy() new_kw.update(kw) return _stateful_declared_attr(**new_kw) - def __call__(self, fn): + def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]: return declared_attr(fn, **self.kw) -def declarative_mixin(cls): +def declarative_mixin(cls: Type[_T]) -> Type[_T]: """Mark a class as providing the feature of "declarative mixin". E.g.:: @@ -427,9 +475,9 @@ def declarative_mixin(cls): return cls -def _setup_declarative_base(cls): +def _setup_declarative_base(cls: Type[Any]) -> None: if "metadata" in cls.__dict__: - metadata = cls.metadata + metadata = cls.metadata # type: ignore else: metadata = None @@ -457,15 +505,15 @@ def _setup_declarative_base(cls): reg = registry( metadata=metadata, type_annotation_map=type_annotation_map ) - cls.registry = reg + cls.registry = reg # type: ignore - cls._sa_registry = reg + cls._sa_registry = reg # type: ignore if "metadata" not in cls.__dict__: - cls.metadata = cls.registry.metadata + cls.metadata = cls.registry.metadata # type: ignore -class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]): +class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]): """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass to intercept new attributes. @@ -477,10 +525,10 @@ class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]): """ - registry: ClassVar["registry"] - _sa_registry: ClassVar["registry"] + registry: ClassVar[_RegistryType] + _sa_registry: ClassVar[_RegistryType] metadata: ClassVar[MetaData] - __mapper__: ClassVar[Mapper] + __mapper__: ClassVar[Mapper[Any]] __table__: Optional[FromClause] if typing.TYPE_CHECKING: @@ -496,7 +544,7 @@ class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]): class DeclarativeBase( - inspection.Inspectable["InstanceState"], + inspection.Inspectable[InstanceState[Any]], metaclass=DeclarativeAttributeIntercept, ): """Base class used for declarative class definitions. @@ -557,10 +605,10 @@ class DeclarativeBase( """ - registry: ClassVar["registry"] - _sa_registry: ClassVar["registry"] + registry: ClassVar[_RegistryType] + _sa_registry: ClassVar[_RegistryType] metadata: ClassVar[MetaData] - __mapper__: ClassVar[Mapper] + __mapper__: ClassVar[Mapper[Any]] __table__: Optional[FromClause] if typing.TYPE_CHECKING: @@ -572,10 +620,12 @@ class DeclarativeBase( if DeclarativeBase in cls.__bases__: _setup_declarative_base(cls) else: - cls._sa_registry.map_declaratively(cls) + _as_declarative(cls._sa_registry, cls, cls.__dict__) -def add_mapped_attribute(target, key, attr): +def add_mapped_attribute( + target: Type[_O], key: str, attr: MapperProperty[Any] +) -> None: """Add a new mapped attribute to an ORM mapped class. E.g.:: @@ -593,14 +643,15 @@ def add_mapped_attribute(target, key, attr): def declarative_base( + *, metadata: Optional[MetaData] = None, - mapper=None, - cls=object, - name="Base", + mapper: Optional[Callable[..., Mapper[Any]]] = None, + cls: Type[Any] = object, + name: str = "Base", class_registry: Optional[clsregistry._ClsRegistryType] = None, type_annotation_map: Optional[_TypeAnnotationMapType] = None, constructor: Callable[..., None] = _declarative_constructor, - metaclass=DeclarativeMeta, + metaclass: Type[Any] = DeclarativeMeta, ) -> Any: r"""Construct a base class for declarative class definitions. @@ -736,8 +787,19 @@ class registry: """ + _class_registry: clsregistry._ClsRegistryType + _managers: weakref.WeakKeyDictionary[ClassManager[Any], Literal[True]] + _non_primary_mappers: weakref.WeakKeyDictionary[Mapper[Any], Literal[True]] + metadata: MetaData + constructor: CallableReference[Callable[..., None]] + type_annotation_map: _MutableTypeAnnotationMapType + _dependents: Set[_RegistryType] + _dependencies: Set[_RegistryType] + _new_mappers: bool + def __init__( self, + *, metadata: Optional[MetaData] = None, class_registry: Optional[clsregistry._ClsRegistryType] = None, type_annotation_map: Optional[_TypeAnnotationMapType] = None, @@ -799,9 +861,7 @@ class registry: def update_type_annotation_map( self, - type_annotation_map: Mapping[ - Type, Union[Type[TypeEngine], TypeEngine] - ], + type_annotation_map: _TypeAnnotationMapType, ) -> None: """update the :paramref:`_orm.registry.type_annotation_map` with new values.""" @@ -817,20 +877,20 @@ class registry: ) @property - def mappers(self): + def mappers(self) -> FrozenSet[Mapper[Any]]: """read only collection of all :class:`_orm.Mapper` objects.""" return frozenset(manager.mapper for manager in self._managers).union( self._non_primary_mappers ) - def _set_depends_on(self, registry): + def _set_depends_on(self, registry: RegistryType) -> None: if registry is self: return registry._dependents.add(self) self._dependencies.add(registry) - def _flag_new_mapper(self, mapper): + def _flag_new_mapper(self, mapper: Mapper[Any]) -> None: mapper._ready_for_configure = True if self._new_mappers: return @@ -839,7 +899,9 @@ class registry: reg._new_mappers = True @classmethod - def _recurse_with_dependents(cls, registries): + def _recurse_with_dependents( + cls, registries: Set[RegistryType] + ) -> Iterator[RegistryType]: todo = registries done = set() while todo: @@ -856,7 +918,9 @@ class registry: todo.update(reg._dependents.difference(done)) @classmethod - def _recurse_with_dependencies(cls, registries): + def _recurse_with_dependencies( + cls, registries: Set[RegistryType] + ) -> Iterator[RegistryType]: todo = registries done = set() while todo: @@ -873,7 +937,7 @@ class registry: # them before todo.update(reg._dependencies.difference(done)) - def _mappers_to_configure(self): + def _mappers_to_configure(self) -> Iterator[Mapper[Any]]: return itertools.chain( ( manager.mapper @@ -889,13 +953,13 @@ class registry: ), ) - def _add_non_primary_mapper(self, np_mapper): + def _add_non_primary_mapper(self, np_mapper: Mapper[Any]) -> None: self._non_primary_mappers[np_mapper] = True - def _dispose_cls(self, cls): + def _dispose_cls(self, cls: Type[_O]) -> None: clsregistry.remove_class(cls.__name__, cls, self._class_registry) - def _add_manager(self, manager): + def _add_manager(self, manager: ClassManager[Any]) -> None: self._managers[manager] = True if manager.is_mapped: raise exc.ArgumentError( @@ -905,7 +969,7 @@ class registry: assert manager.registry is None manager.registry = self - def configure(self, cascade=False): + def configure(self, cascade: bool = False) -> None: """Configure all as-yet unconfigured mappers in this :class:`_orm.registry`. @@ -946,7 +1010,7 @@ class registry: """ mapperlib._configure_registries({self}, cascade=cascade) - def dispose(self, cascade=False): + def dispose(self, cascade: bool = False) -> None: """Dispose of all mappers in this :class:`_orm.registry`. After invocation, all the classes that were mapped within this registry @@ -972,7 +1036,7 @@ class registry: mapperlib._dispose_registries({self}, cascade=cascade) - def _dispose_manager_and_mapper(self, manager): + def _dispose_manager_and_mapper(self, manager: ClassManager[Any]) -> None: if "mapper" in manager.__dict__: mapper = manager.mapper @@ -984,11 +1048,11 @@ class registry: def generate_base( self, - mapper=None, - cls=object, - name="Base", - metaclass=DeclarativeMeta, - ): + mapper: Optional[Callable[..., Mapper[Any]]] = None, + cls: Type[Any] = object, + name: str = "Base", + metaclass: Type[Any] = DeclarativeMeta, + ) -> Any: """Generate a declarative base class. Classes that inherit from the returned class object will be @@ -1070,7 +1134,7 @@ class registry: if hasattr(cls, "__class_getitem__"): - def __class_getitem__(cls, key): + def __class_getitem__(cls: Type[_T], key: str) -> Type[_T]: # allow generic classes in py3.9+ return cls @@ -1078,7 +1142,7 @@ class registry: return metaclass(name, bases, class_dict) - def mapped(self, cls): + def mapped(self, cls: Type[_O]) -> Type[_O]: """Class decorator that will apply the Declarative mapping process to a given class. @@ -1114,7 +1178,7 @@ class registry: _as_declarative(self, cls, cls.__dict__) return cls - def as_declarative_base(self, **kw): + def as_declarative_base(self, **kw: Any) -> Callable[[Type[_T]], Type[_T]]: """ Class decorator which will invoke :meth:`_orm.registry.generate_base` @@ -1142,14 +1206,14 @@ class registry: """ - def decorate(cls): + def decorate(cls: Type[_T]) -> Type[_T]: kw["cls"] = cls kw["name"] = cls.__name__ - return self.generate_base(**kw) + return self.generate_base(**kw) # type: ignore return decorate - def map_declaratively(self, cls): + def map_declaratively(self, cls: Type[_O]) -> Mapper[_O]: """Map a class declaratively. In this form of mapping, the class is scanned for mapping information, @@ -1194,9 +1258,15 @@ class registry: :meth:`_orm.registry.map_imperatively` """ - return _as_declarative(self, cls, cls.__dict__) + _as_declarative(self, cls, cls.__dict__) + return cls.__mapper__ # type: ignore - def map_imperatively(self, class_, local_table=None, **kw): + def map_imperatively( + self, + class_: Type[_O], + local_table: Optional[FromClause] = None, + **kw: Any, + ) -> Mapper[_O]: r"""Map a class imperatively. In this form of mapping, the class is not scanned for any mapping @@ -1251,7 +1321,7 @@ class registry: RegistryType = registry -def as_declarative(**kw): +def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]: """ Class decorator which will adapt a given class into a :func:`_orm.declarative_base`. @@ -1292,14 +1362,9 @@ def as_declarative(**kw): @inspection._inspects( DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept ) -def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]: - mp: Mapper[Any] = _inspect_mapped_class(cls) +def _inspect_decl_meta(cls: Type[Any]) -> Optional[Mapper[Any]]: + mp: Optional[Mapper[Any]] = _inspect_mapped_class(cls) if mp is None: if _DeferredMapperConfig.has_cls(cls): _DeferredMapperConfig.raise_unmapped_for_cls(cls) - raise orm_exc.UnmappedClassError( - cls, - msg="Class %s has a deferred mapping on it. It is not yet " - "usable as a mapped class." % orm_exc._safe_cls_name(cls), - ) return mp diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index b1f81cb6b8..c3faac36cf 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -4,16 +4,26 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php + """Internal implementation for declarative.""" from __future__ import annotations import collections from typing import Any +from typing import Callable +from typing import cast from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import weakref from . import attributes @@ -21,6 +31,8 @@ from . import clsregistry from . import exc as orm_exc from . import instrumentation from . import mapperlib +from ._typing import _O +from ._typing import attr_is_internal_proxy from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .base import _is_mapped_class @@ -32,6 +44,7 @@ from .interfaces import _MappedAttribute from .interfaces import _MapsColumns from .interfaces import MapperProperty from .mapper import Mapper as mapper +from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn from .util import _is_mapped_annotation @@ -43,12 +56,41 @@ from ..sql import expression from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +from ..util.typing import Protocol if TYPE_CHECKING: + from ._typing import _ClassDict from ._typing import _RegistryType + from .decl_api import declared_attr + from .instrumentation import ClassManager + from ..sql.schema import MetaData + from ..sql.selectable import FromClause + +_T = TypeVar("_T", bound=Any) + +_MapperKwArgs = Mapping[str, Any] + +_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]] -def _declared_mapping_info(cls): +class _DeclMappedClassProtocol(Protocol[_O]): + metadata: MetaData + __mapper__: Mapper[_O] + __table__: Table + __tablename__: str + __mapper_args__: Mapping[str, Any] + __table_args__: Optional[_TableArgsType] + + def __declare_first__(self) -> None: + pass + + def __declare_last__(self) -> None: + pass + + +def _declared_mapping_info( + cls: Type[Any], +) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]: # deferred mapping if _DeferredMapperConfig.has_cls(cls): return _DeferredMapperConfig.config_for_cls(cls) @@ -59,13 +101,15 @@ def _declared_mapping_info(cls): return None -def _resolve_for_abstract_or_classical(cls): +def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]: if cls is object: return None + sup: Optional[Type[Any]] + if cls.__dict__.get("__abstract__", False): - for sup in cls.__bases__: - sup = _resolve_for_abstract_or_classical(sup) + for base_ in cls.__bases__: + sup = _resolve_for_abstract_or_classical(base_) if sup is not None: return sup else: @@ -79,7 +123,9 @@ def _resolve_for_abstract_or_classical(cls): return cls -def _get_immediate_cls_attr(cls, attrname, strict=False): +def _get_immediate_cls_attr( + cls: Type[Any], attrname: str, strict: bool = False +) -> Optional[Any]: """return an attribute of the class that is either present directly on the class, e.g. not on a superclass, or is from a superclass but this superclass is a non-mapped mixin, that is, not a descendant of @@ -102,7 +148,7 @@ def _get_immediate_cls_attr(cls, attrname, strict=False): return getattr(cls, attrname) for base in cls.__mro__[1:]: - _is_classicial_inherits = _dive_for_cls_manager(base) + _is_classicial_inherits = _dive_for_cls_manager(base) is not None if attrname in base.__dict__ and ( base is cls @@ -116,33 +162,37 @@ def _get_immediate_cls_attr(cls, attrname, strict=False): return None -def _dive_for_cls_manager(cls): +def _dive_for_cls_manager(cls: Type[_O]) -> Optional[ClassManager[_O]]: # because the class manager registration is pluggable, # we need to do the search for every class in the hierarchy, # rather than just a simple "cls._sa_class_manager" - # python 2 old style class - if not hasattr(cls, "__mro__"): - return None - for base in cls.__mro__: - manager = attributes.opt_manager_of_class(base) + manager: Optional[ClassManager[_O]] = attributes.opt_manager_of_class( + base + ) if manager: return manager return None -def _as_declarative(registry, cls, dict_): +def _as_declarative( + registry: _RegistryType, cls: Type[Any], dict_: _ClassDict +) -> Optional[_MapperConfig]: # declarative scans the class for attributes. no table or mapper # args passed separately. - return _MapperConfig.setup_mapping(registry, cls, dict_, None, {}) -def _mapper(registry, cls, table, mapper_kw): +def _mapper( + registry: _RegistryType, + cls: Type[_O], + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, +) -> Mapper[_O]: _ImperativeMapperConfig(registry, cls, table, mapper_kw) - return cls.__mapper__ + return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__ @util.preload_module("sqlalchemy.orm.decl_api") @@ -152,7 +202,9 @@ def _is_declarative_props(obj: Any) -> bool: return isinstance(obj, (declared_attr, util.classproperty)) -def _check_declared_props_nocascade(obj, name, cls): +def _check_declared_props_nocascade( + obj: Any, name: str, cls: Type[_O] +) -> bool: if _is_declarative_props(obj): if getattr(obj, "_cascading", False): util.warn( @@ -174,8 +226,20 @@ class _MapperConfig: "__weakref__", ) + cls: Type[Any] + classname: str + properties: util.OrderedDict[str, MapperProperty[Any]] + declared_attr_reg: Dict[declared_attr[Any], Any] + @classmethod - def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw): + def setup_mapping( + cls, + registry: _RegistryType, + cls_: Type[_O], + dict_: _ClassDict, + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, + ) -> Optional[_MapperConfig]: manager = attributes.opt_manager_of_class(cls) if manager and manager.class_ is cls_: raise exc.InvalidRequestError( @@ -183,24 +247,26 @@ class _MapperConfig: ) if cls_.__dict__.get("__abstract__", False): - return + return None defer_map = _get_immediate_cls_attr( cls_, "_sa_decl_prepare_nocascade", strict=True ) or hasattr(cls_, "_sa_decl_prepare") if defer_map: - cfg_cls = _DeferredMapperConfig + return _DeferredMapperConfig( + registry, cls_, dict_, table, mapper_kw + ) else: - cfg_cls = _ClassScanMapperConfig - - return cfg_cls(registry, cls_, dict_, table, mapper_kw) + return _ClassScanMapperConfig( + registry, cls_, dict_, table, mapper_kw + ) def __init__( self, registry: _RegistryType, cls_: Type[Any], - mapper_kw: Dict[str, Any], + mapper_kw: _MapperKwArgs, ): self.cls = util.assert_arg_type(cls_, type, "cls_") self.classname = cls_.__name__ @@ -224,13 +290,16 @@ class _MapperConfig: "Mapper." % self.cls ) - def set_cls_attribute(self, attrname, value): + def set_cls_attribute(self, attrname: str, value: _T) -> _T: manager = instrumentation.manager_of_class(self.cls) manager.install_member(attrname, value) return value - def _early_mapping(self, mapper_kw): + def map(self, mapper_kw: _MapperKwArgs = ...) -> Mapper[Any]: + raise NotImplementedError() + + def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: self.map(mapper_kw) @@ -239,10 +308,10 @@ class _ImperativeMapperConfig(_MapperConfig): def __init__( self, - registry, - cls_, - table, - mapper_kw, + registry: _RegistryType, + cls_: Type[_O], + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, ): super(_ImperativeMapperConfig, self).__init__( registry, cls_, mapper_kw @@ -260,7 +329,7 @@ class _ImperativeMapperConfig(_MapperConfig): self._early_mapping(mapper_kw) - def map(self, mapper_kw=util.EMPTY_DICT): + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: mapper_cls = mapper return self.set_cls_attribute( @@ -268,7 +337,7 @@ class _ImperativeMapperConfig(_MapperConfig): mapper_cls(self.cls, self.local_table, **mapper_kw), ) - def _setup_inheritance(self, mapper_kw): + def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None: cls = self.cls inherits = mapper_kw.get("inherits", None) @@ -277,8 +346,8 @@ class _ImperativeMapperConfig(_MapperConfig): # since we search for classical mappings now, search for # multiple mapped bases as well and raise an error. inherits_search = [] - for c in cls.__bases__: - c = _resolve_for_abstract_or_classical(c) + for base_ in cls.__bases__: + c = _resolve_for_abstract_or_classical(base_) if c is None: continue if _declared_mapping_info( @@ -318,13 +387,30 @@ class _ClassScanMapperConfig(_MapperConfig): "inherits", ) + registry: _RegistryType + clsdict_view: _ClassDict + collected_annotations: Dict[str, Tuple[Any, bool]] + collected_attributes: Dict[str, Any] + local_table: Optional[FromClause] + persist_selectable: Optional[FromClause] + declared_columns: util.OrderedSet[Column[Any]] + column_copies: Dict[ + Union[MappedColumn[Any], Column[Any]], + Union[MappedColumn[Any], Column[Any]], + ] + tablename: Optional[str] + mapper_args: Mapping[str, Any] + table_args: Optional[_TableArgsType] + mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] + inherits: Optional[Type[Any]] + def __init__( self, - registry, - cls_, - dict_, - table, - mapper_kw, + registry: _RegistryType, + cls_: Type[_O], + dict_: _ClassDict, + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, ): # grab class dict before the instrumentation manager has been added. @@ -337,7 +423,7 @@ class _ClassScanMapperConfig(_MapperConfig): self.persist_selectable = None self.collected_attributes = {} - self.collected_annotations: Dict[str, Tuple[Any, bool]] = {} + self.collected_annotations = {} self.declared_columns = util.OrderedSet() self.column_copies = {} @@ -360,31 +446,37 @@ class _ClassScanMapperConfig(_MapperConfig): self._early_mapping(mapper_kw) - def _setup_declared_events(self): + def _setup_declared_events(self) -> None: if _get_immediate_cls_attr(self.cls, "__declare_last__"): @event.listens_for(mapper, "after_configured") - def after_configured(): - self.cls.__declare_last__() + def after_configured() -> None: + cast( + "_DeclMappedClassProtocol[Any]", self.cls + ).__declare_last__() if _get_immediate_cls_attr(self.cls, "__declare_first__"): @event.listens_for(mapper, "before_configured") - def before_configured(): - self.cls.__declare_first__() - - def _cls_attr_override_checker(self, cls): + def before_configured() -> None: + cast( + "_DeclMappedClassProtocol[Any]", self.cls + ).__declare_first__() + + def _cls_attr_override_checker( + self, cls: Type[_O] + ) -> Callable[[str, Any], bool]: """Produce a function that checks if a class has overridden an attribute, taking SQLAlchemy-enabled dataclass fields into account. """ sa_dataclass_metadata_key = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__", None + cls, "__sa_dataclass_metadata_key__" ) if sa_dataclass_metadata_key is None: - def attribute_is_overridden(key, obj): + def attribute_is_overridden(key: str, obj: Any) -> bool: return getattr(cls, key) is not obj else: @@ -402,7 +494,7 @@ class _ClassScanMapperConfig(_MapperConfig): absent = object() - def attribute_is_overridden(key, obj): + def attribute_is_overridden(key: str, obj: Any) -> bool: if _is_declarative_props(obj): obj = obj.fget @@ -457,13 +549,15 @@ class _ClassScanMapperConfig(_MapperConfig): ] ) - def _cls_attr_resolver(self, cls): + def _cls_attr_resolver( + self, cls: Type[Any] + ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]: """produce a function to iterate the "attributes" of a class, adjusting for SQLAlchemy fields embedded in dataclass fields. """ - sa_dataclass_metadata_key = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__", None + sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" ) cls_annotations = util.get_annotations(cls) @@ -477,7 +571,9 @@ class _ClassScanMapperConfig(_MapperConfig): ) if sa_dataclass_metadata_key is None: - def local_attributes_for_class(): + def local_attributes_for_class() -> Iterable[ + Tuple[str, Any, Any, bool] + ]: return ( ( name, @@ -493,12 +589,16 @@ class _ClassScanMapperConfig(_MapperConfig): field.name: field for field in util.local_dataclass_fields(cls) } - def local_attributes_for_class(): + fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key + + def local_attributes_for_class() -> Iterable[ + Tuple[str, Any, Any, bool] + ]: for name in names: field = dataclass_fields.get(name, None) if field and sa_dataclass_metadata_key in field.metadata: yield field.name, _as_dc_declaredattr( - field.metadata, sa_dataclass_metadata_key + field.metadata, fixed_sa_dataclass_metadata_key ), cls_annotations.get(field.name), True else: yield name, cls_vars.get(name), cls_annotations.get( @@ -507,14 +607,17 @@ class _ClassScanMapperConfig(_MapperConfig): return local_attributes_for_class - def _scan_attributes(self): + def _scan_attributes(self) -> None: cls = self.cls + cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls) + clsdict_view = self.clsdict_view collected_attributes = self.collected_attributes column_copies = self.column_copies mapper_args_fn = None table_args = inherited_table_args = None + tablename = None fixed_table = "__table__" in clsdict_view @@ -555,21 +658,23 @@ class _ClassScanMapperConfig(_MapperConfig): # make a copy of it so a class-level dictionary # is not overwritten when we update column-based # arguments. - def mapper_args_fn(): - return dict(cls.__mapper_args__) + def _mapper_args_fn() -> Dict[str, Any]: + return dict(cls_as_Decl.__mapper_args__) + + mapper_args_fn = _mapper_args_fn elif name == "__tablename__": check_decl = _check_declared_props_nocascade( obj, name, cls ) if not tablename and (not class_mapped or check_decl): - tablename = cls.__tablename__ + tablename = cls_as_Decl.__tablename__ elif name == "__table_args__": check_decl = _check_declared_props_nocascade( obj, name, cls ) if not table_args and (not class_mapped or check_decl): - table_args = cls.__table_args__ + table_args = cls_as_Decl.__table_args__ if not isinstance( table_args, (tuple, dict, type(None)) ): @@ -657,9 +762,10 @@ class _ClassScanMapperConfig(_MapperConfig): # or similar. note there is no known case that # produces nested proxies, so we are only # looking one level deep right now. + if ( isinstance(ret, InspectionAttr) - and ret._is_internal_proxy + and attr_is_internal_proxy(ret) and not isinstance( ret.original_property, MapperProperty ) @@ -669,6 +775,7 @@ class _ClassScanMapperConfig(_MapperConfig): collected_attributes[name] = column_copies[ obj ] = ret + if ( isinstance(ret, (Column, MapperProperty)) and ret.doc is None @@ -737,7 +844,9 @@ class _ClassScanMapperConfig(_MapperConfig): self.tablename = tablename self.mapper_args_fn = mapper_args_fn - def _warn_for_decl_attributes(self, cls, key, c): + def _warn_for_decl_attributes( + self, cls: Type[Any], key: str, c: Any + ) -> None: if isinstance(c, expression.ColumnClause): util.warn( f"Attribute '{key}' on class {cls} appears to " @@ -746,8 +855,12 @@ class _ClassScanMapperConfig(_MapperConfig): ) def _produce_column_copies( - self, attributes_for_class, attribute_is_overridden - ): + self, + attributes_for_class: Callable[ + [], Iterable[Tuple[str, Any, Any, bool]] + ], + attribute_is_overridden: Callable[[str, Any], bool], + ) -> None: cls = self.cls dict_ = self.clsdict_view collected_attributes = self.collected_attributes @@ -763,7 +876,8 @@ class _ClassScanMapperConfig(_MapperConfig): continue elif name not in dict_ and not ( "__table__" in dict_ - and (obj.name or name) in dict_["__table__"].c + and (getattr(obj, "name", None) or name) + in dict_["__table__"].c ): if obj.foreign_keys: for fk in obj.foreign_keys: @@ -786,7 +900,7 @@ class _ClassScanMapperConfig(_MapperConfig): setattr(cls, name, copy_) - def _extract_mappable_attributes(self): + def _extract_mappable_attributes(self) -> None: cls = self.cls collected_attributes = self.collected_attributes @@ -858,17 +972,19 @@ class _ClassScanMapperConfig(_MapperConfig): "declarative base class." ) elif isinstance(value, Column): - _undefer_column_name(k, self.column_copies.get(value, value)) + _undefer_column_name( + k, self.column_copies.get(value, value) # type: ignore + ) elif isinstance(value, _IntrospectsAnnotations): annotation, is_dataclass = self.collected_annotations.get( - k, (None, None) + k, (None, False) ) value.declarative_scan( self.registry, cls, k, annotation, is_dataclass ) our_stuff[k] = value - def _extract_declared_columns(self): + def _extract_declared_columns(self) -> None: our_stuff = self.properties # extract columns from the class dict @@ -914,8 +1030,10 @@ class _ClassScanMapperConfig(_MapperConfig): % (self.classname, name, (", ".join(sorted(keys)))) ) - def _setup_table(self, table=None): + def _setup_table(self, table: Optional[FromClause] = None) -> None: cls = self.cls + cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls) + tablename = self.tablename table_args = self.table_args clsdict_view = self.clsdict_view @@ -925,13 +1043,18 @@ class _ClassScanMapperConfig(_MapperConfig): if "__table__" not in clsdict_view and table is None: if hasattr(cls, "__table_cls__"): - table_cls = util.unbound_method_to_callable(cls.__table_cls__) + table_cls = cast( + Type[Table], + util.unbound_method_to_callable(cls.__table_cls__), # type: ignore # noqa: E501 + ) else: table_cls = Table if tablename is not None: - args, table_kw = (), {} + args: Tuple[Any, ...] = () + table_kw: Dict[str, Any] = {} + if table_args: if isinstance(table_args, dict): table_kw = table_args @@ -960,7 +1083,7 @@ class _ClassScanMapperConfig(_MapperConfig): ) else: if table is None: - table = cls.__table__ + table = cls_as_Decl.__table__ if declared_columns: for c in declared_columns: if not table.c.contains_column(c): @@ -968,15 +1091,16 @@ class _ClassScanMapperConfig(_MapperConfig): "Can't add additional column %r when " "specifying __table__" % c.key ) + self.local_table = table - def _metadata_for_cls(self, manager): + def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData: if hasattr(self.cls, "metadata"): - return self.cls.metadata + return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata else: return manager.registry.metadata - def _setup_inheritance(self, mapper_kw): + def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None: table = self.local_table cls = self.cls table_args = self.table_args @@ -988,8 +1112,8 @@ class _ClassScanMapperConfig(_MapperConfig): # since we search for classical mappings now, search for # multiple mapped bases as well and raise an error. inherits_search = [] - for c in cls.__bases__: - c = _resolve_for_abstract_or_classical(c) + for base_ in cls.__bases__: + c = _resolve_for_abstract_or_classical(base_) if c is None: continue if _declared_mapping_info( @@ -1024,9 +1148,12 @@ class _ClassScanMapperConfig(_MapperConfig): "table-mapped class." % cls ) elif self.inherits: - inherited_mapper = _declared_mapping_info(self.inherits) - inherited_table = inherited_mapper.local_table - inherited_persist_selectable = inherited_mapper.persist_selectable + inherited_mapper_or_config = _declared_mapping_info(self.inherits) + assert inherited_mapper_or_config is not None + inherited_table = inherited_mapper_or_config.local_table + inherited_persist_selectable = ( + inherited_mapper_or_config.persist_selectable + ) if table is None: # single table inheritance. @@ -1036,29 +1163,44 @@ class _ClassScanMapperConfig(_MapperConfig): "Can't place __table_args__ on an inherited class " "with no table." ) + # add any columns declared here to the inherited table. - for c in declared_columns: - if c.name in inherited_table.c: - if inherited_table.c[c.name] is c: + if declared_columns and not isinstance(inherited_table, Table): + raise exc.ArgumentError( + f"Can't declare columns on single-table-inherited " + f"subclass {self.cls}; superclass {self.inherits} " + "is not mapped to a Table" + ) + + for col in declared_columns: + assert inherited_table is not None + if col.name in inherited_table.c: + if inherited_table.c[col.name] is col: continue raise exc.ArgumentError( "Column '%s' on class %s conflicts with " "existing column '%s'" - % (c, cls, inherited_table.c[c.name]) + % (col, cls, inherited_table.c[col.name]) ) - if c.primary_key: + if col.primary_key: raise exc.ArgumentError( "Can't place primary key columns on an inherited " "class with no table." ) - inherited_table.append_column(c) + + if TYPE_CHECKING: + assert isinstance(inherited_table, Table) + + inherited_table.append_column(col) if ( inherited_persist_selectable is not None and inherited_persist_selectable is not inherited_table ): - inherited_persist_selectable._refresh_for_new_column(c) + inherited_persist_selectable._refresh_for_new_column( + col + ) - def _prepare_mapper_arguments(self, mapper_kw): + def _prepare_mapper_arguments(self, mapper_kw: _MapperKwArgs) -> None: properties = self.properties if self.mapper_args_fn: @@ -1100,6 +1242,7 @@ class _ClassScanMapperConfig(_MapperConfig): # not mapped on the parent class, to avoid # mapping columns specific to sibling/nephew classes inherited_mapper = _declared_mapping_info(self.inherits) + assert isinstance(inherited_mapper, Mapper) inherited_table = inherited_mapper.local_table if "exclude_properties" not in mapper_args: @@ -1133,11 +1276,14 @@ class _ClassScanMapperConfig(_MapperConfig): result_mapper_args["properties"] = properties self.mapper_args = result_mapper_args - def map(self, mapper_kw=util.EMPTY_DICT): + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: self._prepare_mapper_arguments(mapper_kw) if hasattr(self.cls, "__mapper_cls__"): - mapper_cls = util.unbound_method_to_callable( - self.cls.__mapper_cls__ + mapper_cls = cast( + "Type[Mapper[Any]]", + util.unbound_method_to_callable( + self.cls.__mapper_cls__ # type: ignore + ), ) else: mapper_cls = mapper @@ -1149,7 +1295,9 @@ class _ClassScanMapperConfig(_MapperConfig): @util.preload_module("sqlalchemy.orm.decl_api") -def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key): +def _as_dc_declaredattr( + field_metadata: Mapping[str, Any], sa_dataclass_metadata_key: str +) -> Any: # wrap lambdas inside dataclass fields inside an ad-hoc declared_attr. # we can't write it because field.metadata is immutable :( so we have # to go through extra trouble to compare these @@ -1162,46 +1310,55 @@ def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key): class _DeferredMapperConfig(_ClassScanMapperConfig): - _configs = util.OrderedDict() + _cls: weakref.ref[Type[Any]] + + _configs: util.OrderedDict[ + weakref.ref[Type[Any]], _DeferredMapperConfig + ] = util.OrderedDict() - def _early_mapping(self, mapper_kw): + def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: pass - @property - def cls(self): - return self._cls() + # mypy disallows plain property override of variable + @property # type: ignore + def cls(self) -> Type[Any]: # type: ignore + return self._cls() # type: ignore @cls.setter - def cls(self, class_): + def cls(self, class_: Type[Any]) -> None: self._cls = weakref.ref(class_, self._remove_config_cls) self._configs[self._cls] = self @classmethod - def _remove_config_cls(cls, ref): + def _remove_config_cls(cls, ref: weakref.ref[Type[Any]]) -> None: cls._configs.pop(ref, None) @classmethod - def has_cls(cls, class_): + def has_cls(cls, class_: Type[Any]) -> bool: # 2.6 fails on weakref if class_ is an old style class return isinstance(class_, type) and weakref.ref(class_) in cls._configs @classmethod - def raise_unmapped_for_cls(cls, class_): + def raise_unmapped_for_cls(cls, class_: Type[Any]) -> NoReturn: if hasattr(class_, "_sa_raise_deferred_config"): - class_._sa_raise_deferred_config() + class_._sa_raise_deferred_config() # type: ignore raise orm_exc.UnmappedClassError( class_, - msg="Class %s has a deferred mapping on it. It is not yet " - "usable as a mapped class." % orm_exc._safe_cls_name(class_), + msg=( + f"Class {orm_exc._safe_cls_name(class_)} has a deferred " + "mapping on it. It is not yet usable as a mapped class." + ), ) @classmethod - def config_for_cls(cls, class_): + def config_for_cls(cls, class_: Type[Any]) -> _DeferredMapperConfig: return cls._configs[weakref.ref(class_)] @classmethod - def classes_for_base(cls, base_cls, sort=True): + def classes_for_base( + cls, base_cls: Type[Any], sort: bool = True + ) -> List[_DeferredMapperConfig]: classes_for_base = [ m for m, cls_ in [(m, m.cls) for m in cls._configs.values()] @@ -1213,7 +1370,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): all_m_by_cls = dict((m.cls, m) for m in classes_for_base) - tuples = [] + tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = [] for m_cls in all_m_by_cls: tuples.extend( (all_m_by_cls[base_cls], all_m_by_cls[m_cls]) @@ -1222,12 +1379,14 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): ) return list(topological.sort(tuples, classes_for_base)) - def map(self, mapper_kw=util.EMPTY_DICT): + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: self._configs.pop(self._cls, None) return super(_DeferredMapperConfig, self).map(mapper_kw) -def _add_attribute(cls, key, value): +def _add_attribute( + cls: Type[Any], key: str, value: MapperProperty[Any] +) -> None: """add an attribute to an existing declarative class. This runs through the logic to determine MapperProperty, @@ -1236,39 +1395,44 @@ def _add_attribute(cls, key, value): """ if "__mapper__" in cls.__dict__: + mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls) if isinstance(value, Column): _undefer_column_name(key, value) - cls.__table__.append_column(value, replace_existing=True) - cls.__mapper__.add_property(key, value) + # TODO: raise for this is not a Table + mapped_cls.__table__.append_column(value, replace_existing=True) + mapped_cls.__mapper__.add_property(key, value) elif isinstance(value, _MapsColumns): mp = value.mapper_property_to_assign for col in value.columns_to_assign: _undefer_column_name(key, col) - cls.__table__.append_column(col, replace_existing=True) + # TODO: raise for this is not a Table + mapped_cls.__table__.append_column(col, replace_existing=True) if not mp: - cls.__mapper__.add_property(key, col) + mapped_cls.__mapper__.add_property(key, col) if mp: - cls.__mapper__.add_property(key, mp) + mapped_cls.__mapper__.add_property(key, mp) elif isinstance(value, MapperProperty): - cls.__mapper__.add_property(key, value) + mapped_cls.__mapper__.add_property(key, value) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() value = Synonym(value.key) - cls.__mapper__.add_property(key, value) + mapped_cls.__mapper__.add_property(key, value) else: type.__setattr__(cls, key, value) - cls.__mapper__._expire_memoizations() + mapped_cls.__mapper__._expire_memoizations() else: type.__setattr__(cls, key, value) -def _del_attribute(cls, key): +def _del_attribute(cls: Type[Any], key: str) -> None: if ( "__mapper__" in cls.__dict__ and key in cls.__dict__ - and not cls.__mapper__._dispose_called + and not cast( + "_DeclMappedClassProtocol[Any]", cls + ).__mapper__._dispose_called ): value = cls.__dict__[key] if isinstance( @@ -1279,12 +1443,14 @@ def _del_attribute(cls, key): ) else: type.__delattr__(cls, key) - cls.__mapper__._expire_memoizations() + cast( + "_DeclMappedClassProtocol[Any]", cls + ).__mapper__._expire_memoizations() else: type.__delattr__(cls, key) -def _declarative_constructor(self, **kwargs): +def _declarative_constructor(self: Any, **kwargs: Any) -> None: """A simple constructor that allows initialization from kwargs. Sets attributes on the constructed instance using the names and @@ -1306,7 +1472,7 @@ def _declarative_constructor(self, **kwargs): _declarative_constructor.__name__ = "__init__" -def _undefer_column_name(key, column): +def _undefer_column_name(key: str, column: Column[Any]) -> None: if column.key is None: column.key = key if column.name is None: diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 5975c30db3..8c89f96aa9 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -20,15 +20,21 @@ import typing from typing import Any from typing import Callable from typing import List +from typing import NoReturn from typing import Optional +from typing import Sequence from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import attributes from . import util as orm_util +from .base import LoaderCallableStatus from .base import Mapped +from .base import PassiveFlag +from .base import SQLORMOperations from .interfaces import _IntrospectsAnnotations from .interfaces import _MapsColumns from .interfaces import MapperProperty @@ -41,20 +47,41 @@ from .. import schema from .. import sql from .. import util from ..sql import expression -from ..sql import operators +from ..sql.elements import BindParameter from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _InstanceDict + from ._typing import _RegistryType + from .attributes import History from .attributes import InstrumentedAttribute + from .attributes import QueryableAttribute + from .context import ORMCompileState + from .mapper import Mapper + from .properties import ColumnProperty from .properties import MappedColumn + from .state import InstanceState + from ..engine.base import Connection + from ..engine.row import Row + from ..sql._typing import _DMLColumnArgument from ..sql._typing import _InfoType + from ..sql.elements import ClauseList + from ..sql.elements import ColumnElement from ..sql.schema import Column + from ..sql.selectable import Select + from ..util.typing import _AnnotationScanType + from ..util.typing import CallableReference + from ..util.typing import DescriptorReference + from ..util.typing import RODescriptorReference _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) class _CompositeClassProto(Protocol): + def __init__(self, *args: Any): + ... + def __composite_values__(self) -> Tuple[Any, ...]: ... @@ -63,32 +90,43 @@ class DescriptorProperty(MapperProperty[_T]): """:class:`.MapperProperty` which proxies access to a user-defined descriptor.""" - doc = None + doc: Optional[str] = None uses_objects = False _links_to_entity = False - def instrument_class(self, mapper): + descriptor: DescriptorReference[Any] + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + raise NotImplementedError() + + def instrument_class(self, mapper: Mapper[Any]) -> None: prop = self - class _ProxyImpl: + class _ProxyImpl(attributes.AttributeImpl): accepts_scalar_loader = False load_on_unexpire = True collection = False @property - def uses_objects(self): + def uses_objects(self) -> bool: # type: ignore return prop.uses_objects - def __init__(self, key): + def __init__(self, key: str): self.key = key - if hasattr(prop, "get_history"): - - def get_history( - self, state, dict_, passive=attributes.PASSIVE_OFF - ): - return prop.get_history(state, dict_, passive) + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + return prop.get_history(state, dict_, passive) if self.descriptor is None: desc = getattr(mapper.class_, self.key, None) @@ -97,13 +135,13 @@ class DescriptorProperty(MapperProperty[_T]): if self.descriptor is None: - def fset(obj, value): + def fset(obj: Any, value: Any) -> None: setattr(obj, self.name, value) - def fdel(obj): + def fdel(obj: Any) -> None: delattr(obj, self.name) - def fget(obj): + def fget(obj: Any) -> Any: return getattr(obj, self.name) self.descriptor = property(fget=fget, fset=fset, fdel=fdel) @@ -129,8 +167,11 @@ _CompositeAttrType = Union[ ] +_CC = TypeVar("_CC", bound=_CompositeClassProto) + + class Composite( - _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T] + _MapsColumns[_CC], _IntrospectsAnnotations, DescriptorProperty[_CC] ): """Defines a "composite" mapped attribute, representing a collection of columns as one attribute. @@ -148,19 +189,25 @@ class Composite( """ - composite_class: Union[ - Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]] + composite_class: Union[Type[_CC], Callable[..., _CC]] + attrs: Tuple[_CompositeAttrType[Any], ...] + + _generated_composite_accessor: CallableReference[ + Optional[Callable[[_CC], Tuple[Any, ...]]] ] - attrs: Tuple[_CompositeAttrType, ...] + + comparator_factory: Type[Comparator[_CC]] def __init__( self, - class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None, - *attrs: _CompositeAttrType, + class_: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, + *attrs: _CompositeAttrType[Any], active_history: bool = False, deferred: bool = False, group: Optional[str] = None, - comparator_factory: Optional[Type[Comparator]] = None, + comparator_factory: Optional[Type[Comparator[_CC]]] = None, info: Optional[_InfoType] = None, ): super().__init__() @@ -170,7 +217,7 @@ class Composite( # will initialize within declarative_scan self.composite_class = None # type: ignore else: - self.composite_class = class_ + self.composite_class = class_ # type: ignore self.attrs = attrs self.active_history = active_history @@ -183,18 +230,16 @@ class Composite( ) self._generated_composite_accessor = None if info is not None: - self.info = info + self.info.update(info) util.set_creation_order(self) self._create_descriptor() - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: super().instrument_class(mapper) self._setup_event_handlers() - def _composite_values_from_instance( - self, value: _CompositeClassProto - ) -> Tuple[Any, ...]: + def _composite_values_from_instance(self, value: _CC) -> Tuple[Any, ...]: if self._generated_composite_accessor: return self._generated_composite_accessor(value) else: @@ -209,7 +254,7 @@ class Composite( else: return accessor() - def do_init(self): + def do_init(self) -> None: """Initialization which occurs after the :class:`.Composite` has been associated with its parent mapper. @@ -218,13 +263,13 @@ class Composite( _COMPOSITE_FGET = object() - def _create_descriptor(self): + def _create_descriptor(self) -> None: """Create the Python descriptor that will serve as the access point on instances of the mapped class. """ - def fget(instance): + def fget(instance: Any) -> Any: dict_ = attributes.instance_dict(instance) state = attributes.instance_state(instance) @@ -251,11 +296,11 @@ class Composite( return dict_.get(self.key, None) - def fset(instance, value): + def fset(instance: Any, value: Any) -> None: dict_ = attributes.instance_dict(instance) state = attributes.instance_state(instance) attr = state.manager[self.key] - previous = dict_.get(self.key, attributes.NO_VALUE) + previous = dict_.get(self.key, LoaderCallableStatus.NO_VALUE) for fn in attr.dispatch.set: value = fn(state, value, previous, attr.impl) dict_[self.key] = value @@ -269,10 +314,10 @@ class Composite( ): setattr(instance, key, value) - def fdel(instance): + def fdel(instance: Any) -> None: state = attributes.instance_state(instance) dict_ = attributes.instance_dict(instance) - previous = dict_.pop(self.key, attributes.NO_VALUE) + previous = dict_.pop(self.key, LoaderCallableStatus.NO_VALUE) attr = state.manager[self.key] attr.dispatch.remove(state, previous, attr.impl) for key in self._attribute_keys: @@ -282,8 +327,13 @@ class Composite( @util.preload_module("sqlalchemy.orm.properties") def declarative_scan( - self, registry, cls, key, annotation, is_dataclass_field - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: MappedColumn = util.preloaded.orm_properties.MappedColumn argument = _extract_mapped_subtype( @@ -310,7 +360,9 @@ class Composite( @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") - def _setup_for_dataclass(self, registry, cls, key): + def _setup_for_dataclass( + self, registry: _RegistryType, cls: Type[Any], key: str + ) -> None: MappedColumn = util.preloaded.orm_properties.MappedColumn decl_base = util.preloaded.orm_decl_base @@ -341,12 +393,12 @@ class Composite( self._generated_composite_accessor = getter @util.memoized_property - def _comparable_elements(self): + def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: return [getattr(self.parent.class_, prop.key) for prop in self.props] @util.memoized_property @util.preload_module("orm.properties") - def props(self): + def props(self) -> Sequence[MapperProperty[Any]]: props = [] MappedColumn = util.preloaded.orm_properties.MappedColumn @@ -360,17 +412,20 @@ class Composite( elif isinstance(attr, attributes.InstrumentedAttribute): prop = attr.property else: + prop = None + + if not isinstance(prop, MapperProperty): raise sa_exc.ArgumentError( "Composite expects Column objects or mapped " - "attributes/attribute names as arguments, got: %r" - % (attr,) + f"attributes/attribute names as arguments, got: {attr!r}" ) + props.append(prop) return props - @property + @util.non_memoized_property @util.preload_module("orm.properties") - def columns(self): + def columns(self) -> Sequence[Column[Any]]: MappedColumn = util.preloaded.orm_properties.MappedColumn return [ a.column if isinstance(a, MappedColumn) else a @@ -379,32 +434,46 @@ class Composite( ] @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_CC]]: return self @property - def columns_to_assign(self) -> List[schema.Column]: + def columns_to_assign(self) -> List[schema.Column[Any]]: return [c for c in self.columns if c.table is None] - def _setup_arguments_on_columns(self): + @util.preload_module("orm.properties") + def _setup_arguments_on_columns(self) -> None: """Propagate configuration arguments made on this composite to the target columns, for those that apply. """ + ColumnProperty = util.preloaded.orm_properties.ColumnProperty + for prop in self.props: - prop.active_history = self.active_history + if not isinstance(prop, ColumnProperty): + continue + else: + cprop = prop + + cprop.active_history = self.active_history if self.deferred: - prop.deferred = self.deferred - prop.strategy_key = (("deferred", True), ("instrument", True)) - prop.group = self.group + cprop.deferred = self.deferred + cprop.strategy_key = (("deferred", True), ("instrument", True)) + cprop.group = self.group - def _setup_event_handlers(self): + def _setup_event_handlers(self) -> None: """Establish events that populate/expire the composite attribute.""" - def load_handler(state, context): + def load_handler( + state: InstanceState[Any], context: ORMCompileState + ) -> None: _load_refresh_handler(state, context, None, is_refresh=False) - def refresh_handler(state, context, to_load): + def refresh_handler( + state: InstanceState[Any], + context: ORMCompileState, + to_load: Optional[Sequence[str]], + ) -> None: # note this corresponds to sqlalchemy.ext.mutable load_attrs() if not to_load or ( @@ -412,7 +481,12 @@ class Composite( ).intersection(to_load): _load_refresh_handler(state, context, to_load, is_refresh=True) - def _load_refresh_handler(state, context, to_load, is_refresh): + def _load_refresh_handler( + state: InstanceState[Any], + context: ORMCompileState, + to_load: Optional[Sequence[str]], + is_refresh: bool, + ) -> None: dict_ = state.dict # if context indicates we are coming from the @@ -440,11 +514,17 @@ class Composite( *[state.dict[key] for key in self._attribute_keys] ) - def expire_handler(state, keys): + def expire_handler( + state: InstanceState[Any], keys: Optional[Sequence[str]] + ) -> None: if keys is None or set(self._attribute_keys).intersection(keys): state.dict.pop(self.key, None) - def insert_update_handler(mapper, connection, state): + def insert_update_handler( + mapper: Mapper[Any], + connection: Connection, + state: InstanceState[Any], + ) -> None: """After an insert or update, some columns may be expired due to server side defaults, or re-populated due to client side defaults. Pop out the composite value here so that it @@ -473,14 +553,19 @@ class Composite( # TODO: need a deserialize hook here @util.memoized_property - def _attribute_keys(self): + def _attribute_keys(self) -> Sequence[str]: return [prop.key for prop in self.props] - def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF): + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: """Provided for userland code that uses attributes.get_history().""" - added = [] - deleted = [] + added: List[Any] = [] + deleted: List[Any] = [] has_history = False for prop in self.props: @@ -508,16 +593,27 @@ class Composite( else: return attributes.History((), [self.composite_class(*added)], ()) - def _comparator_factory(self, mapper): + def _comparator_factory( + self, mapper: Mapper[Any] + ) -> Composite.Comparator[_CC]: return self.comparator_factory(self, mapper) - class CompositeBundle(orm_util.Bundle): - def __init__(self, property_, expr): + class CompositeBundle(orm_util.Bundle[_T]): + def __init__( + self, + property_: Composite[_T], + expr: ClauseList, + ): self.property = property_ super().__init__(property_.key, *expr) - def create_row_processor(self, query, procs, labels): - def proc(row): + def create_row_processor( + self, + query: Select[Any], + procs: Sequence[Callable[[Row[Any]], Any]], + labels: Sequence[str], + ) -> Callable[[Row[Any]], Any]: + def proc(row: Row[Any]) -> Any: return self.property.composite_class( *[proc(row) for proc in procs] ) @@ -546,17 +642,19 @@ class Composite( # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore + prop: RODescriptorReference[Composite[_PT]] + @util.memoized_property - def clauses(self): + def clauses(self) -> ClauseList: return expression.ClauseList( group=False, *self._comparable_elements ) - def __clause_element__(self): + def __clause_element__(self) -> Composite.CompositeBundle[_PT]: return self.expression @util.memoized_property - def expression(self): + def expression(self) -> Composite.CompositeBundle[_PT]: clauses = self.clauses._annotate( { "parententity": self._parententity, @@ -566,13 +664,19 @@ class Composite( ) return Composite.CompositeBundle(self.prop, clauses) - def _bulk_update_tuples(self, value): - if isinstance(value, sql.elements.BindParameter): + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + if isinstance(value, BindParameter): value = value.value + values: Sequence[Any] + if value is None: values = [None for key in self.prop._attribute_keys] - elif isinstance(value, self.prop.composite_class): + elif isinstance(self.prop.composite_class, type) and isinstance( + value, self.prop.composite_class + ): values = self.prop._composite_values_from_instance(value) else: raise sa_exc.ArgumentError( @@ -580,10 +684,10 @@ class Composite( % (self.prop, value) ) - return zip(self._comparable_elements, values) + return list(zip(self._comparable_elements, values)) @util.memoized_property - def _comparable_elements(self): + def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: if self._adapt_to_entity: return [ getattr(self._adapt_to_entity.entity, prop.key) @@ -592,7 +696,8 @@ class Composite( else: return self.prop._comparable_elements - def __eq__(self, other): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + values: Sequence[Any] if other is None: values = [None] * len(self.prop._comparable_elements) else: @@ -601,13 +706,14 @@ class Composite( a == b for a, b in zip(self.prop._comparable_elements, values) ] if self._adapt_to_entity: + assert self.adapter is not None comparisons = [self.adapter(x) for x in comparisons] return sql.and_(*comparisons) - def __ne__(self, other): + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 return sql.not_(self.__eq__(other)) - def __str__(self): + def __str__(self) -> str: return str(self.parent.class_.__name__) + "." + self.key @@ -628,20 +734,24 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): """ - def _comparator_factory(self, mapper): + def _comparator_factory( + self, mapper: Mapper[Any] + ) -> Type[PropComparator[_T]]: + comparator_callable = None for m in self.parent.iterate_to_root(): p = m._props[self.key] - if not isinstance(p, ConcreteInheritedProperty): + if getattr(p, "comparator_factory", None) is not None: comparator_callable = p.comparator_factory break - return comparator_callable + assert comparator_callable is not None + return comparator_callable(p, mapper) # type: ignore - def __init__(self): + def __init__(self) -> None: super().__init__() - def warn(): + def warn() -> NoReturn: raise AttributeError( "Concrete %s does not implement " "attribute %r at the instance level. Add " @@ -650,13 +760,13 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): ) class NoninheritedConcreteProp: - def __set__(s, obj, value): + def __set__(s: Any, obj: Any, value: Any) -> NoReturn: warn() - def __delete__(s, obj): + def __delete__(s: Any, obj: Any) -> NoReturn: warn() - def __get__(s, obj, owner): + def __get__(s: Any, obj: Any, owner: Any) -> Any: if obj is None: return self.descriptor warn() @@ -682,14 +792,16 @@ class Synonym(DescriptorProperty[_T]): """ + comparator_factory: Optional[Type[PropComparator[_T]]] + def __init__( self, - name, - map_column=None, - descriptor=None, - comparator_factory=None, - doc=None, - info=None, + name: str, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, ): super().__init__() @@ -697,21 +809,30 @@ class Synonym(DescriptorProperty[_T]): self.map_column = map_column self.descriptor = descriptor self.comparator_factory = comparator_factory - self.doc = doc or (descriptor and descriptor.__doc__) or None + if doc: + self.doc = doc + elif descriptor and descriptor.__doc__: + self.doc = descriptor.__doc__ + else: + self.doc = None if info: - self.info = info + self.info.update(info) util.set_creation_order(self) - @property - def uses_objects(self): - return getattr(self.parent.class_, self.name).impl.uses_objects + if not TYPE_CHECKING: + + @property + def uses_objects(self) -> bool: + return getattr(self.parent.class_, self.name).impl.uses_objects # TODO: when initialized, check _proxied_object, # emit a warning if its not a column-based property @util.memoized_property - def _proxied_object(self): + def _proxied_object( + self, + ) -> Union[MapperProperty[_T], SQLORMOperations[_T]]: attr = getattr(self.parent.class_, self.name) if not hasattr(attr, "property") or not isinstance( attr.property, MapperProperty @@ -720,7 +841,8 @@ class Synonym(DescriptorProperty[_T]): # hybrid or association proxy if isinstance(attr, attributes.QueryableAttribute): return attr.comparator - elif isinstance(attr, operators.ColumnOperators): + elif isinstance(attr, SQLORMOperations): + # assocaition proxy comes here return attr raise sa_exc.InvalidRequestError( @@ -730,7 +852,7 @@ class Synonym(DescriptorProperty[_T]): ) return attr.property - def _comparator_factory(self, mapper): + def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]: prop = self._proxied_object if isinstance(prop, MapperProperty): @@ -742,12 +864,17 @@ class Synonym(DescriptorProperty[_T]): else: return prop - def get_history(self, *arg, **kw): - attr = getattr(self.parent.class_, self.name) - return attr.impl.get_history(*arg, **kw) + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + attr: QueryableAttribute[Any] = getattr(self.parent.class_, self.name) + return attr.impl.get_history(state, dict_, passive=passive) @util.preload_module("sqlalchemy.orm.properties") - def set_parent(self, parent, init): + def set_parent(self, parent: Mapper[Any], init: bool) -> None: properties = util.preloaded.orm_properties if self.map_column: @@ -776,7 +903,7 @@ class Synonym(DescriptorProperty[_T]): "%r for column %r" % (self.key, self.name, self.name, self.key) ) - p = properties.ColumnProperty( + p: ColumnProperty[Any] = properties.ColumnProperty( parent.persist_selectable.c[self.key] ) parent._configure_property(self.name, p, init=init, setparent=True) diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 1b4f573b50..084ba969fb 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -16,6 +16,12 @@ basic add/delete mutation. from __future__ import annotations +from typing import Any +from typing import Optional +from typing import overload +from typing import TYPE_CHECKING +from typing import Union + from . import attributes from . import exc as orm_exc from . import interfaces @@ -23,17 +29,27 @@ from . import relationships from . import strategies from . import util as orm_util from .base import object_mapper +from .base import PassiveFlag from .query import Query from .session import object_session from .. import exc from .. import log from .. import util from ..engine import result +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import _InstanceDict + from .attributes import _AdaptedCollectionProtocol + from .attributes import AttributeEventToken + from .attributes import CollectionAdapter + from .base import LoaderCallableStatus + from .state import InstanceState @log.class_logger @relationships.Relationship.strategy_for(lazy="dynamic") -class DynaLoader(strategies.AbstractRelationshipLoader): +class DynaLoader(strategies.AbstractRelationshipLoader, log.Identified): def init_class_attribute(self, mapper): self.is_class_level = True if not self.uselist: @@ -106,13 +122,47 @@ class DynamicAttributeImpl( else: return self.query_class(self, state) + @overload def get_collection( self, - state, - dict_, - user_data=None, - passive=attributes.PASSIVE_NO_INITIALIZE, - ): + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Literal[None] = ..., + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., + passive: PassiveFlag = ..., + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: if not passive & attributes.SQL_OK: data = self._get_collection_history(state, passive).added_items else: @@ -170,15 +220,15 @@ class DynamicAttributeImpl( def set( self, - state, - dict_, - value, - initiator=None, - passive=attributes.PASSIVE_OFF, - check_old=None, - pop=False, - _adapt=True, - ): + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: if initiator and initiator.parent_token is self.parent_token: return diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 331c224eef..726ea79b5b 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors """ORM event interfaces. diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index f157919ab9..57e5fe8c6e 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """SQLAlchemy ORM exceptions.""" @@ -12,13 +11,22 @@ from __future__ import annotations from typing import Any from typing import Optional +from typing import Tuple from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar from .. import exc as sa_exc from .. import util from ..exc import MultipleResultsFound # noqa from ..exc import NoResultFound # noqa +if TYPE_CHECKING: + from .interfaces import LoaderStrategy + from .interfaces import MapperProperty + from .state import InstanceState + +_T = TypeVar("_T", bound=Any) NO_STATE = (AttributeError, KeyError) """Exception types that may be raised by instrumentation implementations.""" @@ -100,14 +108,14 @@ class UnmappedInstanceError(UnmappedError): ) UnmappedError.__init__(self, msg) - def __reduce__(self): + def __reduce__(self) -> Any: return self.__class__, (None, self.args[0]) class UnmappedClassError(UnmappedError): """An mapping operation was requested for an unknown class.""" - def __init__(self, cls: Type[object], msg: Optional[str] = None): + def __init__(self, cls: Type[_T], msg: Optional[str] = None): if not msg: msg = _default_unmapped(cls) UnmappedError.__init__(self, msg) @@ -137,7 +145,7 @@ class ObjectDeletedError(sa_exc.InvalidRequestError): """ @util.preload_module("sqlalchemy.orm.base") - def __init__(self, state, msg=None): + def __init__(self, state: InstanceState[Any], msg: Optional[str] = None): base = util.preloaded.orm_base if not msg: @@ -148,7 +156,7 @@ class ObjectDeletedError(sa_exc.InvalidRequestError): sa_exc.InvalidRequestError.__init__(self, msg) - def __reduce__(self): + def __reduce__(self) -> Any: return self.__class__, (None, self.args[0]) @@ -161,11 +169,11 @@ class LoaderStrategyException(sa_exc.InvalidRequestError): def __init__( self, - applied_to_property_type, - requesting_property, - applies_to, - actual_strategy_type, - strategy_key, + applied_to_property_type: Type[Any], + requesting_property: MapperProperty[Any], + applies_to: Optional[Type[MapperProperty[Any]]], + actual_strategy_type: Optional[Type[LoaderStrategy]], + strategy_key: Tuple[Any, ...], ): if actual_strategy_type is None: sa_exc.InvalidRequestError.__init__( @@ -174,6 +182,7 @@ class LoaderStrategyException(sa_exc.InvalidRequestError): % (strategy_key, requesting_property), ) else: + assert applies_to is not None sa_exc.InvalidRequestError.__init__( self, 'Can\'t apply "%s" strategy to property "%s", ' @@ -188,7 +197,8 @@ class LoaderStrategyException(sa_exc.InvalidRequestError): ) -def _safe_cls_name(cls): +def _safe_cls_name(cls: Type[Any]) -> str: + cls_name: Optional[str] try: cls_name = ".".join((cls.__module__, cls.__name__)) except AttributeError: @@ -199,7 +209,7 @@ def _safe_cls_name(cls): @util.preload_module("sqlalchemy.orm.base") -def _default_unmapped(cls) -> Optional[str]: +def _default_unmapped(cls: Type[Any]) -> Optional[str]: base = util.preloaded.orm_base try: diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index d13265c560..63b131a780 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -8,6 +8,7 @@ from __future__ import annotations from typing import Any +from typing import cast from typing import Dict from typing import Iterable from typing import Iterator @@ -15,6 +16,7 @@ from typing import List from typing import NoReturn from typing import Optional from typing import Set +from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar import weakref @@ -66,7 +68,7 @@ class IdentityMap: ) -> Optional[_O]: raise NotImplementedError() - def keys(self): + def keys(self) -> Iterable[_IdentityKeyType[Any]]: return self._dict.keys() def values(self) -> Iterable[object]: @@ -117,10 +119,10 @@ class IdentityMap: class WeakInstanceDict(IdentityMap): - _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]] + _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]] def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: - state = self._dict[key] + state = cast("InstanceState[_O]", self._dict[key]) o = state.obj() if o is None: raise KeyError(key) @@ -140,6 +142,8 @@ class WeakInstanceDict(IdentityMap): def contains_state(self, state: InstanceState[Any]) -> bool: if state.key in self._dict: + if TYPE_CHECKING: + assert state.key is not None try: return self._dict[state.key] is state except KeyError: @@ -150,15 +154,16 @@ class WeakInstanceDict(IdentityMap): def replace( self, state: InstanceState[Any] ) -> Optional[InstanceState[Any]]: + assert state.key is not None if state.key in self._dict: try: - existing = self._dict[state.key] + existing = existing_non_none = self._dict[state.key] except KeyError: # catch gc removed the key after we just checked for it existing = None else: - if existing is not state: - self._manage_removed_state(existing) + if existing_non_none is not state: + self._manage_removed_state(existing_non_none) else: return None else: @@ -170,6 +175,7 @@ class WeakInstanceDict(IdentityMap): def add(self, state: InstanceState[Any]) -> bool: key = state.key + assert key is not None # inline of self.__contains__ if key in self._dict: try: @@ -206,7 +212,7 @@ class WeakInstanceDict(IdentityMap): if key not in self._dict: return default try: - state = self._dict[key] + state = cast("InstanceState[_O]", self._dict[key]) except KeyError: # catch gc removed the key after we just checked for it return default @@ -216,13 +222,15 @@ class WeakInstanceDict(IdentityMap): return default return o - def items(self) -> List[InstanceState[Any]]: + def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]: values = self.all_states() result = [] for state in values: value = state.obj() + key = state.key + assert key is not None if value is not None: - result.append((state.key, value)) + result.append((key, value)) return result def values(self) -> List[object]: @@ -244,28 +252,32 @@ class WeakInstanceDict(IdentityMap): def _fast_discard(self, state: InstanceState[Any]) -> None: # used by InstanceState for state being # GC'ed, inlines _managed_removed_state + key = state.key + assert key is not None try: - st = self._dict[state.key] + st = self._dict[key] except KeyError: # catch gc removed the key after we just checked for it pass else: if st is state: - self._dict.pop(state.key, None) + self._dict.pop(key, None) def discard(self, state: InstanceState[Any]) -> None: self.safe_discard(state) def safe_discard(self, state: InstanceState[Any]) -> None: - if state.key in self._dict: + key = state.key + if key in self._dict: + assert key is not None try: - st = self._dict[state.key] + st = self._dict[key] except KeyError: # catch gc removed the key after we just checked for it pass else: if st is state: - self._dict.pop(state.key, None) + self._dict.pop(key, None) self._manage_removed_state(state) diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 85b85215ea..4fa61b7cee 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -66,7 +66,7 @@ from ..util.typing import Protocol if TYPE_CHECKING: from ._typing import _RegistryType from .attributes import AttributeImpl - from .attributes import InstrumentedAttribute + from .attributes import QueryableAttribute from .collections import _AdaptedCollectionProtocol from .collections import _CollectionFactoryType from .decl_base import _MapperConfig @@ -96,7 +96,7 @@ class _ManagerFactory(Protocol): class ClassManager( HasMemoized, - Dict[str, "InstrumentedAttribute[Any]"], + Dict[str, "QueryableAttribute[Any]"], Generic[_O], EventTarget, ): @@ -117,7 +117,14 @@ class ClassManager( factory: Optional[_ManagerFactory] declarative_scan: Optional[weakref.ref[_MapperConfig]] = None - registry: Optional[_RegistryType] = None + + registry: _RegistryType + + if not TYPE_CHECKING: + # starts as None during setup + registry = None + + class_: Type[_O] _bases: List[ClassManager[Any]] @@ -312,7 +319,7 @@ class ClassManager( else: return default - def _attr_has_impl(self, key): + def _attr_has_impl(self, key: str) -> bool: """Return True if the given attribute is fully initialized. i.e. has an impl. @@ -366,7 +373,12 @@ class ClassManager( def dict_getter(self): return _default_dict_getter - def instrument_attribute(self, key, inst, propagated=False): + def instrument_attribute( + self, + key: str, + inst: QueryableAttribute[Any], + propagated: bool = False, + ) -> None: if propagated: if key in self.local_attrs: return # don't override local attr with inherited attr @@ -429,7 +441,7 @@ class ClassManager( delattr(self.class_, self.MANAGER_ATTR) def install_descriptor( - self, key: str, inst: InstrumentedAttribute[Any] + self, key: str, inst: QueryableAttribute[Any] ) -> None: if key in (self.STATE_ATTR, self.MANAGER_ATTR): raise KeyError( @@ -490,7 +502,11 @@ class ClassManager( # InstanceState management def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O: - instance = self.class_.__new__(self.class_) + # here, we would prefer _O to be bound to "object" + # so that mypy sees that __new__ is present. currently + # it's bound to Any as there were other problems not having + # it that way but these can be revisited + instance = self.class_.__new__(self.class_) # type: ignore if state is None: state = self._state_constructor(instance, self) self._state_setter(instance, state) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index c9c54c1b08..b5569ce063 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls """ @@ -33,6 +32,7 @@ from typing import Sequence from typing import Set from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -48,6 +48,7 @@ from .base import MANYTOMANY as MANYTOMANY # noqa: F401 from .base import MANYTOONE as MANYTOONE # noqa: F401 from .base import NotExtension as NotExtension # noqa: F401 from .base import ONETOMANY as ONETOMANY # noqa: F401 +from .base import RelationshipDirection as RelationshipDirection # noqa: F401 from .base import SQLORMOperations from .. import ColumnElement from .. import inspection @@ -59,7 +60,7 @@ from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey from ..sql.schema import Column from ..sql.type_api import TypeEngine -from ..util.typing import DescriptorReference +from ..util.typing import RODescriptorReference from ..util.typing import TypedDict if typing.TYPE_CHECKING: @@ -75,13 +76,11 @@ if typing.TYPE_CHECKING: from .loading import _PopulatorDict from .mapper import Mapper from .path_registry import AbstractEntityRegistry - from .path_registry import PathRegistry from .query import Query from .session import Session from .state import InstanceState from .strategy_options import _LoadElement from .util import AliasedInsp - from .util import CascadeOptions from .util import ORMAdapter from ..engine.result import Result from ..sql._typing import _ColumnExpressionArgument @@ -89,8 +88,10 @@ if typing.TYPE_CHECKING: from ..sql._typing import _DMLColumnArgument from ..sql._typing import _InfoType from ..sql.operators import OperatorType - from ..sql.util import ColumnAdapter from ..sql.visitors import _TraverseInternalsType + from ..util.typing import _AnnotationScanType + +_StrategyKey = Tuple[Any, ...] _T = TypeVar("_T", bound=Any) @@ -104,7 +105,9 @@ class ORMStatementRole(roles.StatementRole): ) -class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]): +class ORMColumnsClauseRole( + roles.ColumnsClauseRole, roles.TypedColumnsClauseRole[_T] +): __slots__ = () _role_name = "ORM mapped entity, aliased entity, or Column expression" @@ -137,8 +140,8 @@ class _IntrospectsAnnotations: registry: RegistryType, cls: Type[Any], key: str, - annotation: Optional[Type[Any]], - is_dataclass_field: Optional[bool], + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, ) -> None: """Perform class-specific initializaton at early declarative scanning time. @@ -199,6 +202,7 @@ class MapperProperty( "parent", "key", "info", + "doc", ) _cache_key_traversal: _TraverseInternalsType = [ @@ -206,14 +210,8 @@ class MapperProperty( ("key", visitors.ExtendedInternalTraversal.dp_string), ] - cascade: Optional[CascadeOptions] = None - """The set of 'cascade' attribute names. - - This collection is checked before the 'cascade_iterator' method is called. - - The collection typically only applies to a Relationship. - - """ + if not TYPE_CHECKING: + cascade = None is_property = True """Part of the InspectionAttr interface; states this object is a @@ -240,6 +238,9 @@ class MapperProperty( """ + doc: Optional[str] + """optional documentation string""" + def _memoized_attr_info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined data to be associated with this :class:`.InspectionAttr`. @@ -268,8 +269,8 @@ class MapperProperty( self, context: ORMCompileState, query_entity: _MapperEntity, - path: PathRegistry, - adapter: Optional[ColumnAdapter], + path: AbstractEntityRegistry, + adapter: Optional[ORMAdapter], **kwargs: Any, ) -> None: """Called by Query for the purposes of constructing a SQL statement. @@ -284,10 +285,10 @@ class MapperProperty( self, context: ORMCompileState, query_entity: _MapperEntity, - path: PathRegistry, + path: AbstractEntityRegistry, mapper: Mapper[Any], result: Result[Any], - adapter: Optional[ColumnAdapter], + adapter: Optional[ORMAdapter], populators: _PopulatorDict, ) -> None: """Produce row processing functions and append to the given @@ -421,7 +422,7 @@ class MapperProperty( dest_state: InstanceState[Any], dest_dict: _InstanceDict, load: bool, - _recursive: Set[InstanceState[Any]], + _recursive: Dict[Any, object], _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], ) -> None: """Merge the attribute represented by this ``MapperProperty`` @@ -526,7 +527,7 @@ class PropComparator(SQLORMOperations[_T]): _parententity: _InternalEntityType[Any] _adapt_to_entity: Optional[AliasedInsp[Any]] - prop: DescriptorReference[MapperProperty[_T]] + prop: RODescriptorReference[MapperProperty[_T]] def __init__( self, @@ -539,7 +540,7 @@ class PropComparator(SQLORMOperations[_T]): self._adapt_to_entity = adapt_to_entity @util.non_memoized_property - def property(self) -> Optional[MapperProperty[_T]]: + def property(self) -> MapperProperty[_T]: """Return the :class:`.MapperProperty` associated with this :class:`.PropComparator`. @@ -589,7 +590,7 @@ class PropComparator(SQLORMOperations[_T]): return self.prop.comparator._criterion_exists(criterion, **kwargs) @util.ro_non_memoized_property - def adapter(self) -> Optional[_ORMAdapterProto[_T]]: + def adapter(self) -> Optional[_ORMAdapterProto]: """Produce a callable that adapts column expressions to suit an aliased version of this comparator. @@ -597,7 +598,7 @@ class PropComparator(SQLORMOperations[_T]): if self._adapt_to_entity is None: return None else: - return self._adapt_to_entity._adapt_element + return self._adapt_to_entity._orm_adapt_element @util.ro_non_memoized_property def info(self) -> _InfoType: @@ -631,7 +632,7 @@ class PropComparator(SQLORMOperations[_T]): ) -> ColumnElement[Any]: ... - def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: + def of_type(self, class_: _EntityType[_T]) -> PropComparator[_T]: r"""Redefine this object in terms of a polymorphic subclass, :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased` construct. @@ -763,9 +764,9 @@ class StrategizedProperty(MapperProperty[_T]): inherit_cache = True strategy_wildcard_key: ClassVar[str] - strategy_key: Tuple[Any, ...] + strategy_key: _StrategyKey - _strategies: Dict[Tuple[Any, ...], LoaderStrategy] + _strategies: Dict[_StrategyKey, LoaderStrategy] def _memoized_attr__wildcard_token(self) -> Tuple[str]: return ( @@ -808,7 +809,7 @@ class StrategizedProperty(MapperProperty[_T]): return load - def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy: + def _get_strategy(self, key: _StrategyKey) -> LoaderStrategy: try: return self._strategies[key] except KeyError: @@ -822,7 +823,14 @@ class StrategizedProperty(MapperProperty[_T]): self._strategies[key] = strategy = cls(self, key) return strategy - def setup(self, context, query_entity, path, adapter, **kwargs): + def setup( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: loader = self._get_context_loader(context, path) if loader and loader.strategy: strat = self._get_strategy(loader.strategy) @@ -833,8 +841,15 @@ class StrategizedProperty(MapperProperty[_T]): ) def create_row_processor( - self, context, query_entity, path, mapper, result, adapter, populators - ): + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + mapper: Mapper[Any], + result: Result[Any], + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: loader = self._get_context_loader(context, path) if loader and loader.strategy: strat = self._get_strategy(loader.strategy) @@ -851,11 +866,11 @@ class StrategizedProperty(MapperProperty[_T]): populators, ) - def do_init(self): + def do_init(self) -> None: self._strategies = {} self.strategy = self._get_strategy(self.strategy_key) - def post_instrument_class(self, mapper): + def post_instrument_class(self, mapper: Mapper[Any]) -> None: if ( not self.parent.non_primary and not mapper.class_manager._attr_has_impl(self.key) @@ -863,7 +878,7 @@ class StrategizedProperty(MapperProperty[_T]): self.strategy.init_class_attribute(mapper) _all_strategies: collections.defaultdict[ - Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]] + Type[MapperProperty[Any]], Dict[_StrategyKey, Type[LoaderStrategy]] ] = collections.defaultdict(dict) @classmethod @@ -888,6 +903,8 @@ class StrategizedProperty(MapperProperty[_T]): for prop_cls in cls.__mro__: if prop_cls in cls._all_strategies: + if TYPE_CHECKING: + assert issubclass(prop_cls, MapperProperty) strategies = cls._all_strategies[prop_cls] try: return strategies[key] @@ -976,8 +993,8 @@ class CompileStateOption(HasCacheKey, ORMOption): _is_compile_state = True - def process_compile_state(self, compile_state): - """Apply a modification to a given :class:`.CompileState`. + def process_compile_state(self, compile_state: ORMCompileState) -> None: + """Apply a modification to a given :class:`.ORMCompileState`. This method is part of the implementation of a particular :class:`.CompileStateOption` and is only invoked internally @@ -986,9 +1003,11 @@ class CompileStateOption(HasCacheKey, ORMOption): """ def process_compile_state_replaced_entities( - self, compile_state, mapper_entities - ): - """Apply a modification to a given :class:`.CompileState`, + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: + """Apply a modification to a given :class:`.ORMCompileState`, given entities that were replaced by with_only_columns() or with_entities(). @@ -1011,8 +1030,10 @@ class LoaderOption(CompileStateOption): __slots__ = () def process_compile_state_replaced_entities( - self, compile_state, mapper_entities - ): + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: self.process_compile_state(compile_state) @@ -1028,7 +1049,7 @@ class CriteriaOption(CompileStateOption): _is_criteria_option = True - def get_global_criteria(self, attributes): + def get_global_criteria(self, attributes: Dict[str, Any]) -> None: """update additional entity criteria options in the given attributes dictionary. @@ -1054,7 +1075,7 @@ class UserDefinedOption(ORMOption): """ - def __init__(self, payload=None): + def __init__(self, payload: Optional[Any] = None): self.payload = payload @@ -1132,10 +1153,10 @@ class LoaderStrategy: "strategy_opts", ) - _strategy_keys: ClassVar[List[Tuple[Any, ...]]] + _strategy_keys: ClassVar[List[_StrategyKey]] def __init__( - self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...] + self, parent: MapperProperty[Any], strategy_key: _StrategyKey ): self.parent_property = parent self.is_class_level = False @@ -1186,5 +1207,5 @@ class LoaderStrategy: """ - def __str__(self): + def __str__(self) -> str: return str(self.parent_property) diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 75887367e7..1a5ea5fe65 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -54,11 +54,15 @@ from ..sql.selectable import SelectState if TYPE_CHECKING: from ._typing import _IdentityKeyType from .base import LoaderCallableStatus + from .context import QueryContext from .interfaces import ORMOption from .mapper import Mapper + from .query import Query from .session import Session from .state import InstanceState + from ..engine.cursor import CursorResult from ..engine.interfaces import _ExecuteOptions + from ..engine.result import Result from ..sql import Select _T = TypeVar("_T", bound=Any) @@ -69,7 +73,7 @@ _new_runid = util.counter() _PopulatorDict = Dict[str, List[Tuple[str, Any]]] -def instances(cursor, context): +def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]: """Return a :class:`.Result` given an ORM query context. :param cursor: a :class:`.CursorResult`, generated by a statement @@ -152,7 +156,7 @@ def instances(cursor, context): unique_filters = [ _no_unique if context.yield_per - else _not_hashable(ent.column.type) + else _not_hashable(ent.column.type) # type: ignore if (not ent.use_id_for_hash and ent._non_hashable_value) else id if ent.use_id_for_hash @@ -164,7 +168,7 @@ def instances(cursor, context): labels, extra, _unique_filters=unique_filters ) - def chunks(size): + def chunks(size): # type: ignore while True: yield_per = size @@ -302,7 +306,11 @@ def merge_frozen_result(session, statement, frozen_result, load=True): "is superseded by the :func:`_orm.merge_frozen_result` function.", ) @util.preload_module("sqlalchemy.orm.context") -def merge_result(query, iterator, load=True): +def merge_result( + query: Query[Any], + iterator: Union[FrozenResult, Iterable[Sequence[Any]], Iterable[object]], + load: bool = True, +) -> Union[FrozenResult, Iterable[Any]]: """Merge a result into the given :class:`.Query` object's Session. See :meth:`_orm.Query.merge_result` for top-level documentation on this @@ -375,7 +383,7 @@ def merge_result(query, iterator, load=True): result.append(keyed_tuple(newrow)) if frozen_result: - return frozen_result.with_data(result) + return frozen_result.with_new_rows(result) else: return iter(result) finally: diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py index 4324a000d1..d1057ca5f3 100644 --- a/lib/sqlalchemy/orm/mapped_collection.py +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors from __future__ import annotations diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 337a7178b0..2d3bceb928 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -80,6 +80,7 @@ from ..sql import roles from ..sql import util as sql_util from ..sql import visitors from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.elements import KeyedColumnElement from ..sql.schema import Table from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import HasMemoized @@ -108,7 +109,6 @@ if TYPE_CHECKING: from ..sql.base import ReadOnlyColumnCollection from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement - from ..sql.elements import KeyedColumnElement from ..sql.schema import Column from ..sql.selectable import FromClause from ..sql.util import ColumnAdapter @@ -182,6 +182,7 @@ class Mapper( dispatch: dispatcher[Mapper[_O]] _dispose_called = False + _configure_failed: Any = False _ready_for_configure = False @util.deprecated_params( @@ -710,8 +711,11 @@ class Mapper( self.batch = batch self.eager_defaults = eager_defaults self.column_prefix = column_prefix - self.polymorphic_on = ( - coercions.expect( + + # interim - polymorphic_on is further refined in + # _configure_polymorphic_setter + self.polymorphic_on = ( # type: ignore + coercions.expect( # type: ignore roles.ColumnArgumentOrKeyRole, polymorphic_on, argname="polymorphic_on", @@ -1832,12 +1836,22 @@ class Mapper( ) @util.preload_module("sqlalchemy.orm.descriptor_props") - def _configure_property(self, key, prop, init=True, setparent=True): + def _configure_property( + self, + key: str, + prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]], + init: bool = True, + setparent: bool = True, + ) -> MapperProperty[Any]: descriptor_props = util.preloaded.orm_descriptor_props - self._log("_configure_property(%s, %s)", key, prop.__class__.__name__) + self._log( + "_configure_property(%s, %s)", key, prop_arg.__class__.__name__ + ) - if not isinstance(prop, MapperProperty): - prop = self._property_from_column(key, prop) + if not isinstance(prop_arg, MapperProperty): + prop = self._property_from_column(key, prop_arg) + else: + prop = prop_arg if isinstance(prop, properties.ColumnProperty): col = self.persist_selectable.corresponding_column(prop.columns[0]) @@ -1950,18 +1964,23 @@ class Mapper( if self.configured: self._expire_memoizations() + return prop + @util.preload_module("sqlalchemy.orm.descriptor_props") - def _property_from_column(self, key, prop): + def _property_from_column( + self, + key: str, + prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]], + ) -> MapperProperty[Any]: """generate/update a :class:`.ColumnProperty` given a :class:`_schema.Column` object.""" descriptor_props = util.preloaded.orm_descriptor_props # we were passed a Column or a list of Columns; # generate a properties.ColumnProperty - columns = util.to_list(prop) + columns = util.to_list(prop_arg) column = columns[0] - assert isinstance(column, expression.ColumnElement) - prop = self._props.get(key, None) + prop = self._props.get(key) if isinstance(prop, properties.ColumnProperty): if ( @@ -2033,11 +2052,11 @@ class Mapper( "columns get mapped." % (key, self, column.key, prop) ) - def _check_configure(self): + def _check_configure(self) -> None: if self.registry._new_mappers: _configure_registries({self.registry}, cascade=True) - def _post_configure_properties(self): + def _post_configure_properties(self) -> None: """Call the ``init()`` method on all ``MapperProperties`` attached to this mapper. @@ -2068,7 +2087,9 @@ class Mapper( for key, value in dict_of_properties.items(): self.add_property(key, value) - def add_property(self, key, prop): + def add_property( + self, key: str, prop: Union[Column[Any], MapperProperty[Any]] + ) -> None: """Add an individual MapperProperty to this mapper. If the mapper has not been configured yet, just adds the @@ -2077,15 +2098,16 @@ class Mapper( the given MapperProperty is configured immediately. """ + prop = self._configure_property(key, prop, init=self.configured) + assert isinstance(prop, MapperProperty) self._init_properties[key] = prop - self._configure_property(key, prop, init=self.configured) - def _expire_memoizations(self): + def _expire_memoizations(self) -> None: for mapper in self.iterate_to_root(): mapper._reset_memoizations() @property - def _log_desc(self): + def _log_desc(self) -> str: return ( "(" + self.class_.__name__ @@ -2099,16 +2121,16 @@ class Mapper( + ")" ) - def _log(self, msg, *args): + def _log(self, msg: str, *args: Any) -> None: self.logger.info("%s " + msg, *((self._log_desc,) + args)) - def _log_debug(self, msg, *args): + def _log_debug(self, msg: str, *args: Any) -> None: self.logger.debug("%s " + msg, *((self._log_desc,) + args)) - def __repr__(self): + def __repr__(self) -> str: return "" % (id(self), self.class_.__name__) - def __str__(self): + def __str__(self) -> str: return "Mapper[%s%s(%s)]" % ( self.class_.__name__, self.non_primary and " (non-primary)" or "", @@ -2155,7 +2177,9 @@ class Mapper( "Mapper '%s' has no property '%s'" % (self, key) ) from err - def get_property_by_column(self, column): + def get_property_by_column( + self, column: ColumnElement[_T] + ) -> MapperProperty[_T]: """Given a :class:`_schema.Column` object, return the :class:`.MapperProperty` which maps this column.""" @@ -2795,7 +2819,7 @@ class Mapper( return result - def _is_userland_descriptor(self, assigned_name, obj): + def _is_userland_descriptor(self, assigned_name: str, obj: Any) -> bool: if isinstance( obj, ( @@ -3603,7 +3627,9 @@ def configure_mappers(): _configure_registries(_all_registries(), cascade=True) -def _configure_registries(registries, cascade): +def _configure_registries( + registries: Set[_RegistryType], cascade: bool +) -> None: for reg in registries: if reg._new_mappers: break @@ -3637,7 +3663,9 @@ def _configure_registries(registries, cascade): @util.preload_module("sqlalchemy.orm.decl_api") -def _do_configure_registries(registries, cascade): +def _do_configure_registries( + registries: Set[_RegistryType], cascade: bool +) -> None: registry = util.preloaded.orm_decl_api.registry @@ -3688,7 +3716,7 @@ def _do_configure_registries(registries, cascade): @util.preload_module("sqlalchemy.orm.decl_api") -def _dispose_registries(registries, cascade): +def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None: registry = util.preloaded.orm_decl_api.registry diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 361cea9757..36c14a6727 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from ..sql.cache_key import _CacheKeyTraversalType from ..sql.elements import BindParameter from ..sql.visitors import anon_map + from ..util.typing import _LiteralStar from ..util.typing import TypeGuard def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: @@ -80,7 +81,7 @@ def _unreduce_path(path: _SerializedPath) -> PathRegistry: return PathRegistry.deserialize(path) -_WILDCARD_TOKEN = "*" +_WILDCARD_TOKEN: _LiteralStar = "*" _DEFAULT_TOKEN = "_sa_default" @@ -115,6 +116,7 @@ class PathRegistry(HasCacheKey): is_token = False is_root = False has_entity = False + is_property = False is_entity = False path: _PathRepresentation @@ -175,7 +177,40 @@ class PathRegistry(HasCacheKey): def __hash__(self) -> int: return id(self) - def __getitem__(self, key: Any) -> PathRegistry: + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... + + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... + + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... + + @overload + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: + ... + + def __getitem__( + self, + entity: Union[ + str, int, slice, _InternalEntityType[Any], MapperProperty[Any] + ], + ) -> Union[ + TokenRegistry, + _PathElementType, + _PathRepresentation, + PropRegistry, + AbstractEntityRegistry, + ]: raise NotImplementedError() # TODO: what are we using this for? @@ -343,18 +378,8 @@ class RootRegistry(CreatesToken): is_root = True is_unnatural = False - @overload - def __getitem__(self, entity: str) -> TokenRegistry: - ... - - @overload - def __getitem__( - self, entity: _InternalEntityType[Any] - ) -> AbstractEntityRegistry: - ... - - def __getitem__( - self, entity: Union[str, _InternalEntityType[Any]] + def _getitem( + self, entity: Any ) -> Union[TokenRegistry, AbstractEntityRegistry]: if entity in PathToken._intern: if TYPE_CHECKING: @@ -368,6 +393,9 @@ class RootRegistry(CreatesToken): f"invalid argument for RootRegistry.__getitem__: {entity}" ) + if not TYPE_CHECKING: + __getitem__ = _getitem + PathRegistry.root = RootRegistry() @@ -441,12 +469,15 @@ class TokenRegistry(PathRegistry): else: yield self - def __getitem__(self, entity: Any) -> Any: + def _getitem(self, entity: Any) -> Any: try: return self.path[entity] except TypeError as err: raise IndexError(f"{entity}") from err + if not TYPE_CHECKING: + __getitem__ = _getitem + class PropRegistry(PathRegistry): __slots__ = ( @@ -463,6 +494,7 @@ class PropRegistry(PathRegistry): "is_unnatural", ) inherit_cache = True + is_property = True prop: MapperProperty[Any] mapper: Optional[Mapper[Any]] @@ -557,21 +589,7 @@ class PropRegistry(PathRegistry): assert self.entity is not None return self[self.entity] - @overload - def __getitem__(self, entity: slice) -> _PathRepresentation: - ... - - @overload - def __getitem__(self, entity: int) -> _PathElementType: - ... - - @overload - def __getitem__( - self, entity: _InternalEntityType[Any] - ) -> AbstractEntityRegistry: - ... - - def __getitem__( + def _getitem( self, entity: Union[int, slice, _InternalEntityType[Any]] ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]: if isinstance(entity, (int, slice)): @@ -579,6 +597,9 @@ class PropRegistry(PathRegistry): else: return SlotsEntityRegistry(self, entity) + if not TYPE_CHECKING: + __getitem__ = _getitem + class AbstractEntityRegistry(CreatesToken): __slots__ = ( @@ -642,6 +663,10 @@ class AbstractEntityRegistry(CreatesToken): # self.natural_path = parent.natural_path + (entity, ) self.natural_path = self.path + @property + def root_entity(self) -> _InternalEntityType[Any]: + return cast("_InternalEntityType[Any]", self.path[0]) + @property def entity_path(self) -> PathRegistry: return self @@ -653,23 +678,7 @@ class AbstractEntityRegistry(CreatesToken): def __bool__(self) -> bool: return True - @overload - def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: - ... - - @overload - def __getitem__(self, entity: str) -> TokenRegistry: - ... - - @overload - def __getitem__(self, entity: int) -> _PathElementType: - ... - - @overload - def __getitem__(self, entity: slice) -> _PathRepresentation: - ... - - def __getitem__( + def _getitem( self, entity: Any ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]: if isinstance(entity, (int, slice)): @@ -679,6 +688,9 @@ class AbstractEntityRegistry(CreatesToken): else: return PropRegistry(self, entity) + if not TYPE_CHECKING: + __getitem__ = _getitem + class SlotsEntityRegistry(AbstractEntityRegistry): # for aliased class, return lightweight, no-cycles created @@ -715,10 +727,28 @@ class CachingEntityRegistry(AbstractEntityRegistry): def pop(self, key: Any, default: Any) -> Any: return self._cache.pop(key, default) - def __getitem__(self, entity: Any) -> Any: + def _getitem(self, entity: Any) -> Any: if isinstance(entity, (int, slice)): return self.path[entity] elif isinstance(entity, PathToken): return TokenRegistry(self, entity) else: return self._cache[entity] + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +if TYPE_CHECKING: + + def path_is_entity( + path: PathRegistry, + ) -> TypeGuard[AbstractEntityRegistry]: + ... + + def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: + ... + +else: + path_is_entity = operator.attrgetter("is_entity") + path_is_property = operator.attrgetter("is_property") diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 0ca0559b45..911617d6d0 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -16,8 +16,10 @@ from __future__ import annotations from typing import Any from typing import cast +from typing import Dict from typing import List from typing import Optional +from typing import Sequence from typing import Set from typing import Type from typing import TYPE_CHECKING @@ -25,7 +27,6 @@ from typing import TypeVar from . import attributes from . import strategy_options -from .base import SQLCoreOperations from .descriptor_props import Composite from .descriptor_props import ConcreteInheritedProperty from .descriptor_props import Synonym @@ -44,20 +45,34 @@ from .. import util from ..sql import coercions from ..sql import roles from ..sql import sqltypes +from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import NoneType +from ..util.typing import Self if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .mapper import Mapper + from .session import Session + from .state import _InstallLoaderCallableProto + from .state import InstanceState from ..sql._typing import _InfoType - from ..sql.elements import KeyedColumnElement + from ..sql.elements import ColumnElement + from ..sql.elements import NamedColumn + from ..sql.operators import OperatorType + from ..util.typing import _AnnotationScanType + from ..util.typing import RODescriptorReference _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) +_NC = TypeVar("_NC", bound="NamedColumn[Any]") __all__ = [ "ColumnProperty", @@ -85,11 +100,15 @@ class ColumnProperty( inherit_cache = True _links_to_entity = False - columns: List[KeyedColumnElement[Any]] - _orig_columns: List[KeyedColumnElement[Any]] + columns: List[NamedColumn[Any]] + _orig_columns: List[NamedColumn[Any]] _is_polymorphic_discriminator: bool + _mapped_by_synonym: Optional[str] + + comparator_factory: Type[PropComparator[_T]] + __slots__ = ( "_orig_columns", "columns", @@ -100,7 +119,6 @@ class ColumnProperty( "descriptor", "active_history", "expire_on_flush", - "doc", "_creation_order", "_is_polymorphic_discriminator", "_mapped_by_synonym", @@ -117,7 +135,7 @@ class ColumnProperty( group: Optional[str] = None, deferred: bool = False, raiseload: bool = False, - comparator_factory: Optional[Type[PropComparator]] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, descriptor: Optional[Any] = None, active_history: bool = False, expire_on_flush: bool = True, @@ -150,7 +168,7 @@ class ColumnProperty( self.expire_on_flush = expire_on_flush if info is not None: - self.info = info + self.info.update(info) if doc is not None: self.doc = doc @@ -173,8 +191,13 @@ class ColumnProperty( self.strategy_key += (("raiseload", True),) def declarative_scan( - self, registry, cls, key, annotation, is_dataclass_field - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: column = self.columns[0] if column.key is None: column.key = key @@ -186,20 +209,23 @@ class ColumnProperty( return self @property - def columns_to_assign(self) -> List[Column]: + def columns_to_assign(self) -> List[Column[Any]]: + # mypy doesn't care about the isinstance here return [ - c + c # type: ignore for c in self.columns if isinstance(c, Column) and c.table is None ] - def _memoized_attr__renders_in_subqueries(self): + def _memoized_attr__renders_in_subqueries(self) -> bool: return ("deferred", True) not in self.strategy_key or ( - self not in self.parent._readonly_props + self not in self.parent._readonly_props # type: ignore ) @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") - def _memoized_attr__deferred_column_loader(self): + def _memoized_attr__deferred_column_loader( + self, + ) -> _InstallLoaderCallableProto[Any]: state = util.preloaded.orm_state strategies = util.preloaded.orm_strategies return state.InstanceState._instance_level_callable_processor( @@ -209,7 +235,9 @@ class ColumnProperty( ) @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") - def _memoized_attr__raise_column_loader(self): + def _memoized_attr__raise_column_loader( + self, + ) -> _InstallLoaderCallableProto[Any]: state = util.preloaded.orm_state strategies = util.preloaded.orm_strategies return state.InstanceState._instance_level_callable_processor( @@ -218,7 +246,7 @@ class ColumnProperty( self.key, ) - def __clause_element__(self): + def __clause_element__(self) -> roles.ColumnsClauseRole: """Allow the ColumnProperty to work in expression before it is turned into an instrumented attribute. """ @@ -226,7 +254,7 @@ class ColumnProperty( return self.expression @property - def expression(self): + def expression(self) -> roles.ColumnsClauseRole: """Return the primary column or expression for this ColumnProperty. E.g.:: @@ -247,7 +275,7 @@ class ColumnProperty( """ return self.columns[0] - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: if not self.instrument: return @@ -259,7 +287,7 @@ class ColumnProperty( doc=self.doc, ) - def do_init(self): + def do_init(self) -> None: super().do_init() if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( @@ -275,32 +303,25 @@ class ColumnProperty( % (self.parent, self.columns[1], self.columns[0], self.key) ) - def copy(self): + def copy(self) -> ColumnProperty[_T]: return ColumnProperty( + *self.columns, deferred=self.deferred, group=self.group, active_history=self.active_history, - *self.columns, - ) - - def _getcommitted( - self, state, dict_, column, passive=attributes.PASSIVE_OFF - ): - return state.get_impl(self.key).get_committed_value( - state, dict_, passive=passive ) def merge( self, - session, - source_state, - source_dict, - dest_state, - dest_dict, - load, - _recursive, - _resolve_conflict_map, - ): + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: if not self.instrument: return elif self.key in source_dict: @@ -335,9 +356,13 @@ class ColumnProperty( """ - __slots__ = "__clause_element__", "info", "expressions" + if not TYPE_CHECKING: + # prevent pylance from being clever about slots + __slots__ = "__clause_element__", "info", "expressions" + + prop: RODescriptorReference[ColumnProperty[_PT]] - def _orm_annotate_column(self, column): + def _orm_annotate_column(self, column: _NC) -> _NC: """annotate and possibly adapt a column to be returned as the mapped-attribute exposed version of the column. @@ -351,7 +376,7 @@ class ColumnProperty( """ pe = self._parententity - annotations = { + annotations: Dict[str, Any] = { "entity_namespace": pe, "parententity": pe, "parentmapper": pe, @@ -377,22 +402,29 @@ class ColumnProperty( {"compile_state_plugin": "orm", "plugin_subject": pe} ) - def _memoized_method___clause_element__(self): + if TYPE_CHECKING: + + def __clause_element__(self) -> NamedColumn[_PT]: + ... + + def _memoized_method___clause_element__( + self, + ) -> NamedColumn[_PT]: if self.adapter: return self.adapter(self.prop.columns[0], self.prop.key) else: return self._orm_annotate_column(self.prop.columns[0]) - def _memoized_attr_info(self): + def _memoized_attr_info(self) -> _InfoType: """The .info dictionary for this attribute.""" ce = self.__clause_element__() try: - return ce.info + return ce.info # type: ignore except AttributeError: return self.prop.info - def _memoized_attr_expressions(self): + def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]: """The full sequence of columns referenced by this attribute, adjusted for any aliasing in progress. @@ -409,21 +441,25 @@ class ColumnProperty( self._orm_annotate_column(col) for col in self.prop.columns ] - def _fallback_getattr(self, key): + def _fallback_getattr(self, key: str) -> Any: """proxy attribute access down to the mapped column. this allows user-defined comparison methods to be accessed. """ return getattr(self.__clause_element__(), key) - def operate(self, op, *other, **kwargs): - return op(self.__clause_element__(), *other, **kwargs) + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501 - def reverse_operate(self, op, other, **kwargs): + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: col = self.__clause_element__() - return op(col._bind_param(op, other), col, **kwargs) + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 - def __str__(self): + def __str__(self) -> str: if not self.parent or not self.key: return object.__repr__(self) return str(self.parent.class_.__name__) + "." + self.key @@ -460,7 +496,7 @@ class MappedColumn( column: Column[_T] foreign_keys: Optional[Set[ForeignKey]] - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any): self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) self.foreign_keys = self.column.foreign_keys @@ -470,8 +506,8 @@ class MappedColumn( ) util.set_creation_order(self) - def _copy(self, **kw): - new = self.__class__.__new__(self.__class__) + def _copy(self: Self, **kw: Any) -> Self: + new = cast(Self, self.__class__.__new__(self.__class__)) new.column = self.column._copy(**kw) new.deferred = self.deferred new.foreign_keys = new.column.foreign_keys @@ -487,22 +523,31 @@ class MappedColumn( return None @property - def columns_to_assign(self) -> List[Column]: + def columns_to_assign(self) -> List[Column[Any]]: return [self.column] - def __clause_element__(self): + def __clause_element__(self) -> Column[_T]: return self.column - def operate(self, op, *other, **kwargs): - return op(self.__clause_element__(), *other, **kwargs) + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501 - def reverse_operate(self, op, other, **kwargs): + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: col = self.__clause_element__() - return op(col._bind_param(op, other), col, **kwargs) + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 def declarative_scan( - self, registry, cls, key, annotation, is_dataclass_field - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: column = self.column if column.key is None: column.key = key @@ -526,38 +571,48 @@ class MappedColumn( @util.preload_module("sqlalchemy.orm.decl_base") def declarative_scan_for_composite( - self, registry, cls, key, param_name, param_annotation - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + param_name: str, + param_annotation: _AnnotationScanType, + ) -> None: decl_base = util.preloaded.orm_decl_base decl_base._undefer_column_name(param_name, self.column) self._init_column_for_annotation(cls, registry, param_annotation) - def _init_column_for_annotation(self, cls, registry, argument): + def _init_column_for_annotation( + self, + cls: Type[Any], + registry: _RegistryType, + argument: _AnnotationScanType, + ) -> None: sqltype = self.column.type nullable = False if hasattr(argument, "__origin__"): - nullable = NoneType in argument.__args__ + nullable = NoneType in argument.__args__ # type: ignore if not self._has_nullable: self.column.nullable = nullable if sqltype._isnull and not self.column.foreign_keys: - sqltype = None + new_sqltype = None our_type = de_optionalize_union_types(argument) if is_fwd_ref(our_type): our_type = de_stringify_annotation(cls, our_type) if registry.type_annotation_map: - sqltype = registry.type_annotation_map.get(our_type) - if sqltype is None: - sqltype = sqltypes._type_map_get(our_type) + new_sqltype = registry.type_annotation_map.get(our_type) + if new_sqltype is None: + new_sqltype = sqltypes._type_map_get(our_type) # type: ignore - if sqltype is None: + if new_sqltype is None: raise sa_exc.ArgumentError( f"Could not locate SQLAlchemy Core " f"type for Python type: {our_type}" ) - self.column.type = sqltype + self.column.type = new_sqltype # type: ignore diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a60a167ac8..419891708c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -23,18 +23,23 @@ from __future__ import annotations import collections.abc as collections_abc import operator from typing import Any +from typing import Callable +from typing import cast +from typing import Dict from typing import Generic from typing import Iterable from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from . import exc as orm_exc +from . import attributes from . import interfaces from . import loading from . import util as orm_util @@ -44,7 +49,6 @@ from .context import _column_descriptions from .context import _determine_last_joined_entity from .context import _legacy_filter_by_entity_zero from .context import FromStatement -from .context import LABEL_STYLE_LEGACY_ORM from .context import ORMCompileState from .context import QueryContext from .interfaces import ORMColumnDescription @@ -60,6 +64,8 @@ from .. import sql from .. import util from ..engine import Result from ..engine import Row +from ..event import dispatcher +from ..event import EventTarget from ..sql import coercions from ..sql import expression from ..sql import roles @@ -71,8 +77,10 @@ from ..sql._typing import _TP from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative +from ..sql.base import _NoArg from ..sql.base import Executable from ..sql.base import Generative +from ..sql.elements import BooleanClauseList from ..sql.expression import Exists from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements @@ -81,17 +89,31 @@ from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectLabelStyle from ..util.typing import Literal +from ..util.typing import Self if TYPE_CHECKING: from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InternalEntityType + from .mapper import Mapper + from .path_registry import PathRegistry + from .session import _PKIdentityArgument from .session import Session + from .state import InstanceState + from ..engine.cursor import CursorResult + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.result import FrozenResult from ..engine.result import ScalarResult from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _JoinTargetArgument from ..sql._typing import _MAYBE_ENTITY from ..sql._typing import _no_kw from ..sql._typing import _NOT_ENTITY + from ..sql._typing import _OnClauseArgument from ..sql._typing import _PropagateAttrsType from ..sql._typing import _T0 from ..sql._typing import _T1 @@ -102,18 +124,25 @@ if TYPE_CHECKING: from ..sql._typing import _T6 from ..sql._typing import _T7 from ..sql._typing import _TypedColumnClauseArgument as _TCCA - from ..sql.roles import TypedColumnsClauseRole + from ..sql.base import CacheableOptions + from ..sql.base import ExecutableOption + from ..sql.elements import ColumnElement + from ..sql.elements import Label + from ..sql.selectable import _JoinTargetElement from ..sql.selectable import _SetupJoinsElement from ..sql.selectable import Alias + from ..sql.selectable import CTE from ..sql.selectable import ExecutableReturnsRows + from ..sql.selectable import FromClause from ..sql.selectable import ScalarSelect from ..sql.selectable import Subquery + __all__ = ["Query", "QueryContext"] _T = TypeVar("_T", bound=Any) -SelfQuery = TypeVar("SelfQuery", bound="Query") +SelfQuery = TypeVar("SelfQuery", bound="Query[Any]") @inspection._self_inspects @@ -124,6 +153,7 @@ class Query( HasPrefixes, HasSuffixes, HasHints, + EventTarget, log.Identified, Generative, Executable, @@ -150,40 +180,47 @@ class Query( """ # elements that are in Core and can be cached in the same way - _where_criteria = () - _having_criteria = () + _where_criteria: Tuple[ColumnElement[Any], ...] = () + _having_criteria: Tuple[ColumnElement[Any], ...] = () - _order_by_clauses = () - _group_by_clauses = () - _limit_clause = None - _offset_clause = None + _order_by_clauses: Tuple[ColumnElement[Any], ...] = () + _group_by_clauses: Tuple[ColumnElement[Any], ...] = () + _limit_clause: Optional[ColumnElement[Any]] = None + _offset_clause: Optional[ColumnElement[Any]] = None - _distinct = False - _distinct_on = () + _distinct: bool = False + _distinct_on: Tuple[ColumnElement[Any], ...] = () - _for_update_arg = None - _correlate = () - _auto_correlate = True - _from_obj = () + _for_update_arg: Optional[ForUpdateArg] = None + _correlate: Tuple[FromClause, ...] = () + _auto_correlate: bool = True + _from_obj: Tuple[FromClause, ...] = () _setup_joins: Tuple[_SetupJoinsElement, ...] = () - _label_style = LABEL_STYLE_LEGACY_ORM + _label_style: SelectLabelStyle = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM _memoized_select_entities = () - _compile_options = ORMCompileState.default_compile_options + _compile_options: Union[ + Type[CacheableOptions], CacheableOptions + ] = ORMCompileState.default_compile_options + _with_options: Tuple[ExecutableOption, ...] load_options = QueryContext.default_load_options + { "_legacy_uniquing": True } - _params = util.EMPTY_DICT + _params: util.immutabledict[str, Any] = util.EMPTY_DICT # local Query builder state, not needed for # compilation or execution _enable_assertions = True - _statement = None + _statement: Optional[ExecutableReturnsRows] = None + + session: Session + + dispatch: dispatcher[Query[_T]] # mirrors that of ClauseElement, used to propagate the "orm" # plugin as well as the "subject" of the plugin, e.g. the mapper @@ -224,14 +261,23 @@ class Query( """ - self.session = session + # session is usually present. There's one case in subqueryloader + # where it stores a Query without a Session and also there are tests + # for the query(Entity).with_session(session) API which is likely in + # some old recipes, however these are legacy as select() can now be + # used. + self.session = session # type: ignore self._set_entities(entities) - def _set_propagate_attrs(self, values): - self._propagate_attrs = util.immutabledict(values) + def _set_propagate_attrs( + self: SelfQuery, values: Mapping[str, Any] + ) -> SelfQuery: + self._propagate_attrs = util.immutabledict(values) # type: ignore return self - def _set_entities(self, entities): + def _set_entities( + self, entities: Iterable[_ColumnsClauseArgument[Any]] + ) -> None: self._raw_columns = [ coercions.expect( roles.ColumnsClauseRole, @@ -242,15 +288,7 @@ class Query( for ent in util.to_list(entities) ] - @overload - def tuples(self: Query[Row[_TP]]) -> Query[_TP]: - ... - - @overload def tuples(self: Query[_O]) -> Query[Tuple[_O]]: - ... - - def tuples(self) -> Query[Any]: """return a tuple-typed form of this :class:`.Query`. This method invokes the :meth:`.Query.only_return_tuples` @@ -270,29 +308,27 @@ class Query( .. versionadded:: 2.0 """ - return self.only_return_tuples(True) + return self.only_return_tuples(True) # type: ignore - def _entity_from_pre_ent_zero(self): + def _entity_from_pre_ent_zero(self) -> Optional[_InternalEntityType[Any]]: if not self._raw_columns: return None ent = self._raw_columns[0] if "parententity" in ent._annotations: - return ent._annotations["parententity"] - elif isinstance(ent, ORMColumnsClauseRole): - return ent.entity + return ent._annotations["parententity"] # type: ignore elif "bundle" in ent._annotations: - return ent._annotations["bundle"] + return ent._annotations["bundle"] # type: ignore else: # label, other SQL expression for element in visitors.iterate(ent): if "parententity" in element._annotations: - return element._annotations["parententity"] + return element._annotations["parententity"] # type: ignore # noqa: E501 else: return None - def _only_full_mapper_zero(self, methname): + def _only_full_mapper_zero(self, methname: str) -> Mapper[Any]: if ( len(self._raw_columns) != 1 or "parententity" not in self._raw_columns[0]._annotations @@ -303,9 +339,11 @@ class Query( "a single mapped class." % methname ) - return self._raw_columns[0]._annotations["parententity"] + return self._raw_columns[0]._annotations["parententity"] # type: ignore # noqa: E501 - def _set_select_from(self, obj, set_base_alias): + def _set_select_from( + self, obj: Iterable[_FromClauseArgument], set_base_alias: bool + ) -> None: fa = [ coercions.expect( roles.StrictFromClauseRole, @@ -320,19 +358,22 @@ class Query( self._from_obj = tuple(fa) @_generative - def _set_lazyload_from(self: SelfQuery, state) -> SelfQuery: + def _set_lazyload_from( + self: SelfQuery, state: InstanceState[Any] + ) -> SelfQuery: self.load_options += {"_lazy_loaded_from": state} return self - def _get_condition(self): - return self._no_criterion_condition( - "get", order_by=False, distinct=False - ) + def _get_condition(self) -> None: + """used by legacy BakedQuery""" + self._no_criterion_condition("get", order_by=False, distinct=False) - def _get_existing_condition(self): + def _get_existing_condition(self) -> None: self._no_criterion_assertion("get", order_by=False, distinct=False) - def _no_criterion_assertion(self, meth, order_by=True, distinct=True): + def _no_criterion_assertion( + self, meth: str, order_by: bool = True, distinct: bool = True + ) -> None: if not self._enable_assertions: return if ( @@ -351,7 +392,9 @@ class Query( "Query with existing criterion. " % meth ) - def _no_criterion_condition(self, meth, order_by=True, distinct=True): + def _no_criterion_condition( + self, meth: str, order_by: bool = True, distinct: bool = True + ) -> None: self._no_criterion_assertion(meth, order_by, distinct) self._from_obj = self._setup_joins = () @@ -362,7 +405,7 @@ class Query( self._order_by_clauses = self._group_by_clauses = () - def _no_clauseelement_condition(self, meth): + def _no_clauseelement_condition(self, meth: str) -> None: if not self._enable_assertions: return if self._order_by_clauses: @@ -372,7 +415,7 @@ class Query( ) self._no_criterion_condition(meth) - def _no_statement_condition(self, meth): + def _no_statement_condition(self, meth: str) -> None: if not self._enable_assertions: return if self._statement is not None: @@ -384,7 +427,7 @@ class Query( % meth ) - def _no_limit_offset(self, meth): + def _no_limit_offset(self, meth: str) -> None: if not self._enable_assertions: return if self._limit_clause is not None or self._offset_clause is not None: @@ -395,21 +438,21 @@ class Query( ) @property - def _has_row_limiting_clause(self): + def _has_row_limiting_clause(self) -> bool: return ( self._limit_clause is not None or self._offset_clause is not None ) def _get_options( - self, - populate_existing=None, - version_check=None, - only_load_props=None, - refresh_state=None, - identity_token=None, - ): - load_options = {} - compile_options = {} + self: SelfQuery, + populate_existing: Optional[bool] = None, + version_check: Optional[bool] = None, + only_load_props: Optional[Sequence[str]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + identity_token: Optional[Any] = None, + ) -> SelfQuery: + load_options: Dict[str, Any] = {} + compile_options: Dict[str, Any] = {} if version_check: load_options["_version_check"] = version_check @@ -430,11 +473,18 @@ class Query( return self - def _clone(self): - return self._generate() + def _clone(self: Self, **kw: Any) -> Self: + return self._generate() # type: ignore + + def _get_select_statement_only(self) -> Select[_T]: + if self._statement is not None: + raise sa_exc.InvalidRequestError( + "Can't call this method on a Query that uses from_statement()" + ) + return cast("Select[_T]", self.statement) @property - def statement(self): + def statement(self) -> Union[Select[_T], FromStatement[_T]]: """The full SELECT statement represented by this Query. The statement by default will not have disambiguating labels @@ -474,14 +524,15 @@ class Query( return stmt - def _final_statement(self, legacy_query_style=True): + def _final_statement(self, legacy_query_style: bool = True) -> Select[Any]: """Return the 'final' SELECT statement for this :class:`.Query`. + This is used by the testing suite only and is fairly inefficient. + This is the Core-only select() that will be rendered by a complete compilation of this query, and is what .statement used to return in 1.3. - This method creates a complete compile state so is fairly expensive. """ @@ -489,9 +540,11 @@ class Query( return q._compile_state( use_legacy_query_style=legacy_query_style - ).statement + ).statement # type: ignore - def _statement_20(self, for_statement=False, use_legacy_query_style=True): + def _statement_20( + self, for_statement: bool = False, use_legacy_query_style: bool = True + ) -> Union[Select[_T], FromStatement[_T]]: # TODO: this event needs to be deprecated, as it currently applies # only to ORM query and occurs at this spot that is now more # or less an artificial spot @@ -500,7 +553,7 @@ class Query( new_query = fn(self) if new_query is not None and new_query is not self: self = new_query - if not fn._bake_ok: + if not fn._bake_ok: # type: ignore self._compile_options += {"_bake_ok": False} compile_options = self._compile_options @@ -509,6 +562,8 @@ class Query( "_use_legacy_query_style": use_legacy_query_style, } + stmt: Union[Select[_T], FromStatement[_T]] + if self._statement is not None: stmt = FromStatement(self._raw_columns, self._statement) stmt.__dict__.update( @@ -541,10 +596,10 @@ class Query( def subquery( self, - name=None, - with_labels=False, - reduce_columns=False, - ): + name: Optional[str] = None, + with_labels: bool = False, + reduce_columns: bool = False, + ) -> Subquery: """Return the full SELECT statement represented by this :class:`_query.Query`, embedded within an :class:`_expression.Alias`. @@ -571,13 +626,21 @@ class Query( if with_labels: q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) - q = q.statement + stmt = q._get_select_statement_only() + + if TYPE_CHECKING: + assert isinstance(stmt, Select) if reduce_columns: - q = q.reduce_columns() - return q.alias(name=name) + stmt = stmt.reduce_columns() + return stmt.subquery(name=name) - def cte(self, name=None, recursive=False, nesting=False): + def cte( + self, + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + ) -> CTE: r"""Return the full SELECT statement represented by this :class:`_query.Query` represented as a common table expression (CTE). @@ -632,11 +695,13 @@ class Query( :meth:`_expression.HasCTE.cte` """ - return self.enable_eagerloads(False).statement.cte( - name=name, recursive=recursive, nesting=nesting + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .cte(name=name, recursive=recursive, nesting=nesting) ) - def label(self, name): + def label(self, name: Optional[str]) -> Label[Any]: """Return the full SELECT statement represented by this :class:`_query.Query`, converted to a scalar subquery with a label of the given name. @@ -645,7 +710,11 @@ class Query( """ - return self.enable_eagerloads(False).statement.label(name) + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .label(name) + ) @overload def as_scalar( @@ -704,10 +773,14 @@ class Query( """ - return self.enable_eagerloads(False).statement.scalar_subquery() + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .scalar_subquery() + ) @property - def selectable(self): + def selectable(self) -> Union[Select[_T], FromStatement[_T]]: """Return the :class:`_expression.Select` object emitted by this :class:`_query.Query`. @@ -718,7 +791,7 @@ class Query( """ return self.__clause_element__() - def __clause_element__(self): + def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]: return ( self._with_compile_options( _enable_eagerloads=False, _render_for_subquery=True @@ -759,7 +832,7 @@ class Query( return self @property - def is_single_entity(self): + def is_single_entity(self) -> bool: """Indicates if this :class:`_query.Query` returns tuples or single entities. @@ -785,7 +858,7 @@ class Query( ) @_generative - def enable_eagerloads(self: SelfQuery, value) -> SelfQuery: + def enable_eagerloads(self: SelfQuery, value: bool) -> SelfQuery: """Control whether or not eager joins and subqueries are rendered. @@ -804,7 +877,7 @@ class Query( return self @_generative - def _with_compile_options(self: SelfQuery, **opt) -> SelfQuery: + def _with_compile_options(self: SelfQuery, **opt: Any) -> SelfQuery: self._compile_options += opt return self @@ -813,13 +886,15 @@ class Query( alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) " "instead.", ) - def with_labels(self): - return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + def with_labels(self: SelfQuery) -> SelfQuery: + return self.set_label_style( + SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL + ) apply_labels = with_labels @property - def get_label_style(self): + def get_label_style(self) -> SelectLabelStyle: """ Retrieve the current label style. @@ -828,7 +903,7 @@ class Query( """ return self._label_style - def set_label_style(self, style): + def set_label_style(self: SelfQuery, style: SelectLabelStyle) -> SelfQuery: """Apply column labels to the return value of Query.statement. Indicates that this Query's `statement` accessor should return @@ -864,7 +939,7 @@ class Query( return self @_generative - def enable_assertions(self: SelfQuery, value) -> SelfQuery: + def enable_assertions(self: SelfQuery, value: bool) -> SelfQuery: """Control whether assertions are generated. When set to False, the returned Query will @@ -887,7 +962,7 @@ class Query( return self @property - def whereclause(self): + def whereclause(self) -> Optional[ColumnElement[bool]]: """A readonly attribute which returns the current WHERE criterion for this Query. @@ -895,12 +970,12 @@ class Query( criterion has been established. """ - return sql.elements.BooleanClauseList._construct_for_whereclause( + return BooleanClauseList._construct_for_whereclause( self._where_criteria ) @_generative - def _with_current_path(self: SelfQuery, path) -> SelfQuery: + def _with_current_path(self: SelfQuery, path: PathRegistry) -> SelfQuery: """indicate that this query applies to objects loaded within a certain path. @@ -913,7 +988,7 @@ class Query( return self @_generative - def yield_per(self: SelfQuery, count) -> SelfQuery: + def yield_per(self: SelfQuery, count: int) -> SelfQuery: r"""Yield only ``count`` rows at a time. The purpose of this method is when fetching very large result sets @@ -938,7 +1013,7 @@ class Query( ":meth:`_orm.Query.get`", alternative="The method is now available as :meth:`_orm.Session.get`", ) - def get(self, ident): + def get(self, ident: _PKIdentityArgument) -> Optional[Any]: """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -1022,7 +1097,12 @@ class Query( # it return self._get_impl(ident, loading.load_on_pk_identity) - def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): + def _get_impl( + self, + primary_key_identity: _PKIdentityArgument, + db_load_fn: Callable[..., Any], + identity_token: Optional[Any] = None, + ) -> Optional[Any]: mapper = self._only_full_mapper_zero("get") return self.session._get_impl( mapper, @@ -1036,7 +1116,7 @@ class Query( ) @property - def lazy_loaded_from(self): + def lazy_loaded_from(self) -> Optional[InstanceState[Any]]: """An :class:`.InstanceState` that is using this :class:`_query.Query` for a lazy load operation. @@ -1050,14 +1130,17 @@ class Query( :attr:`.ORMExecuteState.lazy_loaded_from` """ - return self.load_options._lazy_loaded_from + return self.load_options._lazy_loaded_from # type: ignore @property - def _current_path(self): - return self._compile_options._current_path + def _current_path(self) -> PathRegistry: + return self._compile_options._current_path # type: ignore @_generative - def correlate(self: SelfQuery, *fromclauses) -> SelfQuery: + def correlate( + self: SelfQuery, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfQuery: """Return a :class:`.Query` construct which will correlate the given FROM clauses to that of an enclosing :class:`.Query` or :func:`~.expression.select`. @@ -1082,13 +1165,13 @@ class Query( if fromclauses and fromclauses[0] in {None, False}: self._correlate = () else: - self._correlate = set(self._correlate).union( + self._correlate = self._correlate + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) return self @_generative - def autoflush(self: SelfQuery, setting) -> SelfQuery: + def autoflush(self: SelfQuery, setting: bool) -> SelfQuery: """Return a Query with a specific 'autoflush' setting. As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method @@ -1116,7 +1199,7 @@ class Query( return self @_generative - def _with_invoke_all_eagers(self: SelfQuery, value) -> SelfQuery: + def _with_invoke_all_eagers(self: SelfQuery, value: bool) -> SelfQuery: """Set the 'invoke all eagers' flag which causes joined- and subquery loaders to traverse into already-loaded related objects and collections. @@ -1132,7 +1215,14 @@ class Query( alternative="Use the :func:`_orm.with_parent` standalone construct.", ) @util.preload_module("sqlalchemy.orm.relationships") - def with_parent(self, instance, property=None, from_entity=None): # noqa + def with_parent( + self: SelfQuery, + instance: object, + property: Optional[ # noqa: A002 + attributes.QueryableAttribute[Any] + ] = None, + from_entity: Optional[_ExternalEntityType[Any]] = None, + ) -> SelfQuery: """Add filtering criterion that relates the given instance to a child object or collection, using its attribute state as well as an established :func:`_orm.relationship()` @@ -1150,7 +1240,7 @@ class Query( An instance which has some :func:`_orm.relationship`. :param property: - String property name, or class-bound attribute, which indicates + Class bound attribute which indicates what relationship from the instance should be used to reconcile the parent/child relationship. @@ -1172,21 +1262,27 @@ class Query( for prop in mapper.iterate_properties: if ( isinstance(prop, relationships.Relationship) - and prop.mapper is entity_zero.mapper + and prop.mapper is entity_zero.mapper # type: ignore ): - property = prop # noqa + property = prop # type: ignore # noqa: A001 break else: raise sa_exc.InvalidRequestError( "Could not locate a property which relates instances " "of class '%s' to instances of class '%s'" % ( - entity_zero.mapper.class_.__name__, + entity_zero.mapper.class_.__name__, # type: ignore instance.__class__.__name__, ) ) - return self.filter(with_parent(instance, property, entity_zero.entity)) + return self.filter( + with_parent( + instance, + property, # type: ignore + entity_zero.entity, # type: ignore + ) + ) @_generative def add_entity( @@ -1211,7 +1307,7 @@ class Query( return self @_generative - def with_session(self: SelfQuery, session) -> SelfQuery: + def with_session(self: SelfQuery, session: Session) -> SelfQuery: """Return a :class:`_query.Query` that will use the given :class:`.Session`. @@ -1237,7 +1333,9 @@ class Query( self.session = session return self - def _legacy_from_self(self, *entities): + def _legacy_from_self( + self: SelfQuery, *entities: _ColumnsClauseArgument[Any] + ) -> SelfQuery: # used for query.count() as well as for the same # function in BakedQuery, as well as some old tests in test_baked.py. @@ -1255,13 +1353,13 @@ class Query( return q @_generative - def _set_enable_single_crit(self: SelfQuery, val) -> SelfQuery: + def _set_enable_single_crit(self: SelfQuery, val: bool) -> SelfQuery: self._compile_options += {"_enable_single_crit": val} return self @_generative def _from_selectable( - self: SelfQuery, fromclause, set_entity_from=True + self: SelfQuery, fromclause: FromClause, set_entity_from: bool = True ) -> SelfQuery: for attr in ( "_where_criteria", @@ -1292,7 +1390,7 @@ class Query( "is deprecated and will be removed in a " "future release. Please use :meth:`_query.Query.with_entities`", ) - def values(self, *columns): + def values(self, *columns: _ColumnsClauseArgument[Any]) -> Iterable[Any]: """Return an iterator yielding result tuples corresponding to the given list of columns @@ -1304,7 +1402,7 @@ class Query( q._set_entities(columns) if not q.load_options._yield_per: q.load_options += {"_yield_per": 10} - return iter(q) + return iter(q) # type: ignore _values = values @@ -1315,25 +1413,24 @@ class Query( "future release. Please use :meth:`_query.Query.with_entities` " "in combination with :meth:`_query.Query.scalar`", ) - def value(self, column): + def value(self, column: _ColumnExpressionArgument[Any]) -> Any: """Return a scalar result corresponding to the given column expression. """ try: - return next(self.values(column))[0] + return next(self.values(column))[0] # type: ignore except StopIteration: return None @overload - def with_entities( - self, _entity: _EntityType[_O], **kwargs: Any - ) -> Query[_O]: + def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: ... @overload def with_entities( - self, _colexpr: TypedColumnsClauseRole[_T] + self, + _colexpr: roles.TypedColumnsClauseRole[_T], ) -> RowReturningQuery[Tuple[_T]]: ... @@ -1418,14 +1515,14 @@ class Query( @overload def with_entities( - self: SelfQuery, *entities: _ColumnsClauseArgument[Any] - ) -> SelfQuery: + self, *entities: _ColumnsClauseArgument[Any] + ) -> Query[Any]: ... @_generative def with_entities( - self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any - ) -> SelfQuery: + self, *entities: _ColumnsClauseArgument[Any], **__kw: Any + ) -> Query[Any]: r"""Return a new :class:`_query.Query` replacing the SELECT list with the given entities. @@ -1451,12 +1548,18 @@ class Query( """ if __kw: raise _no_kw() - _MemoizedSelectEntities._generate_for_statement(self) + + # Query has all the same fields as Select for this operation + # this could in theory be based on a protocol but not sure if it's + # worth it + _MemoizedSelectEntities._generate_for_statement(self) # type: ignore self._set_entities(entities) return self @_generative - def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]: + def add_columns( + self, *column: _ColumnExpressionArgument[Any] + ) -> Query[Any]: """Add one or more column expressions to the list of result columns to be returned.""" @@ -1479,7 +1582,7 @@ class Query( "is deprecated and will be removed in a " "future release. Please use :meth:`_query.Query.add_columns`", ) - def add_column(self, column) -> Query[Any]: + def add_column(self, column: _ColumnExpressionArgument[Any]) -> Query[Any]: """Add a column expression to the list of result columns to be returned. @@ -1487,7 +1590,7 @@ class Query( return self.add_columns(column) @_generative - def options(self: SelfQuery, *args) -> SelfQuery: + def options(self: SelfQuery, *args: ExecutableOption) -> SelfQuery: """Return a new :class:`_query.Query` object, applying the given list of mapper options. @@ -1505,18 +1608,21 @@ class Query( opts = tuple(util.flatten_iterator(args)) if self._compile_options._current_path: + # opting for lower method overhead for the checks for opt in opts: - if opt._is_legacy_option: - opt.process_query_conditionally(self) + if not opt._is_core and opt._is_legacy_option: # type: ignore + opt.process_query_conditionally(self) # type: ignore else: for opt in opts: - if opt._is_legacy_option: - opt.process_query(self) + if not opt._is_core and opt._is_legacy_option: # type: ignore + opt.process_query(self) # type: ignore self._with_options += opts return self - def with_transformation(self, fn): + def with_transformation( + self, fn: Callable[[Query[Any]], Query[Any]] + ) -> Query[Any]: """Return a new :class:`_query.Query` object transformed by the given function. @@ -1535,7 +1641,7 @@ class Query( """ return fn(self) - def get_execution_options(self): + def get_execution_options(self) -> _ImmutableExecuteOptions: """Get the non-SQL options which will take effect during execution. .. versionadded:: 1.3 @@ -1547,7 +1653,7 @@ class Query( return self._execution_options @_generative - def execution_options(self: SelfQuery, **kwargs) -> SelfQuery: + def execution_options(self: SelfQuery, **kwargs: Any) -> SelfQuery: """Set non-SQL options which take effect during execution. Options allowed here include all of those accepted by @@ -1596,11 +1702,17 @@ class Query( @_generative def with_for_update( self: SelfQuery, - read=False, - nowait=False, - of=None, - skip_locked=False, - key_share=False, + *, + nowait: bool = False, + read: bool = False, + of: Optional[ + Union[ + _ColumnExpressionArgument[Any], + Sequence[_ColumnExpressionArgument[Any]], + ] + ] = None, + skip_locked: bool = False, + key_share: bool = False, ) -> SelfQuery: """return a new :class:`_query.Query` with the specified options for the @@ -1659,7 +1771,9 @@ class Query( return self @_generative - def params(self: SelfQuery, *args, **kwargs) -> SelfQuery: + def params( + self: SelfQuery, __params: Optional[Dict[str, Any]] = None, **kw: Any + ) -> SelfQuery: r"""Add values for bind parameters which may have been specified in filter(). @@ -1669,17 +1783,14 @@ class Query( contain unicode keys in which case \**kwargs cannot be used. """ - if len(args) == 1: - kwargs.update(args[0]) - elif len(args) > 0: - raise sa_exc.ArgumentError( - "params() takes zero or one positional argument, " - "which is a dictionary." - ) - self._params = self._params.union(kwargs) + if __params: + kw.update(__params) + self._params = self._params.union(kw) return self - def where(self: SelfQuery, *criterion) -> SelfQuery: + def where( + self: SelfQuery, *criterion: _ColumnExpressionArgument[bool] + ) -> SelfQuery: """A synonym for :meth:`.Query.filter`. .. versionadded:: 1.4 @@ -1716,16 +1827,18 @@ class Query( :meth:`_query.Query.filter_by` - filter on keyword expressions. """ - for criterion in list(criterion): - criterion = coercions.expect( - roles.WhereHavingRole, criterion, apply_propagate_attrs=self + for crit in list(criterion): + crit = coercions.expect( + roles.WhereHavingRole, crit, apply_propagate_attrs=self ) - self._where_criteria += (criterion,) + self._where_criteria += (crit,) return self @util.memoized_property - def _last_joined_entity(self): + def _last_joined_entity( + self, + ) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]: if self._setup_joins: return _determine_last_joined_entity( self._setup_joins, @@ -1733,7 +1846,7 @@ class Query( else: return None - def _filter_by_zero(self): + def _filter_by_zero(self) -> Any: """for the filter_by() method, return the target entity for which we will attempt to derive an expression from based on string name. @@ -1800,13 +1913,6 @@ class Query( """ from_entity = self._filter_by_zero() - if from_entity is None: - raise sa_exc.InvalidRequestError( - "Can't use filter_by when the first entity '%s' of a query " - "is not a mapped class. Please use the filter method instead, " - "or change the order of the entities in the query" - % self._query_entity_zero() - ) clauses = [ _entity_namespace_key(from_entity, key) == value @@ -1815,9 +1921,12 @@ class Query( return self.filter(*clauses) @_generative - @_assertions(_no_statement_condition, _no_limit_offset) def order_by( - self: SelfQuery, *clauses: _ColumnExpressionArgument[Any] + self: SelfQuery, + __first: Union[ + Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionArgument[Any], ) -> SelfQuery: """Apply one or more ORDER BY criteria to the query and return the newly resulting :class:`_query.Query`. @@ -1844,20 +1953,27 @@ class Query( """ - if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False): + for assertion in (self._no_statement_condition, self._no_limit_offset): + assertion("order_by") + + if not clauses and (__first is None or __first is False): self._order_by_clauses = () - else: + elif __first is not _NoArg.NO_ARG: criterion = tuple( coercions.expect(roles.OrderByRole, clause) - for clause in clauses + for clause in (__first,) + clauses ) self._order_by_clauses += criterion + return self @_generative - @_assertions(_no_statement_condition, _no_limit_offset) def group_by( - self: SelfQuery, *clauses: _ColumnExpressionArgument[Any] + self: SelfQuery, + __first: Union[ + Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionArgument[Any], ) -> SelfQuery: """Apply one or more GROUP BY criterion to the query and return the newly resulting :class:`_query.Query`. @@ -1878,12 +1994,15 @@ class Query( """ - if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False): + for assertion in (self._no_statement_condition, self._no_limit_offset): + assertion("group_by") + + if not clauses and (__first is None or __first is False): self._group_by_clauses = () - else: + elif __first is not _NoArg.NO_ARG: criterion = tuple( coercions.expect(roles.GroupByRole, clause) - for clause in clauses + for clause in (__first,) + clauses ) self._group_by_clauses += criterion return self @@ -1916,8 +2035,9 @@ class Query( self._having_criteria += (having_criteria,) return self - def _set_op(self, expr_fn, *q): - return self._from_selectable(expr_fn(*([self] + list(q))).subquery()) + def _set_op(self: SelfQuery, expr_fn: Any, *q: Query[Any]) -> SelfQuery: + list_of_queries = (self,) + q + return self._from_selectable(expr_fn(*(list_of_queries)).subquery()) def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce a UNION of this Query against one or more queries. @@ -2006,7 +2126,12 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) def join( - self: SelfQuery, target, onclause=None, *, isouter=False, full=False + self: SelfQuery, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + isouter: bool = False, + full: bool = False, ) -> SelfQuery: r"""Create a SQL JOIN against this :class:`_query.Query` object's criterion @@ -2193,20 +2318,23 @@ class Query( """ - target = coercions.expect( + join_target = coercions.expect( roles.JoinTargetRole, target, apply_propagate_attrs=self, legacy=True, ) if onclause is not None: - onclause = coercions.expect( + onclause_element = coercions.expect( roles.OnClauseRole, onclause, legacy=True ) + else: + onclause_element = None + self._setup_joins += ( ( - target, - onclause, + join_target, + onclause_element, None, { "isouter": isouter, @@ -2218,7 +2346,13 @@ class Query( self.__dict__.pop("_last_joined_entity", None) return self - def outerjoin(self, target, onclause=None, *, full=False): + def outerjoin( + self: SelfQuery, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> SelfQuery: """Create a left outer join against this ``Query`` object's criterion and apply generatively, returning the newly resulting ``Query``. @@ -2295,7 +2429,7 @@ class Query( self._set_select_from(from_obj, False) return self - def __getitem__(self, item): + def __getitem__(self, item: Any) -> Any: return orm_util._getitem( self, item, @@ -2303,7 +2437,11 @@ class Query( @_generative @_assertions(_no_statement_condition) - def slice(self: SelfQuery, start, stop) -> SelfQuery: + def slice( + self: SelfQuery, + start: int, + stop: int, + ) -> SelfQuery: """Computes the "slice" of the :class:`_query.Query` represented by the given indices and returns the resulting :class:`_query.Query`. @@ -2341,7 +2479,9 @@ class Query( @_generative @_assertions(_no_statement_condition) - def limit(self: SelfQuery, limit) -> SelfQuery: + def limit( + self: SelfQuery, limit: Union[int, _ColumnExpressionArgument[int]] + ) -> SelfQuery: """Apply a ``LIMIT`` to the query and return the newly resulting ``Query``. @@ -2351,7 +2491,9 @@ class Query( @_generative @_assertions(_no_statement_condition) - def offset(self: SelfQuery, offset) -> SelfQuery: + def offset( + self: SelfQuery, offset: Union[int, _ColumnExpressionArgument[int]] + ) -> SelfQuery: """Apply an ``OFFSET`` to the query and return the newly resulting ``Query``. @@ -2361,7 +2503,9 @@ class Query( @_generative @_assertions(_no_statement_condition) - def distinct(self: SelfQuery, *expr) -> SelfQuery: + def distinct( + self: SelfQuery, *expr: _ColumnExpressionArgument[Any] + ) -> SelfQuery: r"""Apply a ``DISTINCT`` to the query and return the newly resulting ``Query``. @@ -2415,7 +2559,7 @@ class Query( :ref:`faq_query_deduplicating` """ - return self._iter().all() + return self._iter().all() # type: ignore @_generative @_assertions(_no_clauseelement_condition) @@ -2462,9 +2606,9 @@ class Query( """ # replicates limit(1) behavior if self._statement is not None: - return self._iter().first() + return self._iter().first() # type: ignore else: - return self.limit(1)._iter().first() + return self.limit(1)._iter().first() # type: ignore def one_or_none(self) -> Optional[_T]: """Return at most one result or raise an exception. @@ -2490,7 +2634,7 @@ class Query( :meth:`_query.Query.one` """ - return self._iter().one_or_none() + return self._iter().one_or_none() # type: ignore def one(self) -> _T: """Return exactly one result or raise an exception. @@ -2537,18 +2681,18 @@ class Query( if not isinstance(ret, collections_abc.Sequence): return ret return ret[0] - except orm_exc.NoResultFound: + except sa_exc.NoResultFound: return None def __iter__(self) -> Iterable[_T]: - return self._iter().__iter__() + return self._iter().__iter__() # type: ignore def _iter(self) -> Union[ScalarResult[_T], Result[_T]]: # new style execution. params = self._params statement = self._statement_20() - result = self.session.execute( + result: Union[ScalarResult[_T], Result[_T]] = self.session.execute( statement, params, execution_options={"_sa_orm_load_options": self.load_options}, @@ -2556,7 +2700,7 @@ class Query( # legacy: automatically set scalars, unique if result._attributes.get("is_single_entity", False): - result = result.scalars() + result = cast("Result[_T]", result).scalars() if ( result._attributes.get("filtered", False) @@ -2580,7 +2724,7 @@ class Query( return str(statement.compile(bind)) - def _get_bind_args(self, statement, fn, **kw): + def _get_bind_args(self, statement: Any, fn: Any, **kw: Any) -> Any: return fn(clause=statement, **kw) @property @@ -2634,7 +2778,11 @@ class Query( return _column_descriptions(self, legacy=True) - def instances(self, result_proxy: Result, context=None) -> Any: + def instances( + self, + result_proxy: CursorResult[Any], + context: Optional[QueryContext] = None, + ) -> Any: """Return an ORM result given a :class:`_engine.CursorResult` and :class:`.QueryContext`. @@ -2661,7 +2809,7 @@ class Query( # legacy: automatically set scalars, unique if result._attributes.get("is_single_entity", False): - result = result.scalars() + result = result.scalars() # type: ignore if result._attributes.get("filtered", False): result = result.unique() @@ -2675,7 +2823,13 @@ class Query( ":func:`_orm.merge_frozen_result` function.", enable_warnings=False, # warnings occur via loading.merge_result ) - def merge_result(self, iterator, load=True): + def merge_result( + self, + iterator: Union[ + FrozenResult[Any], Iterable[Sequence[Any]], Iterable[object] + ], + load: bool = True, + ) -> Union[FrozenResult[Any], Iterable[Any]]: """Merge a result into this :class:`_query.Query` object's Session. Given an iterator returned by a :class:`_query.Query` @@ -2743,7 +2897,8 @@ class Query( self.enable_eagerloads(False) .add_columns(sql.literal_column("1")) .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) - .statement.with_only_columns(1) + ._get_select_statement_only() + .with_only_columns(1) ) ezero = self._entity_from_pre_ent_zero() @@ -2752,7 +2907,7 @@ class Query( return sql.exists(inner) - def count(self): + def count(self) -> int: r"""Return a count of rows this the SQL formed by this :class:`Query` would return. @@ -2806,9 +2961,11 @@ class Query( """ col = sql.func.count(sql.literal_column("*")) - return self._legacy_from_self(col).enable_eagerloads(False).scalar() + return ( # type: ignore + self._legacy_from_self(col).enable_eagerloads(False).scalar() + ) - def delete(self, synchronize_session="evaluate"): + def delete(self, synchronize_session: str = "evaluate") -> int: r"""Perform a DELETE with an arbitrary WHERE clause. Deletes rows matched by this query from the database. @@ -2850,20 +3007,28 @@ class Query( self = bulk_del.query - delete_ = sql.delete(*self._raw_columns) + delete_ = sql.delete(*self._raw_columns) # type: ignore delete_._where_criteria = self._where_criteria - result = self.session.execute( - delete_, - self._params, - execution_options={"synchronize_session": synchronize_session}, + result: CursorResult[Any] = cast( + "CursorResult[Any]", + self.session.execute( + delete_, + self._params, + execution_options={"synchronize_session": synchronize_session}, + ), ) - bulk_del.result = result + bulk_del.result = result # type: ignore self.session.dispatch.after_bulk_delete(bulk_del) result.close() return result.rowcount - def update(self, values, synchronize_session="evaluate", update_args=None): + def update( + self, + values: Dict[_DMLColumnArgument, Any], + synchronize_session: str = "evaluate", + update_args: Optional[Dict[Any, Any]] = None, + ) -> int: r"""Perform an UPDATE with an arbitrary WHERE clause. Updates rows matched by this query in the database. @@ -2926,28 +3091,33 @@ class Query( bulk_ud.query = new_query self = bulk_ud.query - upd = sql.update(*self._raw_columns) + upd = sql.update(*self._raw_columns) # type: ignore ppo = update_args.pop("preserve_parameter_order", False) if ppo: - upd = upd.ordered_values(*values) + upd = upd.ordered_values(*values) # type: ignore else: upd = upd.values(values) if update_args: upd = upd.with_dialect_options(**update_args) upd._where_criteria = self._where_criteria - result = self.session.execute( - upd, - self._params, - execution_options={"synchronize_session": synchronize_session}, + result: CursorResult[Any] = cast( + "CursorResult[Any]", + self.session.execute( + upd, + self._params, + execution_options={"synchronize_session": synchronize_session}, + ), ) - bulk_ud.result = result + bulk_ud.result = result # type: ignore self.session.dispatch.after_bulk_update(bulk_ud) result.close() return result.rowcount - def _compile_state(self, for_statement=False, **kw): + def _compile_state( + self, for_statement: bool = False, **kw: Any + ) -> ORMCompileState: """Create an out-of-compiler ORMCompileState object. The ORMCompileState object is normally created directly as a result @@ -2971,13 +3141,14 @@ class Query( # ORMSelectCompileState. We could also base this on # query._statement is not None as we have the ORM Query here # however this is the more general path. - compile_state_cls = ORMCompileState._get_plugin_class_for_plugin( - stmt, "orm" + compile_state_cls = cast( + ORMCompileState, + ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"), ) return compile_state_cls.create_for_statement(stmt, None) - def _compile_context(self, for_statement=False): + def _compile_context(self, for_statement: bool = False) -> QueryContext: compile_state = self._compile_state(for_statement=for_statement) context = QueryContext( compile_state, @@ -3006,7 +3177,7 @@ class AliasOption(interfaces.LoaderOption): """ - def process_compile_state(self, compile_state: ORMCompileState): + def process_compile_state(self, compile_state: ORMCompileState) -> None: pass @@ -3017,12 +3188,12 @@ class BulkUD: """ - def __init__(self, query): + def __init__(self, query: Query[Any]): self.query = query.enable_eagerloads(False) self._validate_query_state() self.mapper = self.query._entity_from_pre_ent_zero() - def _validate_query_state(self): + def _validate_query_state(self) -> None: for attr, methname, notset, op in ( ("_limit_clause", "limit()", None, operator.is_), ("_offset_clause", "offset()", None, operator.is_), @@ -3049,14 +3220,19 @@ class BulkUD: ) @property - def session(self): + def session(self) -> Session: return self.query.session class BulkUpdate(BulkUD): """BulkUD which handles UPDATEs.""" - def __init__(self, query, values, update_kwargs): + def __init__( + self, + query: Query[Any], + values: Dict[_DMLColumnArgument, Any], + update_kwargs: Optional[Dict[Any, Any]], + ): super(BulkUpdate, self).__init__(query) self.values = values self.update_kwargs = update_kwargs @@ -3067,4 +3243,7 @@ class BulkDelete(BulkUD): class RowReturningQuery(Query[Row[_TP]]): - pass + if TYPE_CHECKING: + + def tuples(self) -> Query[_TP]: # type: ignore + ... diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 8273775ae1..1186f0f541 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -17,13 +17,23 @@ from __future__ import annotations import collections from collections import abc +import dataclasses import re import typing from typing import Any from typing import Callable +from typing import cast +from typing import Collection from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NamedTuple +from typing import NoReturn from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TypeVar @@ -32,14 +42,19 @@ import weakref from . import attributes from . import strategy_options +from ._typing import insp_is_aliased_class +from ._typing import is_has_collection_adapter from .base import _is_mapped_class from .base import class_mapper +from .base import LoaderCallableStatus +from .base import PassiveFlag from .base import state_str from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY from .interfaces import MANYTOONE from .interfaces import ONETOMANY from .interfaces import PropComparator +from .interfaces import RelationshipDirection from .interfaces import StrategizedProperty from .util import _extract_mapped_subtype from .util import _orm_annotate @@ -60,6 +75,7 @@ from ..sql import visitors from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _HasClauseElement from ..sql.elements import ColumnClause +from ..sql.elements import ColumnElement from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate from ..sql.util import adapt_criterion_to_null @@ -71,15 +87,42 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict from ._typing import _InternalEntityType + from ._typing import _O + from ._typing import _RegistryType + from .clsregistry import _class_resolver + from .clsregistry import _ModNS + from .dependency import DependencyProcessor from .mapper import Mapper + from .query import Query + from .session import Session + from .state import InstanceState + from .strategies import LazyLoader from .util import AliasedClass from .util import AliasedInsp - from ..sql.elements import ColumnElement + from ..sql._typing import _CoreAdapterProto + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _InfoType + from ..sql.annotation import _AnnotationDict + from ..sql.elements import BinaryExpression + from ..sql.elements import BindParameter + from ..sql.elements import ClauseElement + from ..sql.schema import Table + from ..sql.selectable import FromClause + from ..util.typing import _AnnotationScanType + from ..util.typing import RODescriptorReference _T = TypeVar("_T", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) + _PT = TypeVar("_PT", bound=Any) +_PT2 = TypeVar("_PT2", bound=Any) + _RelationshipArgumentType = Union[ str, @@ -111,7 +154,10 @@ _RelationshipJoinConditionArgument = Union[ str, _ColumnExpressionArgument[bool] ] _ORMOrderByArgument = Union[ - Literal[False], str, _ColumnExpressionArgument[Any] + Literal[False], + str, + _ColumnExpressionArgument[Any], + Iterable[Union[str, _ColumnExpressionArgument[Any]]], ] _ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] _ORMColCollectionArgument = Union[ @@ -120,7 +166,19 @@ _ORMColCollectionArgument = Union[ ] -def remote(expr): +_CEA = TypeVar("_CEA", bound=_ColumnExpressionArgument[Any]) + +_CE = TypeVar("_CE", bound="ColumnElement[Any]") + + +_ColumnPairIterable = Iterable[Tuple[ColumnElement[Any], ColumnElement[Any]]] + +_ColumnPairs = Sequence[Tuple[ColumnElement[Any], ColumnElement[Any]]] + +_MutableColumnPairs = List[Tuple[ColumnElement[Any], ColumnElement[Any]]] + + +def remote(expr: _CEA) -> _CEA: """Annotate a portion of a primaryjoin expression with a 'remote' annotation. @@ -134,12 +192,12 @@ def remote(expr): :func:`.foreign` """ - return _annotate_columns( + return _annotate_columns( # type: ignore coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True} ) -def foreign(expr): +def foreign(expr: _CEA) -> _CEA: """Annotate a portion of a primaryjoin expression with a 'foreign' annotation. @@ -154,11 +212,71 @@ def foreign(expr): """ - return _annotate_columns( + return _annotate_columns( # type: ignore coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True} ) +@dataclasses.dataclass +class _RelationshipArg(Generic[_T1, _T2]): + """stores a user-defined parameter value that must be resolved and + parsed later at mapper configuration time. + + """ + + __slots__ = "name", "argument", "resolved" + name: str + argument: _T1 + resolved: Optional[_T2] + + def _is_populated(self) -> bool: + return self.argument is not None + + def _resolve_against_registry( + self, clsregistry_resolver: Callable[[str, bool], _class_resolver] + ) -> None: + attr_value = self.argument + + if isinstance(attr_value, str): + self.resolved = clsregistry_resolver( + attr_value, self.name == "secondary" + )() + elif callable(attr_value) and not _is_mapped_class(attr_value): + self.resolved = attr_value() + else: + self.resolved = attr_value + + +class _RelationshipArgs(NamedTuple): + """stores user-passed parameters that are resolved at mapper configuration + time. + + """ + + secondary: _RelationshipArg[ + Optional[Union[FromClause, str]], + Optional[FromClause], + ] + primaryjoin: _RelationshipArg[ + Optional[_RelationshipJoinConditionArgument], + Optional[ColumnElement[Any]], + ] + secondaryjoin: _RelationshipArg[ + Optional[_RelationshipJoinConditionArgument], + Optional[ColumnElement[Any]], + ] + order_by: _RelationshipArg[ + _ORMOrderByArgument, + Union[Literal[None, False], Tuple[ColumnElement[Any], ...]], + ] + foreign_keys: _RelationshipArg[ + Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] + ] + remote_side: _RelationshipArg[ + Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] + ] + + @log.class_logger class Relationship( _IntrospectsAnnotations, StrategizedProperty[_T], log.Identified @@ -184,6 +302,10 @@ class Relationship( _links_to_entity = True _is_relationship = True + _overlaps: Sequence[str] + + _lazy_strategy: LazyLoader + _persistence_only = dict( passive_deletes=False, passive_updates=True, @@ -192,56 +314,87 @@ class Relationship( cascade_backrefs=False, ) - _dependency_processor = None + _dependency_processor: Optional[DependencyProcessor] = None + + primaryjoin: ColumnElement[bool] + secondaryjoin: Optional[ColumnElement[bool]] + secondary: Optional[FromClause] + _join_condition: JoinCondition + order_by: Union[Literal[False], Tuple[ColumnElement[Any], ...]] + + _user_defined_foreign_keys: Set[ColumnElement[Any]] + _calculated_foreign_keys: Set[ColumnElement[Any]] + + remote_side: Set[ColumnElement[Any]] + local_columns: Set[ColumnElement[Any]] + + synchronize_pairs: _ColumnPairs + secondary_synchronize_pairs: Optional[_ColumnPairs] + + local_remote_pairs: Optional[_ColumnPairs] + + direction: RelationshipDirection + + _init_args: _RelationshipArgs def __init__( self, argument: Optional[_RelationshipArgumentType[_T]] = None, - secondary=None, + secondary: Optional[Union[FromClause, str]] = None, *, - uselist=None, - collection_class=None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, - order_by=False, - backref=None, - cascade_backrefs=False, - overlaps=None, - post_update=False, - cascade="save-update, merge", - viewonly=False, + uselist: Optional[bool] = None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[_ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, lazy: _LazyLoadArgumentType = "select", - passive_deletes=False, - passive_updates=True, - active_history=False, - enable_typechecks=True, - foreign_keys=None, - remote_side=None, - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - load_on_pending=False, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - doc=None, - bake_queries=True, - _local_remote_pairs=None, - _legacy_inactive_history_style=False, + passive_deletes: Union[Literal["all"], bool] = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[ + Type[Relationship.Comparator[Any]] + ] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, + doc: Optional[str] = None, + bake_queries: Literal[True] = True, + cascade_backrefs: Literal[False] = False, + _local_remote_pairs: Optional[_ColumnPairs] = None, + _legacy_inactive_history_style: bool = False, ): super(Relationship, self).__init__() self.uselist = uselist self.argument = argument - self.secondary = secondary - self.primaryjoin = primaryjoin - self.secondaryjoin = secondaryjoin + + self._init_args = _RelationshipArgs( + _RelationshipArg("secondary", secondary, None), + _RelationshipArg("primaryjoin", primaryjoin, None), + _RelationshipArg("secondaryjoin", secondaryjoin, None), + _RelationshipArg("order_by", order_by, None), + _RelationshipArg("foreign_keys", foreign_keys, None), + _RelationshipArg("remote_side", remote_side, None), + ) + self.post_update = post_update - self.direction = None self.viewonly = viewonly if viewonly: self._warn_for_persistence_only_flags( @@ -258,7 +411,6 @@ class Relationship( self.sync_backref = sync_backref self.lazy = lazy self.single_parent = single_parent - self._user_defined_foreign_keys = foreign_keys self.collection_class = collection_class self.passive_deletes = passive_deletes @@ -269,7 +421,6 @@ class Relationship( ) self.passive_updates = passive_updates - self.remote_side = remote_side self.enable_typechecks = enable_typechecks self.query_class = query_class self.innerjoin = innerjoin @@ -292,23 +443,22 @@ class Relationship( self.local_remote_pairs = _local_remote_pairs self.load_on_pending = load_on_pending self.comparator_factory = comparator_factory or Relationship.Comparator - self.comparator = self.comparator_factory(self, None) util.set_creation_order(self) if info is not None: - self.info = info + self.info.update(info) self.strategy_key = (("lazy", self.lazy),) - self._reverse_property = set() + self._reverse_property: Set[Relationship[Any]] = set() + if overlaps: - self._overlaps = set(re.split(r"\s*,\s*", overlaps)) + self._overlaps = set(re.split(r"\s*,\s*", overlaps)) # type: ignore # noqa: E501 else: self._overlaps = () - self.cascade = cascade - - self.order_by = order_by + # mypy ignoring the @property setter + self.cascade = cascade # type: ignore self.back_populates = back_populates @@ -322,7 +472,7 @@ class Relationship( else: self.backref = backref - def _warn_for_persistence_only_flags(self, **kw): + def _warn_for_persistence_only_flags(self, **kw: Any) -> None: for k, v in kw.items(): if v != self._persistence_only[k]: # we are warning here rather than warn deprecated as this is a @@ -340,7 +490,7 @@ class Relationship( "in a future release." % (k,) ) - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: attributes.register_descriptor( mapper.class_, self.key, @@ -378,13 +528,16 @@ class Relationship( "_extra_criteria", ) + prop: RODescriptorReference[Relationship[_PT]] + _of_type: Optional[_EntityType[_PT]] + def __init__( self, - prop, - parentmapper, - adapt_to_entity=None, - of_type=None, - extra_criteria=(), + prop: Relationship[_PT], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, + of_type: Optional[_EntityType[_PT]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), ): """Construction of :class:`.Relationship.Comparator` is internal to the ORM's attribute mechanics. @@ -399,15 +552,17 @@ class Relationship( self._of_type = None self._extra_criteria = extra_criteria - def adapt_to_entity(self, adapt_to_entity): + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> Relationship.Comparator[Any]: return self.__class__( - self.property, + self.prop, self._parententity, adapt_to_entity=adapt_to_entity, of_type=self._of_type, ) - entity: _InternalEntityType + entity: _InternalEntityType[_PT] """The target entity referred to by this :class:`.Relationship.Comparator`. @@ -419,7 +574,7 @@ class Relationship( """ - mapper: Mapper[Any] + mapper: Mapper[_PT] """The target :class:`_orm.Mapper` referred to by this :class:`.Relationship.Comparator`. @@ -428,22 +583,22 @@ class Relationship( """ - def _memoized_attr_entity(self) -> _InternalEntityType: + def _memoized_attr_entity(self) -> _InternalEntityType[_PT]: if self._of_type: - return inspect(self._of_type) + return inspect(self._of_type) # type: ignore else: return self.prop.entity - def _memoized_attr_mapper(self) -> Mapper[Any]: + def _memoized_attr_mapper(self) -> Mapper[_PT]: return self.entity.mapper - def _source_selectable(self): + def _source_selectable(self) -> FromClause: if self._adapt_to_entity: return self._adapt_to_entity.selectable else: return self.property.parent._with_polymorphic_selectable - def __clause_element__(self): + def __clause_element__(self) -> ColumnElement[bool]: adapt_from = self._source_selectable() if self._of_type: of_type_entity = inspect(self._of_type) @@ -457,7 +612,7 @@ class Relationship( dest, secondary, target_adapter, - ) = self.property._create_joins( + ) = self.prop._create_joins( source_selectable=adapt_from, source_polymorphic=True, of_type_entity=of_type_entity, @@ -469,7 +624,7 @@ class Relationship( else: return pj - def of_type(self, cls): + def of_type(self, class_: _EntityType[_PT]) -> PropComparator[_PT]: r"""Redefine this object in terms of a polymorphic subclass. See :meth:`.PropComparator.of_type` for an example. @@ -477,16 +632,16 @@ class Relationship( """ return Relationship.Comparator( - self.property, + self.prop, self._parententity, adapt_to_entity=self._adapt_to_entity, - of_type=cls, + of_type=class_, extra_criteria=self._extra_criteria, ) def and_( self, *criteria: _ColumnExpressionArgument[bool] - ) -> PropComparator[bool]: + ) -> PropComparator[Any]: """Add AND criteria. See :meth:`.PropComparator.and_` for an example. @@ -500,14 +655,14 @@ class Relationship( ) return Relationship.Comparator( - self.property, + self.prop, self._parententity, adapt_to_entity=self._adapt_to_entity, of_type=self._of_type, extra_criteria=self._extra_criteria + exprs, ) - def in_(self, other): + def in_(self, other: Any) -> NoReturn: """Produce an IN clause - this is not implemented for :func:`_orm.relationship`-based attributes at this time. @@ -522,7 +677,7 @@ class Relationship( # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore - def __eq__(self, other): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 """Implement the ``==`` operator. In a many-to-one context, such as:: @@ -559,7 +714,7 @@ class Relationship( or many-to-many context produce a NOT EXISTS clause. """ - if isinstance(other, (util.NoneType, expression.Null)): + if other is None or isinstance(other, expression.Null): if self.property.direction in [ONETOMANY, MANYTOMANY]: return ~self._criterion_exists() else: @@ -585,8 +740,18 @@ class Relationship( criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, ) -> Exists: + + where_criteria = ( + coercions.expect(roles.WhereHavingRole, criterion) + if criterion is not None + else None + ) + if getattr(self, "_of_type", None): - info = inspect(self._of_type) + info: Optional[_InternalEntityType[Any]] = inspect( + self._of_type + ) + assert info is not None target_mapper, to_selectable, is_aliased_class = ( info.mapper, info.selectable, @@ -597,10 +762,10 @@ class Relationship( single_crit = target_mapper._single_table_criterion if single_crit is not None: - if criterion is not None: - criterion = single_crit & criterion + if where_criteria is not None: + where_criteria = single_crit & where_criteria else: - criterion = single_crit + where_criteria = single_crit else: is_aliased_class = False to_selectable = None @@ -624,10 +789,10 @@ class Relationship( for k in kwargs: crit = getattr(self.property.mapper.class_, k) == kwargs[k] - if criterion is None: - criterion = crit + if where_criteria is None: + where_criteria = crit else: - criterion = criterion & crit + where_criteria = where_criteria & crit # annotate the *local* side of the join condition, in the case # of pj + sj this is the full primaryjoin, in the case of just @@ -638,24 +803,24 @@ class Relationship( j = _orm_annotate(pj, exclude=self.property.remote_side) if ( - criterion is not None + where_criteria is not None and target_adapter and not is_aliased_class ): # limit this adapter to annotated only? - criterion = target_adapter.traverse(criterion) + where_criteria = target_adapter.traverse(where_criteria) # only have the "joined left side" of what we # return be subject to Query adaption. The right # side of it is used for an exists() subquery and # should not correlate or otherwise reach out # to anything in the enclosing query. - if criterion is not None: - criterion = criterion._annotate( + if where_criteria is not None: + where_criteria = where_criteria._annotate( {"no_replacement_traverse": True} ) - crit = j & sql.True_._ifnone(criterion) + crit = j & sql.True_._ifnone(where_criteria) if secondary is not None: ex = ( @@ -673,7 +838,11 @@ class Relationship( ) return ex - def any(self, criterion=None, **kwargs): + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: """Produce an expression that tests a collection against particular criterion, using EXISTS. @@ -722,7 +891,11 @@ class Relationship( return self._criterion_exists(criterion, **kwargs) - def has(self, criterion=None, **kwargs): + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: """Produce an expression that tests a scalar reference against particular criterion, using EXISTS. @@ -756,7 +929,9 @@ class Relationship( ) return self._criterion_exists(criterion, **kwargs) - def contains(self, other, **kwargs): + def contains( + self, other: _ColumnExpressionArgument[Any], **kwargs: Any + ) -> ColumnElement[bool]: """Return a simple expression that tests a collection for containment of a particular item. @@ -815,38 +990,45 @@ class Relationship( kwargs may be ignored by this operator but are required for API conformance. """ - if not self.property.uselist: + if not self.prop.uselist: raise sa_exc.InvalidRequestError( "'contains' not implemented for scalar " "attributes. Use ==" ) - clause = self.property._optimized_compare( + + clause = self.prop._optimized_compare( other, adapt_source=self.adapter ) - if self.property.secondaryjoin is not None: + if self.prop.secondaryjoin is not None: clause.negation_clause = self.__negated_contains_or_equals( other ) return clause - def __negated_contains_or_equals(self, other): - if self.property.direction == MANYTOONE: + def __negated_contains_or_equals( + self, other: Any + ) -> ColumnElement[bool]: + if self.prop.direction == MANYTOONE: state = attributes.instance_state(other) - def state_bindparam(local_col, state, remote_col): + def state_bindparam( + local_col: ColumnElement[Any], + state: InstanceState[Any], + remote_col: ColumnElement[Any], + ) -> BindParameter[Any]: dict_ = state.dict return sql.bindparam( local_col.key, type_=local_col.type, unique=True, - callable_=self.property._get_attr_w_warn_on_none( - self.property.mapper, state, dict_, remote_col + callable_=self.prop._get_attr_w_warn_on_none( + self.prop.mapper, state, dict_, remote_col ), ) - def adapt(col): + def adapt(col: _CE) -> _CE: if self.adapter: return self.adapter(col) else: @@ -876,7 +1058,7 @@ class Relationship( return ~self._criterion_exists(criterion) - def __ne__(self, other): + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 """Implement the ``!=`` operator. In a many-to-one context, such as:: @@ -915,7 +1097,7 @@ class Relationship( or many-to-many context produce an EXISTS clause. """ - if isinstance(other, (util.NoneType, expression.Null)): + if other is None or isinstance(other, expression.Null): if self.property.direction == MANYTOONE: return _orm_annotate( ~self.property._optimized_compare( @@ -934,12 +1116,10 @@ class Relationship( else: return _orm_annotate(self.__negated_contains_or_equals(other)) - def _memoized_attr_property(self): + def _memoized_attr_property(self) -> Relationship[_PT]: self.prop.parent._check_configure() return self.prop - comparator: Comparator[_T] - def _with_parent( self, instance: object, @@ -947,10 +1127,11 @@ class Relationship( from_entity: Optional[_EntityType[Any]] = None, ) -> ColumnElement[bool]: assert instance is not None - adapt_source = None + adapt_source: Optional[_CoreAdapterProto] = None if from_entity is not None: - insp = inspect(from_entity) - if insp.is_aliased_class: + insp: Optional[_InternalEntityType[Any]] = inspect(from_entity) + assert insp is not None + if insp_is_aliased_class(insp): adapt_source = insp._adapter.adapt_clause return self._optimized_compare( instance, @@ -961,11 +1142,11 @@ class Relationship( def _optimized_compare( self, - state, - value_is_parent=False, - adapt_source=None, - alias_secondary=True, - ): + state: Any, + value_is_parent: bool = False, + adapt_source: Optional[_CoreAdapterProto] = None, + alias_secondary: bool = True, + ) -> ColumnElement[bool]: if state is not None: try: state = inspect(state) @@ -1005,7 +1186,7 @@ class Relationship( dict_ = attributes.instance_dict(state.obj()) - def visit_bindparam(bindparam): + def visit_bindparam(bindparam: BindParameter[Any]) -> None: if bindparam._identifying_key in bind_to_col: bindparam.callable = self._get_attr_w_warn_on_none( mapper, @@ -1027,7 +1208,13 @@ class Relationship( criterion = adapt_source(criterion) return criterion - def _get_attr_w_warn_on_none(self, mapper, state, dict_, column): + def _get_attr_w_warn_on_none( + self, + mapper: Mapper[Any], + state: InstanceState[Any], + dict_: _InstanceDict, + column: ColumnElement[Any], + ) -> Callable[[], Any]: """Create the callable that is used in a many-to-one expression. E.g.:: @@ -1077,9 +1264,14 @@ class Relationship( # this feature was added explicitly for use in this method. state._track_last_known_value(prop.key) - def _go(): - last_known = to_return = state._last_known_values[prop.key] - existing_is_available = last_known is not attributes.NO_VALUE + lkv_fixed = state._last_known_values + + def _go() -> Any: + assert lkv_fixed is not None + last_known = to_return = lkv_fixed[prop.key] + existing_is_available = ( + last_known is not LoaderCallableStatus.NO_VALUE + ) # we support that the value may have changed. so here we # try to get the most recent value including re-fetching. @@ -1089,19 +1281,19 @@ class Relationship( state, dict_, column, - passive=attributes.PASSIVE_OFF + passive=PassiveFlag.PASSIVE_OFF if state.persistent - else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK, + else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK, ) - if current_value is attributes.NEVER_SET: + if current_value is LoaderCallableStatus.NEVER_SET: if not existing_is_available: raise sa_exc.InvalidRequestError( "Can't resolve value for column %s on object " "%s; no value has been set for this column" % (column, state_str(state)) ) - elif current_value is attributes.PASSIVE_NO_RESULT: + elif current_value is LoaderCallableStatus.PASSIVE_NO_RESULT: if not existing_is_available: raise sa_exc.InvalidRequestError( "Can't resolve value for column %s on object " @@ -1121,7 +1313,11 @@ class Relationship( return _go - def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): + def _lazy_none_clause( + self, + reverse_direction: bool = False, + adapt_source: Optional[_CoreAdapterProto] = None, + ) -> ColumnElement[bool]: if not reverse_direction: criterion, bind_to_col = ( self._lazy_strategy._lazywhere, @@ -1139,20 +1335,20 @@ class Relationship( criterion = adapt_source(criterion) return criterion - def __str__(self): + def __str__(self) -> str: return str(self.parent.class_.__name__) + "." + self.key def merge( self, - session, - source_state, - source_dict, - dest_state, - dest_dict, - load, - _recursive, - _resolve_conflict_map, - ): + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: if load: for r in self._reverse_property: @@ -1167,6 +1363,8 @@ class Relationship( if self.uselist: impl = source_state.get_impl(self.key) + + assert is_has_collection_adapter(impl) instances_iterable = impl.get_collection(source_state, source_dict) # if this is a CollectionAttributeImpl, then empty should @@ -1204,9 +1402,9 @@ class Relationship( for c in dest_list: coll.append_without_event(c) else: - dest_state.get_impl(self.key).set( - dest_state, dest_dict, dest_list, _adapt=False - ) + dest_impl = dest_state.get_impl(self.key) + assert is_has_collection_adapter(dest_impl) + dest_impl.set(dest_state, dest_dict, dest_list, _adapt=False) else: current = source_dict[self.key] if current is not None: @@ -1231,8 +1429,12 @@ class Relationship( ) def _value_as_iterable( - self, state, dict_, key, passive=attributes.PASSIVE_OFF - ): + self, + state: InstanceState[_O], + dict_: _InstanceDict, + key: str, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Sequence[Tuple[InstanceState[_O], _O]]: """Return a list of tuples (state, obj) for the given key. @@ -1241,9 +1443,9 @@ class Relationship( impl = state.manager[key].impl x = impl.get(state, dict_, passive=passive) - if x is attributes.PASSIVE_NO_RESULT or x is None: + if x is LoaderCallableStatus.PASSIVE_NO_RESULT or x is None: return [] - elif hasattr(impl, "get_collection"): + elif is_has_collection_adapter(impl): return [ (attributes.instance_state(o), o) for o in impl.get_collection(state, dict_, x, passive=passive) @@ -1252,19 +1454,23 @@ class Relationship( return [(attributes.instance_state(x), x)] def cascade_iterator( - self, type_, state, dict_, visited_states, halt_on=None - ): + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[Tuple[Any, Mapper[Any], InstanceState[Any], _InstanceDict]]: # assert type_ in self._cascade # only actively lazy load on the 'delete' cascade if type_ != "delete" or self.passive_deletes: - passive = attributes.PASSIVE_NO_INITIALIZE + passive = PassiveFlag.PASSIVE_NO_INITIALIZE else: - passive = attributes.PASSIVE_OFF + passive = PassiveFlag.PASSIVE_OFF if type_ == "save-update": tuples = state.manager[self.key].impl.get_all_pending(state, dict_) - else: tuples = self._value_as_iterable( state, dict_, self.key, passive=passive @@ -1285,6 +1491,7 @@ class Relationship( # see [ticket:2229] continue + assert instance_state is not None instance_dict = attributes.instance_dict(c) if halt_on and halt_on(instance_state): @@ -1308,14 +1515,16 @@ class Relationship( yield c, instance_mapper, instance_state, instance_dict @property - def _effective_sync_backref(self): + def _effective_sync_backref(self) -> bool: if self.viewonly: return False else: return self.sync_backref is not False @staticmethod - def _check_sync_backref(rel_a, rel_b): + def _check_sync_backref( + rel_a: Relationship[Any], rel_b: Relationship[Any] + ) -> None: if rel_a.viewonly and rel_b.sync_backref: raise sa_exc.InvalidRequestError( "Relationship %s cannot specify sync_backref=True since %s " @@ -1328,7 +1537,7 @@ class Relationship( ): rel_b.sync_backref = False - def _add_reverse_property(self, key): + def _add_reverse_property(self, key: str) -> None: other = self.mapper.get_property(key, _configure_mappers=False) if not isinstance(other, Relationship): raise sa_exc.InvalidRequestError( @@ -1361,7 +1570,8 @@ class Relationship( ) if ( - self.direction in (ONETOMANY, MANYTOONE) + other._configure_started + and self.direction in (ONETOMANY, MANYTOONE) and self.direction == other.direction ): raise sa_exc.ArgumentError( @@ -1372,7 +1582,7 @@ class Relationship( ) @util.memoized_property - def entity(self) -> Union["Mapper", "AliasedInsp"]: + def entity(self) -> _InternalEntityType[_T]: """Return the target mapped entity, which is an inspect() of the class or aliased class that is referred towards. @@ -1388,7 +1598,7 @@ class Relationship( """ return self.entity.mapper - def do_init(self): + def do_init(self) -> None: self._check_conflicts() self._process_dependent_arguments() self._setup_entity() @@ -1399,14 +1609,16 @@ class Relationship( self._generate_backref() self._join_condition._warn_for_conflicting_sync_targets() super(Relationship, self).do_init() - self._lazy_strategy = self._get_strategy((("lazy", "select"),)) + self._lazy_strategy = cast( + "LazyLoader", self._get_strategy((("lazy", "select"),)) + ) - def _setup_registry_dependencies(self): + def _setup_registry_dependencies(self) -> None: self.parent.mapper.registry._set_depends_on( self.entity.mapper.registry ) - def _process_dependent_arguments(self): + def _process_dependent_arguments(self) -> None: """Convert incoming configuration arguments to their proper form. @@ -1417,78 +1629,80 @@ class Relationship( # accept callables for other attributes which may require # deferred initialization. This technique is used # by declarative "string configs" and some recipes. + init_args = self._init_args + for attr in ( "order_by", "primaryjoin", "secondaryjoin", "secondary", - "_user_defined_foreign_keys", + "foreign_keys", "remote_side", ): - attr_value = getattr(self, attr) - - if isinstance(attr_value, str): - setattr( - self, - attr, - self._clsregistry_resolve_arg( - attr_value, favor_tables=attr == "secondary" - )(), - ) - elif callable(attr_value) and not _is_mapped_class(attr_value): - setattr(self, attr, attr_value()) + + rel_arg = getattr(init_args, attr) + + rel_arg._resolve_against_registry(self._clsregistry_resolvers[1]) # remove "annotations" which are present if mapped class # descriptors are used to create the join expression. for attr in "primaryjoin", "secondaryjoin": - val = getattr(self, attr) + rel_arg = getattr(init_args, attr) + val = rel_arg.resolved if val is not None: - setattr( - self, - attr, - _orm_deannotate( - coercions.expect( - roles.ColumnArgumentRole, val, argname=attr - ) - ), + rel_arg.resolved = _orm_deannotate( + coercions.expect( + roles.ColumnArgumentRole, val, argname=attr + ) ) - if self.secondary is not None and _is_mapped_class(self.secondary): + secondary = init_args.secondary.resolved + if secondary is not None and _is_mapped_class(secondary): raise sa_exc.ArgumentError( "secondary argument %s passed to to relationship() %s must " "be a Table object or other FROM clause; can't send a mapped " "class directly as rows in 'secondary' are persisted " "independently of a class that is mapped " - "to that same table." % (self.secondary, self) + "to that same table." % (secondary, self) ) # ensure expressions in self.order_by, foreign_keys, # remote_side are all columns, not strings. - if self.order_by is not False and self.order_by is not None: + if ( + init_args.order_by.resolved is not False + and init_args.order_by.resolved is not None + ): self.order_by = tuple( coercions.expect( roles.ColumnArgumentRole, x, argname="order_by" ) - for x in util.to_list(self.order_by) + for x in util.to_list(init_args.order_by.resolved) ) + else: + self.order_by = False self._user_defined_foreign_keys = util.column_set( coercions.expect( roles.ColumnArgumentRole, x, argname="foreign_keys" ) - for x in util.to_column_set(self._user_defined_foreign_keys) + for x in util.to_column_set(init_args.foreign_keys.resolved) ) self.remote_side = util.column_set( coercions.expect( roles.ColumnArgumentRole, x, argname="remote_side" ) - for x in util.to_column_set(self.remote_side) + for x in util.to_column_set(init_args.remote_side.resolved) ) def declarative_scan( - self, registry, cls, key, annotation, is_dataclass_field - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: argument = _extract_mapped_subtype( annotation, cls, @@ -1502,17 +1716,19 @@ class Relationship( if hasattr(argument, "__origin__"): - collection_class = argument.__origin__ + collection_class = argument.__origin__ # type: ignore if issubclass(collection_class, abc.Collection): if self.collection_class is None: self.collection_class = collection_class else: self.uselist = False - if argument.__args__: - if issubclass(argument.__origin__, typing.Mapping): - type_arg = argument.__args__[1] + if argument.__args__: # type: ignore + if issubclass( + argument.__origin__, typing.Mapping # type: ignore + ): + type_arg = argument.__args__[1] # type: ignore else: - type_arg = argument.__args__[0] + type_arg = argument.__args__[0] # type: ignore if hasattr(type_arg, "__forward_arg__"): str_argument = type_arg.__forward_arg__ argument = str_argument @@ -1523,12 +1739,12 @@ class Relationship( f"Generic alias {argument} requires an argument" ) elif hasattr(argument, "__forward_arg__"): - argument = argument.__forward_arg__ + argument = argument.__forward_arg__ # type: ignore self.argument = argument @util.preload_module("sqlalchemy.orm.mapper") - def _setup_entity(self, __argument=None): + def _setup_entity(self, __argument: Any = None) -> None: if "entity" in self.__dict__: return @@ -1539,42 +1755,51 @@ class Relationship( else: argument = self.argument + resolved_argument: _ExternalEntityType[Any] + if isinstance(argument, str): - argument = self._clsregistry_resolve_name(argument)() + # we might want to cleanup clsregistry API to make this + # more straightforward + resolved_argument = cast( + "_ExternalEntityType[Any]", + self._clsregistry_resolve_name(argument)(), + ) elif callable(argument) and not isinstance( argument, (type, mapperlib.Mapper) ): - argument = argument() + resolved_argument = argument() else: - argument = argument + resolved_argument = argument - if isinstance(argument, type): - entity = class_mapper(argument, configure=False) + entity: _InternalEntityType[Any] + + if isinstance(resolved_argument, type): + entity = class_mapper(resolved_argument, configure=False) else: try: - entity = inspect(argument) + entity = inspect(resolved_argument) except sa_exc.NoInspectionAvailable: - entity = None + entity = None # type: ignore if not hasattr(entity, "mapper"): raise sa_exc.ArgumentError( "relationship '%s' expects " "a class or a mapper argument (received: %s)" - % (self.key, type(argument)) + % (self.key, type(resolved_argument)) ) self.entity = entity # type: ignore self.target = self.entity.persist_selectable - def _setup_join_conditions(self): + def _setup_join_conditions(self) -> None: self._join_condition = jc = JoinCondition( parent_persist_selectable=self.parent.persist_selectable, child_persist_selectable=self.entity.persist_selectable, parent_local_selectable=self.parent.local_table, child_local_selectable=self.entity.local_table, - primaryjoin=self.primaryjoin, - secondary=self.secondary, - secondaryjoin=self.secondaryjoin, + primaryjoin=self._init_args.primaryjoin.resolved, + secondary=self._init_args.secondary.resolved, + secondaryjoin=self._init_args.secondaryjoin.resolved, parent_equivalents=self.parent._equivalent_columns, child_equivalents=self.mapper._equivalent_columns, consider_as_foreign_keys=self._user_defined_foreign_keys, @@ -1587,6 +1812,7 @@ class Relationship( ) self.primaryjoin = jc.primaryjoin self.secondaryjoin = jc.secondaryjoin + self.secondary = jc.secondary self.direction = jc.direction self.local_remote_pairs = jc.local_remote_pairs self.remote_side = jc.remote_columns @@ -1596,21 +1822,30 @@ class Relationship( self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs @property - def _clsregistry_resolve_arg(self): + def _clsregistry_resolve_arg( + self, + ) -> Callable[[str, bool], _class_resolver]: return self._clsregistry_resolvers[1] @property - def _clsregistry_resolve_name(self): + def _clsregistry_resolve_name( + self, + ) -> Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]]: return self._clsregistry_resolvers[0] @util.memoized_property @util.preload_module("sqlalchemy.orm.clsregistry") - def _clsregistry_resolvers(self): + def _clsregistry_resolvers( + self, + ) -> Tuple[ + Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], + Callable[[str, bool], _class_resolver], + ]: _resolver = util.preloaded.orm_clsregistry._resolver return _resolver(self.parent.class_, self) - def _check_conflicts(self): + def _check_conflicts(self) -> None: """Test that this relationship is legal, warn about inheritance conflicts.""" if self.parent.non_primary and not class_mapper( @@ -1637,10 +1872,10 @@ class Relationship( return self._cascade @cascade.setter - def cascade(self, cascade: Union[str, CascadeOptions]): + def cascade(self, cascade: Union[str, CascadeOptions]) -> None: self._set_cascade(cascade) - def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]): + def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]) -> None: cascade = CascadeOptions(cascade_arg) if self.viewonly: @@ -1655,7 +1890,7 @@ class Relationship( if self._dependency_processor: self._dependency_processor.cascade = cascade - def _check_cascade_settings(self, cascade): + def _check_cascade_settings(self, cascade: CascadeOptions) -> None: if ( cascade.delete_orphan and not self.single_parent @@ -1699,7 +1934,7 @@ class Relationship( (self.key, self.parent.class_) ) - def _persists_for(self, mapper): + def _persists_for(self, mapper: Mapper[Any]) -> bool: """Return True if this property will persist values on behalf of the given mapper. @@ -1710,16 +1945,15 @@ class Relationship( and mapper.relationships[self.key] is self ) - def _columns_are_mapped(self, *cols): + def _columns_are_mapped(self, *cols: ColumnElement[Any]) -> bool: """Return True if all columns in the given collection are mapped by the tables referenced by this :class:`.Relationship`. """ + + secondary = self._init_args.secondary.resolved for c in cols: - if ( - self.secondary is not None - and self.secondary.c.contains_column(c) - ): + if secondary is not None and secondary.c.contains_column(c): continue if not self.parent.persist_selectable.c.contains_column( c @@ -1727,13 +1961,14 @@ class Relationship( return False return True - def _generate_backref(self): + def _generate_backref(self) -> None: """Interpret the 'backref' instruction to create a :func:`_orm.relationship` complementary to this one.""" if self.parent.non_primary: return if self.backref is not None and not self.back_populates: + kwargs: Dict[str, Any] if isinstance(self.backref, str): backref_key, kwargs = self.backref, {} else: @@ -1805,7 +2040,7 @@ class Relationship( self._add_reverse_property(self.back_populates) @util.preload_module("sqlalchemy.orm.dependency") - def _post_init(self): + def _post_init(self) -> None: dependency = util.preloaded.orm_dependency if self.uselist is None: @@ -1816,7 +2051,7 @@ class Relationship( )(self) @util.memoized_property - def _use_get(self): + def _use_get(self) -> bool: """memoize the 'use_get' attribute of this RelationshipLoader's lazyloader.""" @@ -1824,18 +2059,25 @@ class Relationship( return strategy.use_get @util.memoized_property - def _is_self_referential(self): + def _is_self_referential(self) -> bool: return self.mapper.common_parent(self.parent) def _create_joins( self, - source_polymorphic=False, - source_selectable=None, - dest_selectable=None, - of_type_entity=None, - alias_secondary=False, - extra_criteria=(), - ): + source_polymorphic: bool = False, + source_selectable: Optional[FromClause] = None, + dest_selectable: Optional[FromClause] = None, + of_type_entity: Optional[_InternalEntityType[Any]] = None, + alias_secondary: bool = False, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ) -> Tuple[ + ColumnElement[bool], + Optional[ColumnElement[bool]], + FromClause, + FromClause, + Optional[FromClause], + Optional[ClauseAdapter], + ]: aliased = False @@ -1905,38 +2147,56 @@ class Relationship( ) -def _annotate_columns(element, annotations): - def clone(elem): +def _annotate_columns(element: _CE, annotations: _AnnotationDict) -> _CE: + def clone(elem: _CE) -> _CE: if isinstance(elem, expression.ColumnClause): - elem = elem._annotate(annotations.copy()) + elem = elem._annotate(annotations.copy()) # type: ignore elem._copy_internals(clone=clone) return elem if element is not None: element = clone(element) - clone = None # remove gc cycles + clone = None # type: ignore # remove gc cycles return element class JoinCondition: + + primaryjoin_initial: Optional[ColumnElement[bool]] + primaryjoin: ColumnElement[bool] + secondaryjoin: Optional[ColumnElement[bool]] + secondary: Optional[FromClause] + prop: Relationship[Any] + + synchronize_pairs: _ColumnPairs + secondary_synchronize_pairs: _ColumnPairs + direction: RelationshipDirection + + parent_persist_selectable: FromClause + child_persist_selectable: FromClause + parent_local_selectable: FromClause + child_local_selectable: FromClause + + _local_remote_pairs: Optional[_ColumnPairs] + def __init__( self, - parent_persist_selectable, - child_persist_selectable, - parent_local_selectable, - child_local_selectable, - primaryjoin=None, - secondary=None, - secondaryjoin=None, - parent_equivalents=None, - child_equivalents=None, - consider_as_foreign_keys=None, - local_remote_pairs=None, - remote_side=None, - self_referential=False, - prop=None, - support_sync=True, - can_be_synced_fn=lambda *c: True, + parent_persist_selectable: FromClause, + child_persist_selectable: FromClause, + parent_local_selectable: FromClause, + child_local_selectable: FromClause, + primaryjoin: Optional[ColumnElement[bool]] = None, + secondary: Optional[FromClause] = None, + secondaryjoin: Optional[ColumnElement[bool]] = None, + parent_equivalents: Optional[_EquivalentColumnMap] = None, + child_equivalents: Optional[_EquivalentColumnMap] = None, + consider_as_foreign_keys: Any = None, + local_remote_pairs: Optional[_ColumnPairs] = None, + remote_side: Any = None, + self_referential: Any = False, + prop: Optional[Relationship[Any]] = None, + support_sync: bool = True, + can_be_synced_fn: Callable[..., bool] = lambda *c: True, ): self.parent_persist_selectable = parent_persist_selectable self.parent_local_selectable = parent_local_selectable @@ -1944,7 +2204,7 @@ class JoinCondition: self.child_local_selectable = child_local_selectable self.parent_equivalents = parent_equivalents self.child_equivalents = child_equivalents - self.primaryjoin = primaryjoin + self.primaryjoin_initial = primaryjoin self.secondaryjoin = secondaryjoin self.secondary = secondary self.consider_as_foreign_keys = consider_as_foreign_keys @@ -1954,7 +2214,10 @@ class JoinCondition: self.self_referential = self_referential self.support_sync = support_sync self.can_be_synced_fn = can_be_synced_fn + self._determine_joins() + assert self.primaryjoin is not None + self._sanitize_joins() self._annotate_fks() self._annotate_remote() @@ -1968,7 +2231,7 @@ class JoinCondition: self._check_remote_side() self._log_joins() - def _log_joins(self): + def _log_joins(self) -> None: if self.prop is None: return log = self.prop.logger @@ -2008,7 +2271,7 @@ class JoinCondition: ) log.info("%s relationship direction %s", self.prop, self.direction) - def _sanitize_joins(self): + def _sanitize_joins(self) -> None: """remove the parententity annotation from our join conditions which can leak in here based on some declarative patterns and maybe others. @@ -2026,7 +2289,7 @@ class JoinCondition: self.secondaryjoin, values=("parententity", "proxy_key") ) - def _determine_joins(self): + def _determine_joins(self) -> None: """Determine the 'primaryjoin' and 'secondaryjoin' attributes, if not passed to the constructor already. @@ -2056,21 +2319,25 @@ class JoinCondition: a_subset=self.child_local_selectable, consider_as_foreign_keys=consider_as_foreign_keys, ) - if self.primaryjoin is None: + if self.primaryjoin_initial is None: self.primaryjoin = join_condition( self.parent_persist_selectable, self.secondary, a_subset=self.parent_local_selectable, consider_as_foreign_keys=consider_as_foreign_keys, ) + else: + self.primaryjoin = self.primaryjoin_initial else: - if self.primaryjoin is None: + if self.primaryjoin_initial is None: self.primaryjoin = join_condition( self.parent_persist_selectable, self.child_persist_selectable, a_subset=self.parent_local_selectable, consider_as_foreign_keys=consider_as_foreign_keys, ) + else: + self.primaryjoin = self.primaryjoin_initial except sa_exc.NoForeignKeysError as nfe: if self.secondary is not None: raise sa_exc.NoForeignKeysError( @@ -2118,15 +2385,16 @@ class JoinCondition: ) from afe @property - def primaryjoin_minus_local(self): + def primaryjoin_minus_local(self) -> ColumnElement[bool]: return _deep_deannotate(self.primaryjoin, values=("local", "remote")) @property - def secondaryjoin_minus_local(self): + def secondaryjoin_minus_local(self) -> ColumnElement[bool]: + assert self.secondaryjoin is not None return _deep_deannotate(self.secondaryjoin, values=("local", "remote")) @util.memoized_property - def primaryjoin_reverse_remote(self): + def primaryjoin_reverse_remote(self) -> ColumnElement[bool]: """Return the primaryjoin condition suitable for the "reverse" direction. @@ -2138,7 +2406,7 @@ class JoinCondition: """ if self._has_remote_annotations: - def replace(element): + def replace(element: _CE, **kw: Any) -> Optional[_CE]: if "remote" in element._annotations: v = dict(element._annotations) del v["remote"] @@ -2150,6 +2418,8 @@ class JoinCondition: v["remote"] = True return element._with_annotations(v) + return None + return visitors.replacement_traverse(self.primaryjoin, {}, replace) else: if self._has_foreign_annotations: @@ -2160,7 +2430,7 @@ class JoinCondition: else: return _deep_deannotate(self.primaryjoin) - def _has_annotation(self, clause, annotation): + def _has_annotation(self, clause: ClauseElement, annotation: str) -> bool: for col in visitors.iterate(clause, {}): if annotation in col._annotations: return True @@ -2168,14 +2438,14 @@ class JoinCondition: return False @util.memoized_property - def _has_foreign_annotations(self): + def _has_foreign_annotations(self) -> bool: return self._has_annotation(self.primaryjoin, "foreign") @util.memoized_property - def _has_remote_annotations(self): + def _has_remote_annotations(self) -> bool: return self._has_annotation(self.primaryjoin, "remote") - def _annotate_fks(self): + def _annotate_fks(self) -> None: """Annotate the primaryjoin and secondaryjoin structures with 'foreign' annotations marking columns considered as foreign. @@ -2189,10 +2459,11 @@ class JoinCondition: else: self._annotate_present_fks() - def _annotate_from_fk_list(self): - def check_fk(col): - if col in self.consider_as_foreign_keys: - return col._annotate({"foreign": True}) + def _annotate_from_fk_list(self) -> None: + def check_fk(element: _CE, **kw: Any) -> Optional[_CE]: + if element in self.consider_as_foreign_keys: + return element._annotate({"foreign": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, check_fk @@ -2202,13 +2473,15 @@ class JoinCondition: self.secondaryjoin, {}, check_fk ) - def _annotate_present_fks(self): + def _annotate_present_fks(self) -> None: if self.secondary is not None: secondarycols = util.column_set(self.secondary.c) else: secondarycols = set() - def is_foreign(a, b): + def is_foreign( + a: ColumnElement[Any], b: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: if isinstance(a, schema.Column) and isinstance(b, schema.Column): if a.references(b): return a @@ -2221,7 +2494,9 @@ class JoinCondition: elif b in secondarycols and a not in secondarycols: return b - def visit_binary(binary): + return None + + def visit_binary(binary: BinaryExpression[Any]) -> None: if not isinstance( binary.left, sql.ColumnElement ) or not isinstance(binary.right, sql.ColumnElement): @@ -2248,16 +2523,17 @@ class JoinCondition: self.secondaryjoin, {}, {"binary": visit_binary} ) - def _refers_to_parent_table(self): + def _refers_to_parent_table(self) -> bool: """Return True if the join condition contains column comparisons where both columns are in both tables. """ pt = self.parent_persist_selectable mt = self.child_persist_selectable - result = [False] + result = False - def visit_binary(binary): + def visit_binary(binary: BinaryExpression[Any]) -> None: + nonlocal result c, f = binary.left, binary.right if ( isinstance(c, expression.ColumnClause) @@ -2267,19 +2543,19 @@ class JoinCondition: and mt.is_derived_from(c.table) and mt.is_derived_from(f.table) ): - result[0] = True + result = True visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary}) - return result[0] + return result - def _tables_overlap(self): + def _tables_overlap(self) -> bool: """Return True if parent/child tables have some overlap.""" return selectables_overlap( self.parent_persist_selectable, self.child_persist_selectable ) - def _annotate_remote(self): + def _annotate_remote(self) -> None: """Annotate the primaryjoin and secondaryjoin structures with 'remote' annotations marking columns considered as part of the 'remote' side. @@ -2301,30 +2577,38 @@ class JoinCondition: else: self._annotate_remote_distinct_selectables() - def _annotate_remote_secondary(self): + def _annotate_remote_secondary(self) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when 'secondary' is present. """ - def repl(element): - if self.secondary.c.contains_column(element): + assert self.secondary is not None + fixed_secondary = self.secondary + + def repl(element: _CE, **kw: Any) -> Optional[_CE]: + if fixed_secondary.c.contains_column(element): return element._annotate({"remote": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, repl ) + + assert self.secondaryjoin is not None self.secondaryjoin = visitors.replacement_traverse( self.secondaryjoin, {}, repl ) - def _annotate_selfref(self, fn, remote_side_given): + def _annotate_selfref( + self, fn: Callable[[ColumnElement[Any]], bool], remote_side_given: bool + ) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when the relationship is detected as self-referential. """ - def visit_binary(binary): + def visit_binary(binary: BinaryExpression[Any]) -> None: equated = binary.left.compare(binary.right) if isinstance(binary.left, expression.ColumnClause) and isinstance( binary.right, expression.ColumnClause @@ -2341,7 +2625,7 @@ class JoinCondition: self.primaryjoin, {}, {"binary": visit_binary} ) - def _annotate_remote_from_args(self): + def _annotate_remote_from_args(self) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when the 'remote_side' or '_local_remote_pairs' arguments are used. @@ -2363,17 +2647,18 @@ class JoinCondition: self._annotate_selfref(lambda col: col in remote_side, True) else: - def repl(element): + def repl(element: _CE, **kw: Any) -> Optional[_CE]: # use set() to avoid generating ``__eq__()`` expressions # against each element if element in set(remote_side): return element._annotate({"remote": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, repl ) - def _annotate_remote_with_overlap(self): + def _annotate_remote_with_overlap(self) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when the parent/child tables have some set of tables in common, though is not a fully self-referential @@ -2381,7 +2666,7 @@ class JoinCondition: """ - def visit_binary(binary): + def visit_binary(binary: BinaryExpression[Any]) -> None: binary.left, binary.right = proc_left_right( binary.left, binary.right ) @@ -2393,7 +2678,9 @@ class JoinCondition: self.prop is not None and self.prop.mapper is not self.prop.parent ) - def proc_left_right(left, right): + def proc_left_right( + left: ColumnElement[Any], right: ColumnElement[Any] + ) -> Tuple[ColumnElement[Any], ColumnElement[Any]]: if isinstance(left, expression.ColumnClause) and isinstance( right, expression.ColumnClause ): @@ -2420,32 +2707,33 @@ class JoinCondition: self.primaryjoin, {}, {"binary": visit_binary} ) - def _annotate_remote_distinct_selectables(self): + def _annotate_remote_distinct_selectables(self) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when the parent/child tables are entirely separate. """ - def repl(element): + def repl(element: _CE, **kw: Any) -> Optional[_CE]: if self.child_persist_selectable.c.contains_column(element) and ( not self.parent_local_selectable.c.contains_column(element) or self.child_local_selectable.c.contains_column(element) ): return element._annotate({"remote": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, repl ) - def _warn_non_column_elements(self): + def _warn_non_column_elements(self) -> None: util.warn( "Non-simple column elements in primary " "join condition for property %s - consider using " "remote() annotations to mark the remote side." % self.prop ) - def _annotate_local(self): + def _annotate_local(self) -> None: """Annotate the primaryjoin and secondaryjoin structures with 'local' annotations. @@ -2466,29 +2754,31 @@ class JoinCondition: else: local_side = util.column_set(self.parent_persist_selectable.c) - def locals_(elem): - if "remote" not in elem._annotations and elem in local_side: - return elem._annotate({"local": True}) + def locals_(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" not in element._annotations and element in local_side: + return element._annotate({"local": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, locals_ ) - def _annotate_parentmapper(self): + def _annotate_parentmapper(self) -> None: if self.prop is None: return - def parentmappers_(elem): - if "remote" in elem._annotations: - return elem._annotate({"parentmapper": self.prop.mapper}) - elif "local" in elem._annotations: - return elem._annotate({"parentmapper": self.prop.parent}) + def parentmappers_(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" in element._annotations: + return element._annotate({"parentmapper": self.prop.mapper}) + elif "local" in element._annotations: + return element._annotate({"parentmapper": self.prop.parent}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, parentmappers_ ) - def _check_remote_side(self): + def _check_remote_side(self) -> None: if not self.local_remote_pairs: raise sa_exc.ArgumentError( "Relationship %s could " @@ -2501,7 +2791,9 @@ class JoinCondition: "the relationship." % (self.prop,) ) - def _check_foreign_cols(self, join_condition, primary): + def _check_foreign_cols( + self, join_condition: ColumnElement[bool], primary: bool + ) -> None: """Check the foreign key columns collected and emit error messages.""" @@ -2567,7 +2859,7 @@ class JoinCondition: ) raise sa_exc.ArgumentError(err) - def _determine_direction(self): + def _determine_direction(self) -> None: """Determine if this relationship is one to many, many to one, many to many. @@ -2651,7 +2943,9 @@ class JoinCondition: "nor the child's mapped tables" % self.prop ) - def _deannotate_pairs(self, collection): + def _deannotate_pairs( + self, collection: _ColumnPairIterable + ) -> _MutableColumnPairs: """provide deannotation for the various lists of pairs, so that using them in hashes doesn't incur high-overhead __eq__() comparisons against @@ -2660,13 +2954,22 @@ class JoinCondition: """ return [(x._deannotate(), y._deannotate()) for x, y in collection] - def _setup_pairs(self): - sync_pairs = [] - lrp = util.OrderedSet([]) - secondary_sync_pairs = [] - - def go(joincond, collection): - def visit_binary(binary, left, right): + def _setup_pairs(self) -> None: + sync_pairs: _MutableColumnPairs = [] + lrp: util.OrderedSet[ + Tuple[ColumnElement[Any], ColumnElement[Any]] + ] = util.OrderedSet([]) + secondary_sync_pairs: _MutableColumnPairs = [] + + def go( + joincond: ColumnElement[bool], + collection: _MutableColumnPairs, + ) -> None: + def visit_binary( + binary: BinaryExpression[Any], + left: ColumnElement[Any], + right: ColumnElement[Any], + ) -> None: if ( "remote" in right._annotations and "remote" not in left._annotations @@ -2703,9 +3006,12 @@ class JoinCondition: secondary_sync_pairs ) - _track_overlapping_sync_targets = weakref.WeakKeyDictionary() + _track_overlapping_sync_targets: weakref.WeakKeyDictionary[ + ColumnElement[Any], + weakref.WeakKeyDictionary[Relationship[Any], ColumnElement[Any]], + ] = weakref.WeakKeyDictionary() - def _warn_for_conflicting_sync_targets(self): + def _warn_for_conflicting_sync_targets(self) -> None: if not self.support_sync: return @@ -2793,18 +3099,20 @@ class JoinCondition: self._track_overlapping_sync_targets[to_][self.prop] = from_ @util.memoized_property - def remote_columns(self): + def remote_columns(self) -> Set[ColumnElement[Any]]: return self._gather_join_annotations("remote") @util.memoized_property - def local_columns(self): + def local_columns(self) -> Set[ColumnElement[Any]]: return self._gather_join_annotations("local") @util.memoized_property - def foreign_key_columns(self): + def foreign_key_columns(self) -> Set[ColumnElement[Any]]: return self._gather_join_annotations("foreign") - def _gather_join_annotations(self, annotation): + def _gather_join_annotations( + self, annotation: str + ) -> Set[ColumnElement[Any]]: s = set( self._gather_columns_with_annotation(self.primaryjoin, annotation) ) @@ -2816,24 +3124,32 @@ class JoinCondition: ) return {x._deannotate() for x in s} - def _gather_columns_with_annotation(self, clause, *annotation): - annotation = set(annotation) + def _gather_columns_with_annotation( + self, clause: ColumnElement[Any], *annotation: Iterable[str] + ) -> Set[ColumnElement[Any]]: + annotation_set = set(annotation) return set( [ - col + cast(ColumnElement[Any], col) for col in visitors.iterate(clause, {}) - if annotation.issubset(col._annotations) + if annotation_set.issubset(col._annotations) ] ) def join_targets( self, - source_selectable, - dest_selectable, - aliased, - single_crit=None, - extra_criteria=(), - ): + source_selectable: Optional[FromClause], + dest_selectable: FromClause, + aliased: bool, + single_crit: Optional[ColumnElement[bool]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ) -> Tuple[ + ColumnElement[bool], + Optional[ColumnElement[bool]], + Optional[FromClause], + Optional[ClauseAdapter], + FromClause, + ]: """Given a source and destination selectable, create a join between them. @@ -2923,9 +3239,15 @@ class JoinCondition: dest_selectable, ) - def create_lazy_clause(self, reverse_direction=False): - binds = util.column_dict() - equated_columns = util.column_dict() + def create_lazy_clause( + self, reverse_direction: bool = False + ) -> Tuple[ + ColumnElement[bool], + Dict[str, ColumnElement[Any]], + Dict[ColumnElement[Any], ColumnElement[Any]], + ]: + binds: Dict[ColumnElement[Any], BindParameter[Any]] = {} + equated_columns: Dict[ColumnElement[Any], ColumnElement[Any]] = {} has_secondary = self.secondaryjoin is not None @@ -2941,21 +3263,23 @@ class JoinCondition: for l, r in self.local_remote_pairs: equated_columns[l] = r - def col_to_bind(col): + def col_to_bind( + element: ColumnElement[Any], **kw: Any + ) -> Optional[BindParameter[Any]]: if ( - (not reverse_direction and "local" in col._annotations) + (not reverse_direction and "local" in element._annotations) or reverse_direction and ( - (has_secondary and col in lookup) - or (not has_secondary and "remote" in col._annotations) + (has_secondary and element in lookup) + or (not has_secondary and "remote" in element._annotations) ) ): - if col not in binds: - binds[col] = sql.bindparam( - None, None, type_=col.type, unique=True + if element not in binds: + binds[element] = sql.bindparam( + None, None, type_=element.type, unique=True ) - return binds[col] + return binds[element] return None lazywhere = self.primaryjoin @@ -2982,8 +3306,8 @@ class _ColInAnnotations: __slots__ = ("name",) - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __call__(self, c): + def __call__(self, c: ClauseElement) -> bool: return self.name in c._annotations diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index b5491248b8..d72e78c9e6 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -118,6 +118,7 @@ if typing.TYPE_CHECKING: from ..sql._typing import _T7 from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable + from ..sql.base import ExecutableOption from ..sql.elements import ClauseElement from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import TypedReturnsRows @@ -765,7 +766,7 @@ class SessionTransaction(_StateChange, TransactionalContext): self.session.dispatch.after_transaction_create(self.session, self) def _raise_for_prerequisite_state( - self, operation_name: str, state: SessionTransactionState + self, operation_name: str, state: _StateChangeState ) -> NoReturn: if state is SessionTransactionState.DEACTIVE: if self._rollback_exception: @@ -3183,7 +3184,7 @@ class Session(_SessionClassMethods, EventTarget): primary_key_identity: _PKIdentityArgument, db_load_fn: Callable[..., _O], *, - options: Optional[Sequence[ORMOption]] = None, + options: Optional[Sequence[ExecutableOption]] = None, populate_existing: bool = False, with_for_update: Optional[ForUpdateArg] = None, identity_token: Optional[Any] = None, @@ -3377,7 +3378,7 @@ class Session(_SessionClassMethods, EventTarget): *, options: Optional[Sequence[ORMOption]] = None, load: bool, - _recursive: Dict[InstanceState[Any], object], + _recursive: Dict[Any, object], _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], ) -> _O: mapper: Mapper[_O] = _state_mapper(state) diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index cb8b1f4aad..af9f487066 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -82,6 +82,22 @@ class _InstanceDictProto(Protocol): ... +class _InstallLoaderCallableProto(Protocol[_O]): + """used at result loading time to install a _LoaderCallable callable + upon a specific InstanceState, which will be used to populate an + attribute when that attribute is accessed. + + Concrete examples are per-instance deferred column loaders and + relationship lazy loaders. + + """ + + def __call__( + self, state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] + ) -> None: + ... + + @inspection._self_inspects class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): """tracks state information at the instance level. @@ -658,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): @classmethod def _instance_level_callable_processor( cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any - ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]: + ) -> _InstallLoaderCallableProto[_O]: impl = manager[key].impl if is_collection_impl(impl): fixed_impl = impl diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py index b7bf965585..764b5dfa6b 100644 --- a/lib/sqlalchemy/orm/state_changes.py +++ b/lib/sqlalchemy/orm/state_changes.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php + """State tracking utilities used by :class:`_orm.Session`. """ @@ -14,6 +15,9 @@ import contextlib from enum import Enum from typing import Any from typing import Callable +from typing import cast +from typing import Iterator +from typing import NoReturn from typing import Optional from typing import Tuple from typing import TypeVar @@ -48,9 +52,11 @@ class _StateChange: _next_state: _StateChangeState = _StateChangeStates.ANY _state: _StateChangeState = _StateChangeStates.NO_CHANGE - _current_fn: Optional[Callable] = None + _current_fn: Optional[Callable[..., Any]] = None - def _raise_for_prerequisite_state(self, operation_name, state): + def _raise_for_prerequisite_state( + self, operation_name: str, state: _StateChangeState + ) -> NoReturn: raise sa_exc.IllegalStateChangeError( f"Can't run operation '{operation_name}()' when Session " f"is in state {state!r}" @@ -80,16 +86,19 @@ class _StateChange: prerequisite_states is not _StateChangeStates.ANY ) + prerequisite_state_collection = cast( + "Tuple[_StateChangeState, ...]", prerequisite_states + ) expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE @util.decorator - def _go(fn, self, *arg, **kw): + def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any: current_state = self._state if ( has_prerequisite_states - and current_state not in prerequisite_states + and current_state not in prerequisite_state_collection ): self._raise_for_prerequisite_state(fn.__name__, current_state) @@ -159,7 +168,7 @@ class _StateChange: return _go @contextlib.contextmanager - def _expect_state(self, expected: _StateChangeState): + def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]: """called within a method that changes states. method must also use the ``@declare_states()`` decorator. diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 0ba22e7a7c..5dc80e4f28 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -14,6 +14,10 @@ from __future__ import annotations import collections import itertools +from typing import Any +from typing import Dict +from typing import Tuple +from typing import TYPE_CHECKING from . import attributes from . import exc as orm_exc @@ -28,7 +32,9 @@ from . import util as orm_util from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED +from .base import LoaderCallableStatus from .base import PASSIVE_OFF +from .base import PassiveFlag from .context import _column_descriptions from .context import ORMCompileState from .context import ORMSelectCompileState @@ -50,6 +56,10 @@ from ..sql import visitors from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import Select +if TYPE_CHECKING: + from .relationships import Relationship + from ..sql.elements import ColumnElement + def _register_attribute( prop, @@ -486,10 +496,10 @@ class DeferredColumnLoader(LoaderStrategy): def _load_for_state(self, state, passive): if not state.key: - return attributes.ATTR_EMPTY + return LoaderCallableStatus.ATTR_EMPTY - if not passive & attributes.SQL_OK: - return attributes.PASSIVE_NO_RESULT + if not passive & PassiveFlag.SQL_OK: + return LoaderCallableStatus.PASSIVE_NO_RESULT localparent = state.manager.mapper @@ -522,7 +532,7 @@ class DeferredColumnLoader(LoaderStrategy): state.mapper, state, set(group), PASSIVE_OFF ) - return attributes.ATTR_WAS_SET + return LoaderCallableStatus.ATTR_WAS_SET def _invoke_raise_load(self, state, passive, lazy): raise sa_exc.InvalidRequestError( @@ -626,7 +636,9 @@ class NoLoader(AbstractRelationshipLoader): @relationships.Relationship.strategy_for(lazy="raise") @relationships.Relationship.strategy_for(lazy="raise_on_sql") @relationships.Relationship.strategy_for(lazy="baked_select") -class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): +class LazyLoader( + AbstractRelationshipLoader, util.MemoizedSlots, log.Identified +): """Provide loading behavior for a :class:`.Relationship` with "lazy=True", that is loads when first accessed. @@ -648,7 +660,16 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): "_raise_on_sql", ) - def __init__(self, parent, strategy_key): + _lazywhere: ColumnElement[bool] + _bind_to_col: Dict[str, ColumnElement[Any]] + _rev_lazywhere: ColumnElement[bool] + _rev_bind_to_col: Dict[str, ColumnElement[Any]] + + parent_property: Relationship[Any] + + def __init__( + self, parent: Relationship[Any], strategy_key: Tuple[Any, ...] + ): super(LazyLoader, self).__init__(parent, strategy_key) self._raise_always = self.strategy_opts["lazy"] == "raise" self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql" @@ -786,13 +807,13 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): o = state.obj() # strong ref dict_ = attributes.instance_dict(o) - if passive & attributes.INIT_OK: - passive ^= attributes.INIT_OK + if passive & PassiveFlag.INIT_OK: + passive ^= PassiveFlag.INIT_OK params = {} for key, ident, value in param_keys: if ident is not None: - if passive and passive & attributes.LOAD_AGAINST_COMMITTED: + if passive and passive & PassiveFlag.LOAD_AGAINST_COMMITTED: value = mapper._get_committed_state_attr_by_column( state, dict_, ident, passive ) @@ -818,23 +839,23 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): ) or not state.session_id ): - return attributes.ATTR_EMPTY + return LoaderCallableStatus.ATTR_EMPTY pending = not state.key primary_key_identity = None use_get = self.use_get and (not loadopt or not loadopt._extra_criteria) - if (not passive & attributes.SQL_OK and not use_get) or ( + if (not passive & PassiveFlag.SQL_OK and not use_get) or ( not passive & attributes.NON_PERSISTENT_OK and pending ): - return attributes.PASSIVE_NO_RESULT + return LoaderCallableStatus.PASSIVE_NO_RESULT if ( # we were given lazy="raise" self._raise_always # the no_raise history-related flag was not passed - and not passive & attributes.NO_RAISE + and not passive & PassiveFlag.NO_RAISE and ( # if we are use_get and related_object_ok is disabled, # which means we are at most looking in the identity map @@ -842,7 +863,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): # PASSIVE_NO_RESULT, don't raise. This is also a # history-related flag not use_get - or passive & attributes.RELATED_OBJECT_OK + or passive & PassiveFlag.RELATED_OBJECT_OK ) ): @@ -850,8 +871,8 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): session = _state_session(state) if not session: - if passive & attributes.NO_RAISE: - return attributes.PASSIVE_NO_RESULT + if passive & PassiveFlag.NO_RAISE: + return LoaderCallableStatus.PASSIVE_NO_RESULT raise orm_exc.DetachedInstanceError( "Parent instance %s is not bound to a Session; " @@ -865,19 +886,19 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): primary_key_identity = self._get_ident_for_use_get( session, state, passive ) - if attributes.PASSIVE_NO_RESULT in primary_key_identity: - return attributes.PASSIVE_NO_RESULT - elif attributes.NEVER_SET in primary_key_identity: - return attributes.NEVER_SET + if LoaderCallableStatus.PASSIVE_NO_RESULT in primary_key_identity: + return LoaderCallableStatus.PASSIVE_NO_RESULT + elif LoaderCallableStatus.NEVER_SET in primary_key_identity: + return LoaderCallableStatus.NEVER_SET if _none_set.issuperset(primary_key_identity): return None if ( self.key in state.dict - and not passive & attributes.DEFERRED_HISTORY_LOAD + and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD ): - return attributes.ATTR_WAS_SET + return LoaderCallableStatus.ATTR_WAS_SET # look for this identity in the identity map. Delegate to the # Query class in use, as it may have special rules for how it @@ -892,15 +913,15 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): ) if instance is not None: - if instance is attributes.PASSIVE_CLASS_MISMATCH: + if instance is LoaderCallableStatus.PASSIVE_CLASS_MISMATCH: return None else: return instance elif ( - not passive & attributes.SQL_OK - or not passive & attributes.RELATED_OBJECT_OK + not passive & PassiveFlag.SQL_OK + or not passive & PassiveFlag.RELATED_OBJECT_OK ): - return attributes.PASSIVE_NO_RESULT + return LoaderCallableStatus.PASSIVE_NO_RESULT return self._emit_lazyload( session, @@ -914,7 +935,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): def _get_ident_for_use_get(self, session, state, passive): instance_mapper = state.manager.mapper - if passive & attributes.LOAD_AGAINST_COMMITTED: + if passive & PassiveFlag.LOAD_AGAINST_COMMITTED: get_attr = instance_mapper._get_committed_state_attr_by_column else: get_attr = instance_mapper._get_state_attr_by_column @@ -985,7 +1006,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): stmt._compile_options += {"_current_path": effective_path} if use_get: - if self._raise_on_sql and not passive & attributes.NO_RAISE: + if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE: self._invoke_raise_load(state, passive, "raise_on_sql") return loading.load_on_pk_identity( @@ -1022,9 +1043,9 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): if ( self.key in state.dict - and not passive & attributes.DEFERRED_HISTORY_LOAD + and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD ): - return attributes.ATTR_WAS_SET + return LoaderCallableStatus.ATTR_WAS_SET if pending: if util.has_intersection(orm_util._none_set, params.values()): @@ -1033,7 +1054,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): elif util.has_intersection(orm_util._never_set, params.values()): return None - if self._raise_on_sql and not passive & attributes.NO_RAISE: + if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE: self._invoke_raise_load(state, passive, "raise_on_sql") stmt._where_criteria = (lazy_clause,) @@ -1246,9 +1267,9 @@ class ImmediateLoader(PostLoader): # "use get" load. the "_RELATED" part means it may return # instance even if its expired, since this is a mutually-recursive # load operation. - flags = attributes.PASSIVE_NO_FETCH_RELATED | attributes.NO_RAISE + flags = attributes.PASSIVE_NO_FETCH_RELATED | PassiveFlag.NO_RAISE else: - flags = attributes.PASSIVE_OFF | attributes.NO_RAISE + flags = attributes.PASSIVE_OFF | PassiveFlag.NO_RAISE populators["delayed"].append((self.key, load_immediate)) @@ -2840,7 +2861,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): # if the loaded parent objects do not have the foreign key # to the related item loaded, then degrade into the joined # version of selectinload - if attributes.PASSIVE_NO_RESULT in related_ident: + if LoaderCallableStatus.PASSIVE_NO_RESULT in related_ident: query_info = self._fallback_query_info break diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 63679dd275..7aed6dd7bb 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -3,6 +3,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls """ @@ -12,18 +13,30 @@ from __future__ import annotations import typing from typing import Any +from typing import Callable from typing import cast -from typing import Mapping -from typing import NoReturn +from typing import Dict +from typing import Iterable from typing import Optional +from typing import overload +from typing import Sequence from typing import Tuple +from typing import Type +from typing import TypeVar from typing import Union from . import util as orm_util +from ._typing import insp_is_aliased_class +from ._typing import insp_is_attribute +from ._typing import insp_is_mapper +from ._typing import insp_is_mapper_property +from .attributes import QueryableAttribute from .base import InspectionAttr from .interfaces import LoaderOption from .path_registry import _DEFAULT_TOKEN from .path_registry import _WILDCARD_TOKEN +from .path_registry import AbstractEntityRegistry +from .path_registry import path_is_property from .path_registry import PathRegistry from .path_registry import TokenRegistry from .util import _orm_full_deannotate @@ -38,14 +51,37 @@ from ..sql import roles from ..sql import traversals from ..sql import visitors from ..sql.base import _generative +from ..util.typing import Final +from ..util.typing import Literal -_RELATIONSHIP_TOKEN = "relationship" -_COLUMN_TOKEN = "column" +_RELATIONSHIP_TOKEN: Final[Literal["relationship"]] = "relationship" +_COLUMN_TOKEN: Final[Literal["column"]] = "column" + +_FN = TypeVar("_FN", bound="Callable[..., Any]") if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _InternalEntityType + from .context import _MapperEntity + from .context import ORMCompileState + from .context import QueryContext + from .interfaces import _StrategyKey + from .interfaces import MapperProperty from .mapper import Mapper + from .path_registry import _PathRepresentation + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _FromClauseArgument + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.cache_key import CacheKey + +Self_AbstractLoad = TypeVar("Self_AbstractLoad", bound="_AbstractLoad") + +_AttrType = Union[str, "QueryableAttribute[Any]"] -Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad") +_WildcardKeyType = Literal["relationship", "column"] +_StrategySpec = Dict[str, Any] +_OptsType = Dict[str, Any] +_AttrGroupType = Tuple[_AttrType, ...] class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): @@ -54,7 +90,12 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): _is_strategy_option = True propagate_to_loaders: bool - def contains_eager(self, attr, alias=None, _is_chain=False): + def contains_eager( + self: Self_AbstractLoad, + attr: _AttrType, + alias: Optional[_FromClauseArgument] = None, + _is_chain: bool = False, + ) -> Self_AbstractLoad: r"""Indicate that the given attribute should be eagerly loaded from columns stated manually in the query. @@ -94,9 +135,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ if alias is not None: if not isinstance(alias, str): - info = inspect(alias) - alias = info.selectable - + coerced_alias = coercions.expect(roles.FromClauseRole, alias) else: util.warn_deprecated( "Passing a string name for the 'alias' argument to " @@ -105,21 +144,28 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): "sqlalchemy.orm.aliased() construct.", version="1.4", ) + coerced_alias = alias elif getattr(attr, "_of_type", None): - ot = inspect(attr._of_type) - alias = ot.selectable + assert isinstance(attr, QueryableAttribute) + ot: Optional[_InternalEntityType[Any]] = inspect(attr._of_type) + assert ot is not None + coerced_alias = ot.selectable + else: + coerced_alias = None cloned = self._set_relationship_strategy( attr, {"lazy": "joined"}, propagate_to_loaders=False, - opts={"eager_from_alias": alias}, + opts={"eager_from_alias": coerced_alias}, _reconcile_to_other=True if _is_chain else None, ) return cloned - def load_only(self, *attrs): + def load_only( + self: Self_AbstractLoad, *attrs: _AttrType + ) -> Self_AbstractLoad: """Indicate that for a particular entity, only the given list of column-based attribute names should be loaded; all others will be deferred. @@ -159,11 +205,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): {"deferred": False, "instrument": True}, ) cloned = cloned._set_column_strategy( - "*", {"deferred": True, "instrument": True}, {"undefer_pks": True} + ("*",), + {"deferred": True, "instrument": True}, + {"undefer_pks": True}, ) return cloned - def joinedload(self, attr, innerjoin=None): + def joinedload( + self: Self_AbstractLoad, + attr: _AttrType, + innerjoin: Optional[bool] = None, + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using joined eager loading. @@ -258,7 +310,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): ) return loader - def subqueryload(self, attr): + def subqueryload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using subquery eager loading. @@ -289,7 +343,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_relationship_strategy(attr, {"lazy": "subquery"}) - def selectinload(self, attr): + def selectinload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using SELECT IN eager loading. @@ -321,7 +377,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_relationship_strategy(attr, {"lazy": "selectin"}) - def lazyload(self, attr): + def lazyload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using "lazy" loading. @@ -337,7 +395,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_relationship_strategy(attr, {"lazy": "select"}) - def immediateload(self, attr): + def immediateload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using an immediate load with a per-attribute SELECT statement. @@ -361,7 +421,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): loader = self._set_relationship_strategy(attr, {"lazy": "immediate"}) return loader - def noload(self, attr): + def noload(self: Self_AbstractLoad, attr: _AttrType) -> Self_AbstractLoad: """Indicate that the given relationship attribute should remain unloaded. @@ -387,7 +447,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): return self._set_relationship_strategy(attr, {"lazy": "noload"}) - def raiseload(self, attr, sql_only=False): + def raiseload( + self: Self_AbstractLoad, attr: _AttrType, sql_only: bool = False + ) -> Self_AbstractLoad: """Indicate that the given attribute should raise an error if accessed. A relationship attribute configured with :func:`_orm.raiseload` will @@ -428,7 +490,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): attr, {"lazy": "raise_on_sql" if sql_only else "raise"} ) - def defaultload(self, attr): + def defaultload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate an attribute should load using its default loader style. This method is used to link to other loader options further into @@ -463,7 +527,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_relationship_strategy(attr, None) - def defer(self, key, raiseload=False): + def defer( + self: Self_AbstractLoad, key: _AttrType, raiseload: bool = False + ) -> Self_AbstractLoad: r"""Indicate that the given column-oriented attribute should be deferred, e.g. not loaded until accessed. @@ -524,7 +590,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): strategy["raiseload"] = True return self._set_column_strategy((key,), strategy) - def undefer(self, key): + def undefer(self: Self_AbstractLoad, key: _AttrType) -> Self_AbstractLoad: r"""Indicate that the given column-oriented attribute should be undeferred, e.g. specified within the SELECT statement of the entity as a whole. @@ -538,7 +604,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): Examples:: # undefer two columns - session.query(MyClass).options(undefer("col1"), undefer("col2")) + session.query(MyClass).options( + undefer(MyClass.col1), undefer(MyClass.col2) + ) # undefer all columns specific to a single class using Load + * session.query(MyClass, MyOtherClass).options( @@ -546,7 +614,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): # undefer a column on a related object session.query(MyClass).options( - defaultload(MyClass.items).undefer('text')) + defaultload(MyClass.items).undefer(MyClass.text)) :param key: Attribute to be undeferred. @@ -563,7 +631,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): (key,), {"deferred": False, "instrument": True} ) - def undefer_group(self, name): + def undefer_group(self: Self_AbstractLoad, name: str) -> Self_AbstractLoad: """Indicate that columns within the given deferred group name should be undeferred. @@ -591,10 +659,14 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_column_strategy( - _WILDCARD_TOKEN, None, {f"undefer_group_{name}": True} + (_WILDCARD_TOKEN,), None, {f"undefer_group_{name}": True} ) - def with_expression(self, key, expression): + def with_expression( + self: Self_AbstractLoad, + key: _AttrType, + expression: _ColumnExpressionArgument[Any], + ) -> Self_AbstractLoad: r"""Apply an ad-hoc SQL expression to a "deferred expression" attribute. @@ -626,15 +698,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ - expression = coercions.expect( - roles.LabeledColumnExprRole, _orm_full_deannotate(expression) + expression = _orm_full_deannotate( + coercions.expect(roles.LabeledColumnExprRole, expression) ) return self._set_column_strategy( (key,), {"query_expression": True}, opts={"expression": expression} ) - def selectin_polymorphic(self, classes): + def selectin_polymorphic( + self: Self_AbstractLoad, classes: Iterable[Type[Any]] + ) -> Self_AbstractLoad: """Indicate an eager load should take place for all attributes specific to a subclass. @@ -659,25 +733,37 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): ) return self - def _coerce_strat(self, strategy): + @overload + def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: + ... + + @overload + def _coerce_strat(self, strategy: Literal[None]) -> None: + ... + + def _coerce_strat( + self, strategy: Optional[_StrategySpec] + ) -> Optional[_StrategyKey]: if strategy is not None: - strategy = tuple(sorted(strategy.items())) - return strategy + strategy_key = tuple(sorted(strategy.items())) + else: + strategy_key = None + return strategy_key @_generative def _set_relationship_strategy( self: Self_AbstractLoad, - attr, - strategy, - propagate_to_loaders=True, - opts=None, - _reconcile_to_other=None, + attr: _AttrType, + strategy: Optional[_StrategySpec], + propagate_to_loaders: bool = True, + opts: Optional[_OptsType] = None, + _reconcile_to_other: Optional[bool] = None, ) -> Self_AbstractLoad: - strategy = self._coerce_strat(strategy) + strategy_key = self._coerce_strat(strategy) self._clone_for_bind_strategy( (attr,), - strategy, + strategy_key, _RELATIONSHIP_TOKEN, opts=opts, propagate_to_loaders=propagate_to_loaders, @@ -687,13 +773,16 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): @_generative def _set_column_strategy( - self: Self_AbstractLoad, attrs, strategy, opts=None + self: Self_AbstractLoad, + attrs: Tuple[_AttrType, ...], + strategy: Optional[_StrategySpec], + opts: Optional[_OptsType] = None, ) -> Self_AbstractLoad: - strategy = self._coerce_strat(strategy) + strategy_key = self._coerce_strat(strategy) self._clone_for_bind_strategy( attrs, - strategy, + strategy_key, _COLUMN_TOKEN, opts=opts, attr_group=attrs, @@ -702,12 +791,15 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): @_generative def _set_generic_strategy( - self: Self_AbstractLoad, attrs, strategy, _reconcile_to_other=None + self: Self_AbstractLoad, + attrs: Tuple[_AttrType, ...], + strategy: _StrategySpec, + _reconcile_to_other: Optional[bool] = None, ) -> Self_AbstractLoad: - strategy = self._coerce_strat(strategy) + strategy_key = self._coerce_strat(strategy) self._clone_for_bind_strategy( attrs, - strategy, + strategy_key, None, propagate_to_loaders=True, reconcile_to_other=_reconcile_to_other, @@ -716,14 +808,14 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): @_generative def _set_class_strategy( - self: Self_AbstractLoad, strategy, opts + self: Self_AbstractLoad, strategy: _StrategySpec, opts: _OptsType ) -> Self_AbstractLoad: - strategy = self._coerce_strat(strategy) + strategy_key = self._coerce_strat(strategy) - self._clone_for_bind_strategy(None, strategy, None, opts=opts) + self._clone_for_bind_strategy(None, strategy_key, None, opts=opts) return self - def _apply_to_parent(self, parent): + def _apply_to_parent(self, parent: Load) -> None: """apply this :class:`_orm._AbstractLoad` object as a sub-option o a :class:`_orm.Load` object. @@ -732,7 +824,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ raise NotImplementedError() - def options(self: Self_AbstractLoad, *opts) -> NoReturn: + def options( + self: Self_AbstractLoad, *opts: _AbstractLoad + ) -> Self_AbstractLoad: r"""Apply a series of options as sub-options to this :class:`_orm._AbstractLoad` object. @@ -742,20 +836,22 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): raise NotImplementedError() def _clone_for_bind_strategy( - self, - attrs, - strategy, - wildcard_key, - opts=None, - attr_group=None, - propagate_to_loaders=True, - reconcile_to_other=None, - ): + self: Self_AbstractLoad, + attrs: Optional[Tuple[_AttrType, ...]], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + opts: Optional[_OptsType] = None, + attr_group: Optional[_AttrGroupType] = None, + propagate_to_loaders: bool = True, + reconcile_to_other: Optional[bool] = None, + ) -> Self_AbstractLoad: raise NotImplementedError() def process_compile_state_replaced_entities( - self, compile_state, mapper_entities - ): + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: if not compile_state.compile_options._enable_eagerloads: return @@ -768,7 +864,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): not bool(compile_state.current_path), ) - def process_compile_state(self, compile_state): + def process_compile_state(self, compile_state: ORMCompileState) -> None: if not compile_state.compile_options._enable_eagerloads: return @@ -779,12 +875,22 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): and not compile_state.compile_options._for_refresh_state, ) - def _process(self, compile_state, mapper_entities, raiseerr): + def _process( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + raiseerr: bool, + ) -> None: """implemented by subclasses""" raise NotImplementedError() @classmethod - def _chop_path(cls, to_chop, path, debug=False): + def _chop_path( + cls, + to_chop: _PathRepresentation, + path: PathRegistry, + debug: bool = False, + ) -> Optional[_PathRepresentation]: i = -1 for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)): @@ -793,7 +899,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): return to_chop elif ( c_token != f"{_RELATIONSHIP_TOKEN}:{_WILDCARD_TOKEN}" - and c_token != p_token.key + and c_token != p_token.key # type: ignore ): return None @@ -801,9 +907,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): continue elif ( isinstance(c_token, InspectionAttr) - and c_token.is_mapper + and insp_is_mapper(c_token) and ( - (p_token.is_mapper and c_token.isa(p_token)) + (insp_is_mapper(p_token) and c_token.isa(p_token)) or ( # a too-liberal check here to allow a path like # A->A.bs->B->B.cs->C->C.ds, natural path, to chop @@ -827,10 +933,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): # test_of_type.py->test_all_subq_query # i >= 2 - and p_token.is_aliased_class + and insp_is_aliased_class(p_token) and p_token._is_with_polymorphic and c_token in p_token.with_polymorphic_mappers - # and (breakpoint() or True) ) ) ): @@ -841,7 +946,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): return to_chop[i + 1 :] -SelfLoad = typing.TypeVar("SelfLoad", bound="Load") +SelfLoad = TypeVar("SelfLoad", bound="Load") class Load(_AbstractLoad): @@ -903,28 +1008,28 @@ class Load(_AbstractLoad): _cache_key_traversal = None path: PathRegistry - context: Tuple["_LoadElement", ...] + context: Tuple[_LoadElement, ...] - def __init__(self, entity): - insp = cast(Union["Mapper", AliasedInsp], inspect(entity)) + def __init__(self, entity: _EntityType[Any]): + insp = cast("Union[Mapper[Any], AliasedInsp[Any]]", inspect(entity)) insp._post_inspect self.path = insp._path_registry self.context = () self.propagate_to_loaders = False - def __str__(self): + def __str__(self) -> str: return f"Load({self.path[0]})" @classmethod - def _construct_for_existing_path(cls, path): + def _construct_for_existing_path(cls, path: PathRegistry) -> Load: load = cls.__new__(cls) load.path = path load.context = () load.propagate_to_loaders = False return load - def _adjust_for_extra_criteria(self, context): + def _adjust_for_extra_criteria(self, context: QueryContext) -> Load: """Apply the current bound parameters in a QueryContext to all occurrences "extra_criteria" stored within this ``Load`` object, returning a new instance of this ``Load`` object. @@ -932,10 +1037,10 @@ class Load(_AbstractLoad): """ orig_query = context.compile_state.select_statement - orig_cache_key = None - replacement_cache_key = None + orig_cache_key: Optional[CacheKey] = None + replacement_cache_key: Optional[CacheKey] = None - def process(opt): + def process(opt: _LoadElement) -> _LoadElement: if not opt._extra_criteria: return opt @@ -948,6 +1053,9 @@ class Load(_AbstractLoad): orig_cache_key = orig_query._generate_cache_key() replacement_cache_key = context.query._generate_cache_key() + assert orig_cache_key is not None + assert replacement_cache_key is not None + opt._extra_criteria = tuple( replacement_cache_key._apply_params_to_element( orig_cache_key, crit @@ -975,12 +1083,22 @@ class Load(_AbstractLoad): ezero = None for ent in mapper_entities: ezero = ent.entity_zero - if ezero and orm_util._entity_corresponds_to(ezero, path[0]): + if ezero and orm_util._entity_corresponds_to( + # technically this can be a token also, but this is + # safe to pass to _entity_corresponds_to() + ezero, + cast("_InternalEntityType[Any]", path[0]), + ): return ezero return None - def _process(self, compile_state, mapper_entities, raiseerr): + def _process( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + raiseerr: bool, + ) -> None: reconciled_lead_entity = self._reconcile_query_entities_with_us( mapper_entities, raiseerr @@ -995,7 +1113,7 @@ class Load(_AbstractLoad): raiseerr, ) - def _apply_to_parent(self, parent): + def _apply_to_parent(self, parent: Load) -> None: """apply this :class:`_orm.Load` object as a sub-option of another :class:`_orm.Load` object. @@ -1007,7 +1125,8 @@ class Load(_AbstractLoad): assert cloned.propagate_to_loaders == self.propagate_to_loaders if not orm_util._entity_corresponds_to_use_path_impl( - parent.path[-1], cloned.path[0] + cast("_InternalEntityType[Any]", parent.path[-1]), + cast("_InternalEntityType[Any]", cloned.path[0]), ): raise sa_exc.ArgumentError( f'Attribute "{cloned.path[1]}" does not link ' @@ -1025,7 +1144,7 @@ class Load(_AbstractLoad): parent.context += cloned.context @_generative - def options(self: SelfLoad, *opts) -> SelfLoad: + def options(self: SelfLoad, *opts: _AbstractLoad) -> SelfLoad: r"""Apply a series of options as sub-options to this :class:`_orm.Load` object. @@ -1062,38 +1181,36 @@ class Load(_AbstractLoad): return self def _clone_for_bind_strategy( - self, - attrs, - strategy, - wildcard_key, - opts=None, - attr_group=None, - propagate_to_loaders=True, - reconcile_to_other=None, - ) -> None: + self: SelfLoad, + attrs: Optional[Tuple[_AttrType, ...]], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + opts: Optional[_OptsType] = None, + attr_group: Optional[_AttrGroupType] = None, + propagate_to_loaders: bool = True, + reconcile_to_other: Optional[bool] = None, + ) -> SelfLoad: # for individual strategy that needs to propagate, set the whole # Load container to also propagate, so that it shows up in # InstanceState.load_options if propagate_to_loaders: self.propagate_to_loaders = True - if not self.path.has_entity: - if self.path.is_token: + if self.path.is_token: + raise sa_exc.ArgumentError( + "Wildcard token cannot be followed by another entity" + ) + + elif path_is_property(self.path): + # re-use the lookup which will raise a nicely formatted + # LoaderStrategyException + if strategy: + self.path.prop._strategy_lookup(self.path.prop, strategy[0]) + else: raise sa_exc.ArgumentError( - "Wildcard token cannot be followed by another entity" + f"Mapped attribute '{self.path.prop}' does not " + "refer to a mapped entity" ) - else: - # re-use the lookup which will raise a nicely formatted - # LoaderStrategyException - if strategy: - self.path.prop._strategy_lookup( - self.path.prop, strategy[0] - ) - else: - raise sa_exc.ArgumentError( - f"Mapped attribute '{self.path.prop}' does not " - "refer to a mapped entity" - ) if attrs is None: load_element = _ClassStrategyLoad.create( @@ -1140,6 +1257,7 @@ class Load(_AbstractLoad): if wildcard_key is _RELATIONSHIP_TOKEN: self.path = load_element.path self.context += (load_element,) + return self def __getstate__(self): d = self._shallow_to_dict() @@ -1151,7 +1269,7 @@ class Load(_AbstractLoad): self._shallow_from_dict(state) -SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad") +SelfWildcardLoad = TypeVar("SelfWildcardLoad", bound="_WildcardLoad") class _WildcardLoad(_AbstractLoad): @@ -1167,14 +1285,14 @@ class _WildcardLoad(_AbstractLoad): visitors.ExtendedInternalTraversal.dp_string_multi_dict, ), ] - cache_key_traversal = None + cache_key_traversal: _CacheKeyTraversalType = None strategy: Optional[Tuple[Any, ...]] - local_opts: Mapping[str, Any] + local_opts: _OptsType path: Tuple[str, ...] propagate_to_loaders = False - def __init__(self): + def __init__(self) -> None: self.path = () self.strategy = None self.local_opts = util.EMPTY_DICT @@ -1189,6 +1307,7 @@ class _WildcardLoad(_AbstractLoad): propagate_to_loaders=True, reconcile_to_other=None, ): + assert attrs is not None attr = attrs[0] assert ( wildcard_key @@ -1203,10 +1322,12 @@ class _WildcardLoad(_AbstractLoad): if opts: self.local_opts = util.immutabledict(opts) - def options(self: SelfWildcardLoad, *opts) -> SelfWildcardLoad: + def options( + self: SelfWildcardLoad, *opts: _AbstractLoad + ) -> SelfWildcardLoad: raise NotImplementedError("Star option does not support sub-options") - def _apply_to_parent(self, parent): + def _apply_to_parent(self, parent: Load) -> None: """apply this :class:`_orm._WildcardLoad` object as a sub-option of a :class:`_orm.Load` object. @@ -1215,12 +1336,11 @@ class _WildcardLoad(_AbstractLoad): it may be used as the sub-option of a :class:`_orm.Load` object. """ - attr = self.path[0] if attr.endswith(_DEFAULT_TOKEN): attr = f"{attr.split(':')[0]}:{_WILDCARD_TOKEN}" - effective_path = parent.path.token(attr) + effective_path = cast(AbstractEntityRegistry, parent.path).token(attr) assert effective_path.is_token @@ -1244,20 +1364,21 @@ class _WildcardLoad(_AbstractLoad): entities = [ent.entity_zero for ent in mapper_entities] current_path = compile_state.current_path - start_path = self.path + start_path: _PathRepresentation = self.path # TODO: chop_path already occurs in loader.process_compile_state() # so we will seek to simplify this if current_path: - start_path = self._chop_path(start_path, current_path) - if not start_path: + new_path = self._chop_path(start_path, current_path) + if not new_path: return + start_path = new_path # start_path is a single-token tuple assert start_path and len(start_path) == 1 token = start_path[0] - + assert isinstance(token, str) entity = self._find_entity_basestring(entities, token, raiseerr) if not entity: @@ -1270,6 +1391,7 @@ class _WildcardLoad(_AbstractLoad): # we just located, then go through the rest of our path # tokens and populate into the Load(). + assert isinstance(token, str) loader = _TokenStrategyLoad.create( path_element._path_registry, token, @@ -1291,7 +1413,12 @@ class _WildcardLoad(_AbstractLoad): return loader - def _find_entity_basestring(self, entities, token, raiseerr): + def _find_entity_basestring( + self, + entities: Iterable[_InternalEntityType[Any]], + token: str, + raiseerr: bool, + ) -> Optional[_InternalEntityType[Any]]: if token.endswith(f":{_WILDCARD_TOKEN}"): if len(list(entities)) != 1: if raiseerr: @@ -1324,11 +1451,11 @@ class _WildcardLoad(_AbstractLoad): else: return None - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: d = self._shallow_to_dict() return d - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self._shallow_from_dict(state) @@ -1372,38 +1499,38 @@ class _LoadElement( _extra_criteria: Tuple[Any, ...] _reconcile_to_other: Optional[bool] - strategy: Tuple[Any, ...] + strategy: Optional[_StrategyKey] path: PathRegistry propagate_to_loaders: bool - local_opts: Mapping[str, Any] + local_opts: util.immutabledict[str, Any] is_token_strategy: bool is_class_strategy: bool - def __hash__(self): + def __hash__(self) -> int: return id(self) def __eq__(self, other): return traversals.compare(self, other) @property - def is_opts_only(self): + def is_opts_only(self) -> bool: return bool(self.local_opts and self.strategy is None) - def _clone(self): + def _clone(self, **kw: Any) -> _LoadElement: cls = self.__class__ s = cls.__new__(cls) self._shallow_copy_to(s) return s - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: d = self._shallow_to_dict() d["path"] = self.path.serialize() return d - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: state["path"] = PathRegistry.deserialize(state["path"]) self._shallow_from_dict(state) @@ -1437,8 +1564,8 @@ class _LoadElement( ) def _adjust_effective_path_for_current_path( - self, effective_path, current_path - ): + self, effective_path: PathRegistry, current_path: PathRegistry + ) -> Optional[PathRegistry]: """receives the 'current_path' entry from an :class:`.ORMCompileState` instance, which is set during lazy loads and secondary loader strategy loads, and adjusts the given path to be relative to the @@ -1456,7 +1583,7 @@ class _LoadElement( """ - chopped_start_path = Load._chop_path(effective_path, current_path) + chopped_start_path = Load._chop_path(effective_path.path, current_path) if not chopped_start_path: return None @@ -1523,16 +1650,16 @@ class _LoadElement( @classmethod def create( cls, - path, - attr, - strategy, - wildcard_key, - local_opts, - propagate_to_loaders, - raiseerr=True, - attr_group=None, - reconcile_to_other=None, - ): + path: PathRegistry, + attr: Optional[_AttrType], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + local_opts: Optional[_OptsType], + propagate_to_loaders: bool, + raiseerr: bool = True, + attr_group: Optional[_AttrGroupType] = None, + reconcile_to_other: Optional[bool] = None, + ) -> _LoadElement: """Create a new :class:`._LoadElement` object.""" opt = cls.__new__(cls) @@ -1554,14 +1681,14 @@ class _LoadElement( path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr) if not path: - return None + return None # type: ignore assert opt.is_token_strategy == path.is_token opt.path = path return opt - def __init__(self, path, strategy, local_opts, propagate_to_loaders): + def __init__(self) -> None: raise NotImplementedError() def _prepend_path_from(self, parent): @@ -1580,7 +1707,8 @@ class _LoadElement( assert cloned.is_class_strategy == self.is_class_strategy if not orm_util._entity_corresponds_to_use_path_impl( - parent.path[-1], cloned.path[0] + cast("_InternalEntityType[Any]", parent.path[-1]), + cast("_InternalEntityType[Any]", cloned.path[0]), ): raise sa_exc.ArgumentError( f'Attribute "{cloned.path[1]}" does not link ' @@ -1592,7 +1720,9 @@ class _LoadElement( return cloned @staticmethod - def _reconcile(replacement, existing): + def _reconcile( + replacement: _LoadElement, existing: _LoadElement + ) -> _LoadElement: """define behavior for when two Load objects are to be put into the context.attributes under the same key. @@ -1670,7 +1800,7 @@ class _AttributeStrategyLoad(_LoadElement): ), ] - _of_type: Union["Mapper", AliasedInsp, None] + _of_type: Union["Mapper[Any]", "AliasedInsp[Any]", None] _path_with_polymorphic_path: Optional[PathRegistry] is_class_strategy = False @@ -1812,7 +1942,7 @@ class _AttributeStrategyLoad(_LoadElement): pwpi = inspect( orm_util.AliasedInsp._with_polymorphic_factory( pwpi.mapper.base_mapper, - pwpi.mapper, + (pwpi.mapper,), aliased=True, _use_mapper_path=True, ) @@ -1820,11 +1950,12 @@ class _AttributeStrategyLoad(_LoadElement): start_path = self._path_with_polymorphic_path if current_path: - start_path = self._adjust_effective_path_for_current_path( + new_path = self._adjust_effective_path_for_current_path( start_path, current_path ) - if start_path is None: + if new_path is None: return + start_path = new_path key = ("path_with_polymorphic", start_path.natural_path) if key in context: @@ -1872,6 +2003,7 @@ class _AttributeStrategyLoad(_LoadElement): effective_path = self.path if current_path: + assert effective_path is not None effective_path = self._adjust_effective_path_for_current_path( effective_path, current_path ) @@ -1985,11 +2117,12 @@ class _TokenStrategyLoad(_LoadElement): ) if current_path: - effective_path = self._adjust_effective_path_for_current_path( + new_effective_path = self._adjust_effective_path_for_current_path( effective_path, current_path ) - if effective_path is None: + if new_effective_path is None: return [] + effective_path = new_effective_path # for a wildcard token, expand out the path we set # to encompass everything from the query entity on @@ -2048,19 +2181,25 @@ class _ClassStrategyLoad(_LoadElement): effective_path = self.path if current_path: - effective_path = self._adjust_effective_path_for_current_path( + new_effective_path = self._adjust_effective_path_for_current_path( effective_path, current_path ) - if effective_path is None: + if new_effective_path is None: return [] + effective_path = new_effective_path - return [("loader", cast(PathRegistry, effective_path).natural_path)] + return [("loader", effective_path.natural_path)] -def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad: - - lead_element = None +def _generate_from_keys( + meth: Callable[..., _AbstractLoad], + keys: Tuple[_AttrType, ...], + chained: bool, + kw: Any, +) -> _AbstractLoad: + lead_element: Optional[_AbstractLoad] = None + attr: Any for is_default, _keys in (True, keys[0:-1]), (False, keys[-1:]): for attr in _keys: if isinstance(attr, str): @@ -2116,7 +2255,9 @@ def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad: return lead_element -def _parse_attr_argument(attr): +def _parse_attr_argument( + attr: _AttrType, +) -> Tuple[InspectionAttr, _InternalEntityType[Any], MapperProperty[Any]]: """parse an attribute or wildcard argument to produce an :class:`._AbstractLoad` instance. @@ -2126,16 +2267,21 @@ def _parse_attr_argument(attr): """ try: - insp = inspect(attr) + # TODO: need to figure out this None thing being returned by + # inspect(), it should not have None as an option in most cases + # if at all + insp: InspectionAttr = inspect(attr) # type: ignore except sa_exc.NoInspectionAvailable as err: raise sa_exc.ArgumentError( "expected ORM mapped attribute for loader strategy argument" ) from err - if insp.is_property: + lead_entity: _InternalEntityType[Any] + + if insp_is_mapper_property(insp): lead_entity = insp.parent prop = insp - elif insp.is_attribute: + elif insp_is_attribute(insp): lead_entity = insp.parent prop = insp.prop else: @@ -2146,7 +2292,7 @@ def _parse_attr_argument(attr): return insp, lead_entity, prop -def loader_unbound_fn(fn): +def loader_unbound_fn(fn: _FN) -> _FN: """decorator that applies docstrings between standalone loader functions and the loader methods on :class:`._AbstractLoad`. @@ -2169,12 +2315,12 @@ See :func:`_orm.{fn.__name__}` for usage examples. @loader_unbound_fn -def contains_eager(*keys, **kw) -> _AbstractLoad: +def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad: return _generate_from_keys(Load.contains_eager, keys, True, kw) @loader_unbound_fn -def load_only(*attrs) -> _AbstractLoad: +def load_only(*attrs: _AttrType) -> _AbstractLoad: # TODO: attrs against different classes. we likely have to # add some extra state to Load of some kind _, lead_element, _ = _parse_attr_argument(attrs[0]) @@ -2182,47 +2328,47 @@ def load_only(*attrs) -> _AbstractLoad: @loader_unbound_fn -def joinedload(*keys, **kw) -> _AbstractLoad: +def joinedload(*keys: _AttrType, **kw: Any) -> _AbstractLoad: return _generate_from_keys(Load.joinedload, keys, False, kw) @loader_unbound_fn -def subqueryload(*keys) -> _AbstractLoad: +def subqueryload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.subqueryload, keys, False, {}) @loader_unbound_fn -def selectinload(*keys) -> _AbstractLoad: +def selectinload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.selectinload, keys, False, {}) @loader_unbound_fn -def lazyload(*keys) -> _AbstractLoad: +def lazyload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.lazyload, keys, False, {}) @loader_unbound_fn -def immediateload(*keys) -> _AbstractLoad: +def immediateload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.immediateload, keys, False, {}) @loader_unbound_fn -def noload(*keys) -> _AbstractLoad: +def noload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.noload, keys, False, {}) @loader_unbound_fn -def raiseload(*keys, **kw) -> _AbstractLoad: +def raiseload(*keys: _AttrType, **kw: Any) -> _AbstractLoad: return _generate_from_keys(Load.raiseload, keys, False, kw) @loader_unbound_fn -def defaultload(*keys) -> _AbstractLoad: +def defaultload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.defaultload, keys, False, {}) @loader_unbound_fn -def defer(key, *addl_attrs, **kw) -> _AbstractLoad: +def defer(key: _AttrType, *addl_attrs: _AttrType, **kw: Any) -> _AbstractLoad: if addl_attrs: util.warn_deprecated( "The *addl_attrs on orm.defer is deprecated. Please use " @@ -2234,7 +2380,7 @@ def defer(key, *addl_attrs, **kw) -> _AbstractLoad: @loader_unbound_fn -def undefer(key, *addl_attrs) -> _AbstractLoad: +def undefer(key: _AttrType, *addl_attrs: _AttrType) -> _AbstractLoad: if addl_attrs: util.warn_deprecated( "The *addl_attrs on orm.undefer is deprecated. Please use " @@ -2246,19 +2392,23 @@ def undefer(key, *addl_attrs) -> _AbstractLoad: @loader_unbound_fn -def undefer_group(name) -> _AbstractLoad: +def undefer_group(name: str) -> _AbstractLoad: element = _WildcardLoad() return element.undefer_group(name) @loader_unbound_fn -def with_expression(key, expression) -> _AbstractLoad: +def with_expression( + key: _AttrType, expression: _ColumnExpressionArgument[Any] +) -> _AbstractLoad: return _generate_from_keys( Load.with_expression, (key,), False, {"expression": expression} ) @loader_unbound_fn -def selectin_polymorphic(base_cls, classes) -> _AbstractLoad: +def selectin_polymorphic( + base_cls: _EntityType[Any], classes: Iterable[Type[Any]] +) -> _AbstractLoad: ul = Load(base_cls) return ul.selectin_polymorphic(classes) diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 4f63e241ba..4f1eeb39b6 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +# mypy: allow-untyped-defs, allow-untyped-calls """private module containing functions used for copying data @@ -14,9 +14,9 @@ between instances based on join conditions. from __future__ import annotations -from . import attributes from . import exc from . import util as orm_util +from .base import PassiveFlag def populate( @@ -36,7 +36,7 @@ def populate( # inline of source_mapper._get_state_attr_by_column prop = source_mapper._columntoproperty[l] value = source.manager[prop.key].impl.get( - source, source_dict, attributes.PASSIVE_OFF + source, source_dict, PassiveFlag.PASSIVE_OFF ) except exc.UnmappedColumnError as err: _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err) @@ -74,8 +74,8 @@ def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs): try: prop = source_mapper._columntoproperty[r] source_dict[prop.key] = value - except exc.UnmappedColumnError: - _raise_col_to_prop(True, source_mapper, l, source_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, source_mapper, l, source_mapper, r, err) def clear(dest, dest_mapper, synchronize_pairs): @@ -103,7 +103,7 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs): source.obj(), l ) value = source_mapper._get_state_attr_by_column( - source, source.dict, l, passive=attributes.PASSIVE_OFF + source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF ) except exc.UnmappedColumnError as err: _raise_col_to_prop(False, source_mapper, l, None, r, err) @@ -115,7 +115,7 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs): for l, r in synchronize_pairs: try: value = source_mapper._get_state_attr_by_column( - source, source.dict, l, passive=attributes.PASSIVE_OFF + source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF ) except exc.UnmappedColumnError as err: _raise_col_to_prop(False, source_mapper, l, None, r, err) @@ -134,7 +134,7 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs): except exc.UnmappedColumnError as err: _raise_col_to_prop(False, source_mapper, l, None, r, err) history = uowcommit.get_attribute_history( - source, prop.key, attributes.PASSIVE_NO_INITIALIZE + source, prop.key, PassiveFlag.PASSIVE_NO_INITIALIZE ) if bool(history.deleted): return True diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4da0b77737..c50cc5bac8 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -12,6 +12,7 @@ import re import types import typing from typing import Any +from typing import Callable from typing import cast from typing import Dict from typing import FrozenSet @@ -82,24 +83,29 @@ if typing.TYPE_CHECKING: from ._typing import _EntityType from ._typing import _IdentityKeyType from ._typing import _InternalEntityType - from ._typing import _ORMColumnExprArgument + from ._typing import _ORMCOLEXPR from .context import _MapperEntity from .context import ORMCompileState from .mapper import Mapper + from .query import Query from .relationships import Relationship from ..engine import Row from ..engine import RowMapping + from ..sql._typing import _CE from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _EquivalentColumnMap from ..sql._typing import _FromClauseArgument from ..sql._typing import _OnClauseArgument from ..sql._typing import _PropagateAttrsType + from ..sql.annotation import _SA from ..sql.base import ReadOnlyColumnCollection from ..sql.elements import BindParameter from ..sql.selectable import _ColumnsClauseElement from ..sql.selectable import Alias + from ..sql.selectable import Select from ..sql.selectable import Subquery from ..sql.visitors import anon_map + from ..util.typing import _AnnotationScanType _T = TypeVar("_T", bound=Any) @@ -144,9 +150,11 @@ class CascadeOptions(FrozenSet[str]): expunge: bool delete_orphan: bool - def __new__(cls, value_list): + def __new__( + cls, value_list: Optional[Union[Iterable[str], str]] + ) -> CascadeOptions: if isinstance(value_list, str) or value_list is None: - return cls.from_string(value_list) + return cls.from_string(value_list) # type: ignore values = set(value_list) if values.difference(cls._allowed_cascades): raise sa_exc.ArgumentError( @@ -864,7 +872,7 @@ class AliasedInsp( def _with_polymorphic_factory( cls, base: Union[_O, Mapper[_O]], - classes: Iterable[Type[Any]], + classes: Iterable[_EntityType[Any]], selectable: Union[Literal[False, None], FromClause] = False, flat: bool = False, polymorphic_on: Optional[ColumnElement[Any]] = None, @@ -1011,23 +1019,40 @@ class AliasedInsp( )._aliased_insp def _adapt_element( - self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None - ) -> _ORMColumnExprArgument[_T]: - assert isinstance(elem, ColumnElement) + self, expr: _ORMCOLEXPR, key: Optional[str] = None + ) -> _ORMCOLEXPR: + assert isinstance(expr, ColumnElement) d: Dict[str, Any] = { "parententity": self, "parentmapper": self.mapper, } if key: d["proxy_key"] = key + + # IMO mypy should see this one also as returning the same type + # we put into it, but it's not return ( - self._adapter.traverse(elem) + self._adapter.traverse(expr) # type: ignore ._annotate(d) ._set_propagate_attrs( {"compile_state_plugin": "orm", "plugin_subject": self} ) ) + if TYPE_CHECKING: + # establish compatibility with the _ORMAdapterProto protocol, + # which in turn is compatible with _CoreAdapterProto. + + def _orm_adapt_element( + self, + obj: _CE, + key: Optional[str] = None, + ) -> _CE: + ... + + else: + _orm_adapt_element = _adapt_element + def _entity_for_mapper(self, mapper): self_poly = self.with_polymorphic_mappers if mapper in self_poly: @@ -1469,7 +1494,12 @@ class Bundle( cloned.name = name return cloned - def create_row_processor(self, query, procs, labels): + def create_row_processor( + self, + query: Select[Any], + procs: Sequence[Callable[[Row[Any]], Any]], + labels: Sequence[str], + ) -> Callable[[Row[Any]], Any]: """Produce the "row processing" function for this :class:`.Bundle`. May be overridden by subclasses. @@ -1481,13 +1511,13 @@ class Bundle( """ keyed_tuple = result_tuple(labels, [() for l in labels]) - def proc(row): + def proc(row: Row[Any]) -> Any: return keyed_tuple([proc(row) for proc in procs]) return proc -def _orm_annotate(element, exclude=None): +def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA: """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag. @@ -1497,7 +1527,7 @@ def _orm_annotate(element, exclude=None): return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude) -def _orm_deannotate(element): +def _orm_deannotate(element: _SA) -> _SA: """Remove annotations that link a column to a particular mapping. Note this doesn't affect "remote" and "foreign" annotations @@ -1511,7 +1541,7 @@ def _orm_deannotate(element): ) -def _orm_full_deannotate(element): +def _orm_full_deannotate(element: _SA) -> _SA: return sql_util._deep_deannotate(element) @@ -1560,13 +1590,15 @@ class _ORMJoin(expression.Join): on_selectable = prop.parent.selectable else: prop = None + on_selectable = None if prop: left_selectable = left_info.selectable - + adapt_from: Optional[FromClause] if sql_util.clause_is_present(on_selectable, left_selectable): adapt_from = on_selectable else: + assert isinstance(left_selectable, FromClause) adapt_from = left_selectable ( @@ -1855,7 +1887,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: return given.isa(mapper) -def _getitem(iterable_query, item): +def _getitem(iterable_query: Query[Any], item: Any) -> Any: """calculate __getitem__ in terms of an iterable query object that also has a slice() method. @@ -1881,17 +1913,15 @@ def _getitem(iterable_query, item): isinstance(stop, int) and stop < 0 ): _no_negative_indexes() - return list(iterable_query)[item] res = iterable_query.slice(start, stop) if step is not None: - return list(res)[None : None : item.step] + return list(res)[None : None : item.step] # type: ignore else: - return list(res) + return list(res) # type: ignore else: if item == -1: _no_negative_indexes() - return list(iterable_query)[-1] else: return list(iterable_query[item : item + 1])[0] @@ -1933,7 +1963,7 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str: def _extract_mapped_subtype( - raw_annotation: Union[type, str], + raw_annotation: Optional[_AnnotationScanType], cls: type, key: str, attr_cls: Type[Any], diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index f49a6d3ec5..ed1bd28322 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -61,6 +61,9 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) +_CE = TypeVar("_CE", bound="ColumnElement[Any]") + + class _HasClauseElement(Protocol): """indicates a class that has a __clause_element__() method""" @@ -68,6 +71,13 @@ class _HasClauseElement(Protocol): ... +class _CoreAdapterProto(Protocol): + """protocol for the ClauseAdapter/ColumnAdapter.traverse() method.""" + + def __call__(self, obj: _CE) -> _CE: + ... + + # match column types that are not ORM entities _NOT_ENTITY = TypeVar( "_NOT_ENTITY", diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index fa36c09fcf..56d88bc2fb 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -454,9 +454,23 @@ def _deep_annotate( return element +@overload +def _deep_deannotate( + element: Literal[None], values: Optional[Sequence[str]] = None +) -> Literal[None]: + ... + + +@overload def _deep_deannotate( element: _SA, values: Optional[Sequence[str]] = None ) -> _SA: + ... + + +def _deep_deannotate( + element: Optional[_SA], values: Optional[Sequence[str]] = None +) -> Optional[_SA]: """Deep copy the given element, removing annotations.""" cloned: Dict[Any, SupportsAnnotations] = {} @@ -482,9 +496,7 @@ def _deep_deannotate( return element -def _shallow_annotate( - element: SupportsAnnotations, annotations: _AnnotationDict -) -> SupportsAnnotations: +def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA: """Annotate the given ClauseElement and copy its internals so that internal objects refer to the new annotated object. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 248b48a250..f5a9c10c0c 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -750,6 +750,17 @@ class _MetaOptions(type): o1.__dict__.update(other) return o1 + if TYPE_CHECKING: + + def __getattr__(self, key: str) -> Any: + ... + + def __setattr__(self, key: str, value: Any) -> None: + ... + + def __delattr__(self, key: str) -> None: + ... + class Options(metaclass=_MetaOptions): """A cacheable option dictionary with defaults.""" @@ -904,6 +915,17 @@ class Options(metaclass=_MetaOptions): else: return existing_options, exec_options + if TYPE_CHECKING: + + def __getattr__(self, key: str) -> Any: + ... + + def __setattr__(self, key: str, value: Any) -> None: + ... + + def __delattr__(self, key: str) -> None: + ... + class CacheableOptions(Options, HasCacheKey): __slots__ = () diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index eef5cf211e..501188b127 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -56,6 +56,7 @@ if typing.TYPE_CHECKING: from .elements import ColumnClause from .elements import ColumnElement from .elements import DQLDMLClauseElement + from .elements import NamedColumn from .elements import SQLCoreOperations from .schema import Column from .selectable import _ColumnsClauseElement @@ -197,6 +198,15 @@ def expect( ... +@overload +def expect( + role: Type[roles.LabeledColumnExprRole[Any]], + element: _ColumnExpressionArgument[_T], + **kw: Any, +) -> NamedColumn[_T]: + ... + + @overload def expect( role: Union[ @@ -217,6 +227,7 @@ def expect( Type[roles.LimitOffsetRole], Type[roles.WhereHavingRole], Type[roles.OnClauseRole], + Type[roles.ColumnArgumentRole], ], element: Any, **kw: Any, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 41b7f6392e..61c5379d8a 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -503,7 +503,7 @@ class ClauseElement( def params( self: SelfClauseElement, - __optionaldict: Optional[Dict[str, Any]] = None, + __optionaldict: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> SelfClauseElement: """Return a copy with :func:`_expression.bindparam` elements @@ -525,7 +525,7 @@ class ClauseElement( def _replace_params( self: SelfClauseElement, unique: bool, - optionaldict: Optional[Dict[str, Any]], + optionaldict: Optional[Mapping[str, Any]], kwargs: Dict[str, Any], ) -> SelfClauseElement: @@ -545,7 +545,7 @@ class ClauseElement( {"bindparam": visit_bindparam}, ) - def compare(self, other, **kw): + def compare(self, other: ClauseElement, **kw: Any) -> bool: r"""Compare this :class:`_expression.ClauseElement` to the given :class:`_expression.ClauseElement`. @@ -2516,7 +2516,9 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]): return False_._singleton @classmethod - def _ifnone(cls, other): + def _ifnone( + cls, other: Optional[ColumnElement[Any]] + ) -> ColumnElement[Any]: if other is None: return cls._instance() else: @@ -4226,7 +4228,13 @@ class NamedColumn(KeyedColumnElement[_T]): ) -> Optional[str]: return name - def _bind_param(self, operator, obj, type_=None, expanding=False): + def _bind_param( + self, + operator: OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + ) -> BindParameter[_T]: return BindParameter( self.key, obj, diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d0b0f14761..fd98f17e32 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -64,6 +64,7 @@ from .base import _EntityNamespace from .base import _expand_cloned from .base import _from_objects from .base import _generative +from .base import _NoArg from .base import _select_iterables from .base import CacheableOptions from .base import ColumnCollection @@ -131,6 +132,7 @@ if TYPE_CHECKING: from .dml import Insert from .dml import Update from .elements import KeyedColumnElement + from .elements import Label from .elements import NamedColumn from .elements import TextClause from .functions import Function @@ -212,7 +214,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): """ raise NotImplementedError() - def is_derived_from(self, fromclause: FromClause) -> bool: + def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: """Return ``True`` if this :class:`.ReturnsRows` is 'derived' from the given :class:`.FromClause`. @@ -778,7 +780,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return TableSample._construct(self, sampling, name, seed) - def is_derived_from(self, fromclause: FromClause) -> bool: + def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: """Return ``True`` if this :class:`_expression.FromClause` is 'derived' from the given ``FromClause``. @@ -1128,11 +1130,14 @@ class SelectLabelStyle(Enum): """ + LABEL_STYLE_LEGACY_ORM = 3 + ( LABEL_STYLE_NONE, LABEL_STYLE_TABLENAME_PLUS_COL, LABEL_STYLE_DISAMBIGUATE_ONLY, + _, ) = list(SelectLabelStyle) LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY @@ -1231,7 +1236,7 @@ class Join(roles.DMLTableRole, FromClause): id(self.right), ) - def is_derived_from(self, fromclause: FromClause) -> bool: + def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: return ( # use hash() to ensure direct comparison to annotated works # as well @@ -1635,7 +1640,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): """Legacy for dialects that are referring to Alias.original.""" return self.element - def is_derived_from(self, fromclause: FromClause) -> bool: + def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: if fromclause in self._cloned_set: return True return self.element.is_derived_from(fromclause) @@ -2840,7 +2845,7 @@ class FromGrouping(GroupedElement, FromClause): def foreign_keys(self): return self.element.foreign_keys - def is_derived_from(self, fromclause: FromClause) -> bool: + def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: return self.element.is_derived_from(fromclause) def alias( @@ -3080,11 +3085,17 @@ class ForUpdateArg(ClauseElement): def __init__( self, - nowait=False, - read=False, - of=None, - skip_locked=False, - key_share=False, + *, + nowait: bool = False, + read: bool = False, + of: Optional[ + Union[ + _ColumnExpressionArgument[Any], + Sequence[_ColumnExpressionArgument[Any]], + ] + ] = None, + skip_locked: bool = False, + key_share: bool = False, ): """Represents arguments specified to :meth:`_expression.Select.for_update`. @@ -3455,7 +3466,7 @@ class SelectBase( return ScalarSelect(self) - def label(self, name): + def label(self, name: Optional[str]) -> Label[Any]: """Return a 'scalar' representation of this selectable, embedded as a subquery with a label. @@ -3667,6 +3678,7 @@ class GenerativeSelect(SelectBase, Generative): @_generative def with_for_update( self: SelfGenerativeSelect, + *, nowait: bool = False, read: bool = False, of: Optional[ @@ -4064,7 +4076,11 @@ class GenerativeSelect(SelectBase, Generative): @_generative def order_by( - self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any] + self: SelfGenerativeSelect, + __first: Union[ + Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionArgument[Any], ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of ORDER BY criteria applied. @@ -4092,18 +4108,22 @@ class GenerativeSelect(SelectBase, Generative): """ - if len(clauses) == 1 and clauses[0] is None: + if not clauses and __first is None: self._order_by_clauses = () - else: + elif __first is not _NoArg.NO_ARG: self._order_by_clauses += tuple( coercions.expect(roles.OrderByRole, clause) - for clause in clauses + for clause in (__first,) + clauses ) return self @_generative def group_by( - self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any] + self: SelfGenerativeSelect, + __first: Union[ + Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionArgument[Any], ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of GROUP BY criterion applied. @@ -4128,12 +4148,12 @@ class GenerativeSelect(SelectBase, Generative): """ - if len(clauses) == 1 and clauses[0] is None: + if not clauses and __first is None: self._group_by_clauses = () - else: + elif __first is not _NoArg.NO_ARG: self._group_by_clauses += tuple( coercions.expect(roles.GroupByRole, clause) - for clause in clauses + for clause in (__first,) + clauses ) return self @@ -4257,7 +4277,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): ) -> GroupedElement: return SelectStatementGrouping(self) - def is_derived_from(self, fromclause: FromClause) -> bool: + def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: for s in self.selects: if s.is_derived_from(fromclause): return True @@ -4959,7 +4979,7 @@ class Select( _raw_columns: List[_ColumnsClauseElement] - _distinct = False + _distinct: bool = False _distinct_on: Tuple[ColumnElement[Any], ...] = () _correlate: Tuple[FromClause, ...] = () _correlate_except: Optional[Tuple[FromClause, ...]] = None @@ -5478,8 +5498,8 @@ class Select( return iter(self._all_selected_columns) - def is_derived_from(self, fromclause: FromClause) -> bool: - if self in fromclause._cloned_set: + def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: + if fromclause is not None and self in fromclause._cloned_set: return True for f in self._iterate_from_elements(): diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index aceed99a5d..94e635740e 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -19,6 +19,7 @@ from typing import Callable from typing import Deque from typing import Dict from typing import Iterable +from typing import Optional from typing import Set from typing import Tuple from typing import Type @@ -39,7 +40,7 @@ COMPARE_FAILED = False COMPARE_SUCCEEDED = True -def compare(obj1, obj2, **kw): +def compare(obj1: Any, obj2: Any, **kw: Any) -> bool: strategy: TraversalComparatorStrategy if kw.get("use_proxies", False): strategy = ColIdentityComparatorStrategy() @@ -49,7 +50,7 @@ def compare(obj1, obj2, **kw): return strategy.compare(obj1, obj2, **kw) -def _preconfigure_traversals(target_hierarchy): +def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None: for cls in util.walk_subclasses(target_hierarchy): if hasattr(cls, "_generate_cache_attrs") and hasattr( cls, "_traverse_internals" @@ -482,14 +483,22 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): def __init__(self): self.stack: Deque[ - Tuple[ExternallyTraversible, ExternallyTraversible] + Tuple[ + Optional[ExternallyTraversible], + Optional[ExternallyTraversible], + ] ] = deque() self.cache = set() def _memoized_attr_anon_map(self): return (anon_map(), anon_map()) - def compare(self, obj1, obj2, **kw): + def compare( + self, + obj1: ExternallyTraversible, + obj2: ExternallyTraversible, + **kw: Any, + ) -> bool: stack = self.stack cache = self.cache @@ -551,6 +560,10 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): elif left_attrname in attributes_compared: continue + assert left_visit_sym is not None + assert left_attrname is not None + assert right_attrname is not None + dispatch = self.dispatch(left_visit_sym) assert dispatch, ( f"{self.__class__} has no dispatch for " @@ -595,6 +608,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in zip_longest(left, right, fillvalue=None): + if l is None: + if r is not None: + return COMPARE_FAILED + else: + continue + elif r is None: + return COMPARE_FAILED + if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key( self.anon_map[1], [] ): @@ -604,6 +625,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in zip_longest(left, right, fillvalue=None): + if l is None: + if r is not None: + return COMPARE_FAILED + else: + continue + elif r is None: + return COMPARE_FAILED + if ( l._gen_cache_key(self.anon_map[0], []) if l._is_has_cache_key diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 262689128d..390e23952f 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -73,6 +73,7 @@ if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument + from .elements import BinaryExpression from .elements import TextClause from .selectable import _JoinTargetElement from .selectable import _SelectIterable @@ -86,8 +87,15 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.row import Row +_CE = TypeVar("_CE", bound="ColumnElement[Any]") -def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None): + +def join_condition( + a: FromClause, + b: FromClause, + a_subset: Optional[FromClause] = None, + consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None, +) -> ColumnElement[bool]: """Create a join condition between two tables or selectables. e.g.:: @@ -118,7 +126,9 @@ def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None): ) -def find_join_source(clauses, join_to): +def find_join_source( + clauses: List[FromClause], join_to: FromClause +) -> List[int]: """Given a list of FROM clauses and a selectable, return the first index and element from the list of clauses which can be joined against the selectable. returns @@ -144,7 +154,9 @@ def find_join_source(clauses, join_to): return idx -def find_left_clause_that_matches_given(clauses, join_from): +def find_left_clause_that_matches_given( + clauses: Sequence[FromClause], join_from: FromClause +) -> List[int]: """Given a list of FROM clauses and a selectable, return the indexes from the list of clauses which is derived from the selectable. @@ -243,7 +255,12 @@ def find_left_clause_to_join_from( return idx -def visit_binary_product(fn, expr): +def visit_binary_product( + fn: Callable[ + [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None + ], + expr: ColumnElement[Any], +) -> None: """Produce a traversal of the given expression, delivering column comparisons to the given function. @@ -278,19 +295,19 @@ def visit_binary_product(fn, expr): a binary comparison is passed as pairs. """ - stack: List[ClauseElement] = [] + stack: List[BinaryExpression[Any]] = [] - def visit(element): + def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]: if isinstance(element, ScalarSelect): # we don't want to dig into correlated subqueries, # those are just column elements by themselves yield element elif element.__visit_name__ == "binary" and operators.is_comparison( - element.operator + element.operator # type: ignore ): - stack.insert(0, element) - for l in visit(element.left): - for r in visit(element.right): + stack.insert(0, element) # type: ignore + for l in visit(element.left): # type: ignore + for r in visit(element.right): # type: ignore fn(stack[0], l, r) stack.pop(0) for elem in element.get_children(): @@ -502,7 +519,7 @@ def extract_first_column_annotation(column, annotation_name): return None -def selectables_overlap(left, right): +def selectables_overlap(left: FromClause, right: FromClause) -> bool: """Return True if left/right have some overlapping selectable""" return bool( @@ -701,7 +718,7 @@ class _repr_params(_repr_base): return "[%s]" % (", ".join(trunc(value) for value in params)) -def adapt_criterion_to_null(crit, nulls): +def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE: """given criterion containing bind params, convert selected elements to IS NULL. @@ -922,9 +939,6 @@ def criterion_as_pairs( return pairs -_CE = TypeVar("_CE", bound="ClauseElement") - - class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 217e2d2ab4..b550f8f286 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -21,7 +21,6 @@ from typing import Any from typing import Callable from typing import cast from typing import ClassVar -from typing import Collection from typing import Dict from typing import Iterable from typing import Iterator @@ -31,6 +30,7 @@ from typing import Optional from typing import overload from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -42,6 +42,10 @@ from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self +if TYPE_CHECKING: + from .annotation import _AnnotationDict + from .elements import ColumnElement + if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_util import prefix_anon_map as prefix_anon_map from ._py_util import cache_anon_map as anon_map @@ -590,13 +594,23 @@ _dispatch_lookup = HasTraversalDispatch._dispatch_lookup _generate_traversal_dispatch() +SelfExternallyTraversible = TypeVar( + "SelfExternallyTraversible", bound="ExternallyTraversible" +) + + class ExternallyTraversible(HasTraverseInternals, Visitable): __slots__ = () - _annotations: Collection[Any] = () + _annotations: Mapping[Any, Any] = util.EMPTY_DICT if typing.TYPE_CHECKING: + def _annotate( + self: SelfExternallyTraversible, values: _AnnotationDict + ) -> SelfExternallyTraversible: + ... + def get_children( self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[ExternallyTraversible]: @@ -624,6 +638,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): _ET = TypeVar("_ET", bound=ExternallyTraversible) +_CE = TypeVar("_CE", bound="ColumnElement[Any]") _TraverseCallableType = Callable[[_ET], None] @@ -633,10 +648,8 @@ class _CloneCallableType(Protocol): ... -class _TraverseTransformCallableType(Protocol): - def __call__( - self, element: ExternallyTraversible, **kw: Any - ) -> Optional[ExternallyTraversible]: +class _TraverseTransformCallableType(Protocol[_ET]): + def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]: ... @@ -1074,16 +1087,25 @@ def cloned_traverse( def replacement_traverse( obj: Literal[None], opts: Mapping[str, Any], - replace: _TraverseTransformCallableType, + replace: _TraverseTransformCallableType[Any], ) -> None: ... +@overload +def replacement_traverse( + obj: _CE, + opts: Mapping[str, Any], + replace: _TraverseTransformCallableType[Any], +) -> _CE: + ... + + @overload def replacement_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], - replace: _TraverseTransformCallableType, + replace: _TraverseTransformCallableType[Any], ) -> ExternallyTraversible: ... @@ -1091,7 +1113,7 @@ def replacement_traverse( def replacement_traverse( obj: Optional[ExternallyTraversible], opts: Mapping[str, Any], - replace: _TraverseTransformCallableType, + replace: _TraverseTransformCallableType[Any], ) -> Optional[ExternallyTraversible]: """Clone the given expression structure, allowing element replacement by a given replacement function. @@ -1134,7 +1156,7 @@ def replacement_traverse( newelem = replace(elem) if newelem is not None: stop_on.add(id(newelem)) - return newelem + return newelem # type: ignore else: # base "already seen" on id(), not hash, so that we don't # replace an Annotated element with its non-annotated one, and @@ -1145,11 +1167,11 @@ def replacement_traverse( newelem = kw["replace"](elem) if newelem is not None: cloned[id_elem] = newelem - return newelem + return newelem # type: ignore cloned[id_elem] = newelem = elem._clone(**kw) newelem._copy_internals(clone=clone, **kw) - return cloned[id_elem] + return cloned[id_elem] # type: ignore if obj is not None: obj = clone( diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 7150dedcf8..54be2e4e5b 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -71,7 +71,7 @@ _T_co = TypeVar("_T_co", covariant=True) EMPTY_SET: FrozenSet[Any] = frozenset() -def merge_lists_w_ordering(a, b): +def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: """merge two lists, maintaining ordering as much as possible. this is to reconcile vars(cls) with cls.__annotations__. @@ -450,7 +450,7 @@ def to_set(x): return x -def to_column_set(x): +def to_column_set(x: Any) -> Set[Any]: if x is None: return column_set() if not isinstance(x, column_set): diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 24fa0f3e38..adbbf143f9 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -20,11 +20,14 @@ import typing from typing import Any from typing import Callable from typing import Dict +from typing import Iterable from typing import List from typing import Mapping from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type py311 = sys.version_info >= (3, 11) @@ -225,7 +228,7 @@ def inspect_formatargspec( return result -def dataclass_fields(cls): +def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated with a class.""" @@ -235,12 +238,12 @@ def dataclass_fields(cls): return [] -def local_dataclass_fields(cls): +def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated with a class, excluding those that originate from a superclass.""" if dataclasses.is_dataclass(cls): - super_fields = set() + super_fields: Set[dataclasses.Field[Any]] = set() for sup in cls.__bases__: super_fields.update(dataclass_fields(sup)) return [f for f in dataclasses.fields(cls) if f not in super_fields] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 24c66bfa4e..e54f334758 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -266,13 +266,31 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]: metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name) metadata.update(format_argspec_plus(spec, grouped=False)) metadata["name"] = fn.__name__ - code = ( - """\ + + # look for __ positional arguments. This is a convention in + # SQLAlchemy that arguments should be passed positionally + # rather than as keyword + # arguments. note that apply_pos doesn't currently work in all cases + # such as when a kw-only indicator "*" is present, which is why + # we limit the use of this to just that case we can detect. As we add + # more kinds of methods that use @decorator, things may have to + # be further improved in this area + if "__" in repr(spec[0]): + code = ( + """\ +def %(name)s%(grouped_args)s: + return %(target)s(%(fn)s, %(apply_pos)s) +""" + % metadata + ) + else: + code = ( + """\ def %(name)s%(grouped_args)s: return %(target)s(%(fn)s, %(apply_kw)s) """ - % metadata - ) + % metadata + ) env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__}) decorated = cast( @@ -1235,10 +1253,10 @@ class HasMemoized: return result @classmethod - def memoized_instancemethod(cls, fn: Any) -> Any: + def memoized_instancemethod(cls, fn: _F) -> _F: """Decorate a method memoize its return value.""" - def oneshot(self, *args, **kw): + def oneshot(self: Any, *args: Any, **kw: Any) -> Any: result = fn(self, *args, **kw) def memo(*a, **kw): @@ -1250,7 +1268,7 @@ class HasMemoized: self._memoized_keys |= {fn.__name__} return result - return update_wrapper(oneshot, fn) + return update_wrapper(oneshot, fn) # type: ignore if TYPE_CHECKING: diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py index fce3cd3b0b..67394c9a3a 100644 --- a/lib/sqlalchemy/util/preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -25,8 +25,12 @@ _FN = TypeVar("_FN", bound=Callable[..., Any]) if TYPE_CHECKING: from sqlalchemy.engine import default as engine_default # noqa + from sqlalchemy.orm import clsregistry as orm_clsregistry # noqa + from sqlalchemy.orm import decl_api as orm_decl_api # noqa + from sqlalchemy.orm import properties as orm_properties # noqa from sqlalchemy.orm import relationships as orm_relationships # noqa from sqlalchemy.orm import session as orm_session # noqa + from sqlalchemy.orm import state as orm_state # noqa from sqlalchemy.orm import util as orm_util # noqa from sqlalchemy.sql import dml as sql_dml # noqa from sqlalchemy.sql import functions as sql_functions # noqa diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 37297103ef..24e478b573 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -4,21 +4,33 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls """Topological sorting algorithms.""" from __future__ import annotations +from typing import Any +from typing import DefaultDict +from typing import Iterable +from typing import Iterator +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TypeVar + from .. import util from ..exc import CircularDependencyError +_T = TypeVar("_T", bound=Any) + __all__ = ["sort", "sort_as_subsets", "find_cycles"] -def sort_as_subsets(tuples, allitems): +def sort_as_subsets( + tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T] +) -> Iterator[Sequence[_T]]: - edges = util.defaultdict(set) + edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set) for parent, child in tuples: edges[child].add(parent) @@ -43,7 +55,11 @@ def sort_as_subsets(tuples, allitems): yield output -def sort(tuples, allitems, deterministic_order=True): +def sort( + tuples: Iterable[Tuple[_T, _T]], + allitems: Iterable[_T], + deterministic_order: bool = True, +) -> Iterator[_T]: """sort the given list of items by dependency. 'tuples' is a list of tuples representing a partial ordering. @@ -59,11 +75,14 @@ def sort(tuples, allitems, deterministic_order=True): yield s -def find_cycles(tuples, allitems): +def find_cycles( + tuples: Iterable[Tuple[_T, _T]], + allitems: Iterable[_T], +) -> Set[_T]: # adapted from: # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html - edges = util.defaultdict(set) + edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set) for parent, child in tuples: edges[parent].add(child) nodes_to_test = set(edges) @@ -99,5 +118,5 @@ def find_cycles(tuples, allitems): return output -def _gen_edges(edges): - return set([(right, left) for left in edges for right in edges[left]]) +def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]: + return {(right, left) for left in edges for right in edges[left]} diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index ebcae28a7a..44e26f6094 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -11,7 +11,9 @@ from typing import Dict from typing import ForwardRef from typing import Generic from typing import Iterable +from typing import NoReturn from typing import Optional +from typing import overload from typing import Tuple from typing import Type from typing import TypeVar @@ -33,7 +35,7 @@ Self = TypeVar("Self", bound=Any) if compat.py310: # why they took until py310 to put this in stdlib is beyond me, # I've been wanting it since py27 - from types import NoneType + from types import NoneType as NoneType else: NoneType = type(None) # type: ignore @@ -68,6 +70,8 @@ else: # copied from TypeShed, required in order to implement # MutableMapping.update() +_AnnotationScanType = Union[Type[Any], str] + class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): def keys(self) -> Iterable[_KT]: @@ -90,9 +94,9 @@ else: def de_stringify_annotation( cls: Type[Any], - annotation: Union[str, Type[Any]], + annotation: _AnnotationScanType, str_cleanup_fn: Optional[Callable[[str], str]] = None, -) -> Union[str, Type[Any]]: +) -> Type[Any]: """Resolve annotations that may be string based into real objects. This is particularly important if a module defines "from __future__ import @@ -125,20 +129,32 @@ def de_stringify_annotation( annotation = eval(annotation, base_globals, None) except NameError: pass - return annotation + return annotation # type: ignore -def is_fwd_ref(type_): +def is_fwd_ref(type_: _AnnotationScanType) -> bool: return isinstance(type_, ForwardRef) -def de_optionalize_union_types(type_): +@overload +def de_optionalize_union_types(type_: str) -> str: + ... + + +@overload +def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: + ... + + +def de_optionalize_union_types( + type_: _AnnotationScanType, +) -> _AnnotationScanType: """Given a type, filter out ``Union`` types that include ``NoneType`` to not include the ``NoneType``. """ if is_optional(type_): - typ = set(type_.__args__) + typ = set(type_.__args__) # type: ignore typ.discard(NoneType) @@ -148,14 +164,14 @@ def de_optionalize_union_types(type_): return type_ -def make_union_type(*types): +def make_union_type(*types: _AnnotationScanType) -> Type[Any]: """Make a Union type. This is needed by :func:`.de_optionalize_union_types` which removes ``NoneType`` from a ``Union``. """ - return cast(Any, Union).__getitem__(types) + return cast(Any, Union).__getitem__(types) # type: ignore def expand_unions( @@ -251,4 +267,47 @@ class DescriptorReference(Generic[_DESC]): ... +_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True) + + +class RODescriptorReference(Generic[_DESC_co]): + """a descriptor that refers to a descriptor. + + same as :class:`.DescriptorReference` but is read-only, so that subclasses + can define a subtype as the generically contained element + + """ + + def __get__(self, instance: object, owner: Any) -> _DESC_co: + ... + + def __set__(self, instance: Any, value: Any) -> NoReturn: + ... + + def __delete__(self, instance: Any) -> NoReturn: + ... + + +_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]]) + + +class CallableReference(Generic[_FN]): + """a descriptor that refers to a callable. + + works around mypy's limitation of not allowing callables assigned + as instance variables + + + """ + + def __get__(self, instance: object, owner: Any) -> _FN: + ... + + def __set__(self, instance: Any, value: _FN) -> None: + ... + + def __delete__(self, instance: Any) -> None: + ... + + # $def ro_descriptor_reference(fn: Callable[]) diff --git a/pyproject.toml b/pyproject.toml index f8498fde94..29d59ea698 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,21 +68,6 @@ strict = true # pass module = [ - # TODO for ORM, non-strict - "sqlalchemy.orm.base", - "sqlalchemy.orm.decl_base", - "sqlalchemy.orm.descriptor_props", - "sqlalchemy.orm.identity", - "sqlalchemy.orm.mapped_collection", - "sqlalchemy.orm.properties", - "sqlalchemy.orm.relationships", - "sqlalchemy.orm.strategy_options", - "sqlalchemy.orm.state_changes", - - # would ideally be strict - "sqlalchemy.orm.decl_api", - "sqlalchemy.orm.events", - "sqlalchemy.orm.query", "sqlalchemy.engine.reflection", ] diff --git a/test/ext/mypy/plain_files/composite.py b/test/ext/mypy/plain_files/composite.py new file mode 100644 index 0000000000..c69963314e --- /dev/null +++ b/test/ext/mypy/plain_files/composite.py @@ -0,0 +1,67 @@ +from typing import Any +from typing import Tuple + +from sqlalchemy import select +from sqlalchemy.orm import composite +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Point: + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + def __composite_values__(self) -> Tuple[int, int]: + return self.x, self.y + + def __repr__(self) -> str: + return "Point(x=%r, y=%r)" % (self.x, self.y) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, Point) + and other.x == self.x + and other.y == self.y + ) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +class Vertex(Base): + __tablename__ = "vertices" + + id: Mapped[int] = mapped_column(primary_key=True) + x1: Mapped[int] + y1: Mapped[int] + x2: Mapped[int] + y2: Mapped[int] + + # inferred from right hand side + start = composite(Point, "x1", "y1") + + # taken from left hand side + end: Mapped[Point] = composite(Point, "x2", "y2") + + +v1 = Vertex(start=Point(3, 4), end=Point(5, 6)) + +stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)])) + +# EXPECTED_TYPE: Select[Tuple[Vertex]] +reveal_type(stmt) + +# EXPECTED_TYPE: composite.Point +reveal_type(v1.start) + +# EXPECTED_TYPE: composite.Point +reveal_type(v1.end) + +# EXPECTED_TYPE: int +reveal_type(v1.end.y) diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index 1086f187af..37f99502db 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -258,7 +258,7 @@ class MypyPluginTest(fixtures.TestBase): ) expected_msg = re.sub( - r"(int|str|float|bool)", + r"\b(int|str|float|bool)\b", lambda m: rf"builtins.{m.group(0)}\*?", expected_msg, ) diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 9f9f8e601e..4990056c3e 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -133,7 +133,9 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): reg = registry(metadata=metadata) - reg.map_declaratively(User) + mp = reg.map_declaratively(User) + assert mp is inspect(User) + assert mp is User.__mapper__ def test_undefer_column_name(self): # TODO: not sure if there was an explicit @@ -186,6 +188,53 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): class_mapper(User).get_property("props").secondary is user_to_prop ) + def test_string_dependency_resolution_schemas_no_base(self): + """ + + found_during_type_annotation + + """ + + reg = registry() + + @reg.mapped + class User: + + __tablename__ = "users" + __table_args__ = {"schema": "fooschema"} + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + props = relationship( + "Prop", + secondary="fooschema.user_to_prop", + primaryjoin="User.id==fooschema.user_to_prop.c.user_id", + secondaryjoin="fooschema.user_to_prop.c.prop_id==Prop.id", + backref="users", + ) + + @reg.mapped + class Prop: + + __tablename__ = "props" + __table_args__ = {"schema": "fooschema"} + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + user_to_prop = Table( + "user_to_prop", + reg.metadata, + Column("user_id", Integer, ForeignKey("fooschema.users.id")), + Column("prop_id", Integer, ForeignKey("fooschema.props.id")), + schema="fooschema", + ) + configure_mappers() + + assert ( + class_mapper(User).get_property("props").secondary is user_to_prop + ) + def test_string_dependency_resolution_annotations(self): Base = declarative_base() @@ -290,6 +339,51 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): reg = registry(metadata=metadata) reg.mapped(User) reg.mapped(Address) + + reg.metadata.create_all(testing.db) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) + with Session(testing.db) as sess: + sess.add(u1) + sess.commit() + with Session(testing.db) as sess: + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) + + def test_map_declaratively(self, metadata): + class User(fixtures.ComparableEntity): + + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + addresses = relationship("Address", backref="user") + + class Address(fixtures.ComparableEntity): + + __tablename__ = "addresses" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey("users.id")) + + reg = registry(metadata=metadata) + um = reg.map_declaratively(User) + am = reg.map_declaratively(Address) + + is_(User.__mapper__, um) + is_(Address.__mapper__, am) + reg.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index ca3a6e6089..fb27c910ff 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -1,6 +1,7 @@ import sqlalchemy as sa from sqlalchemy import ForeignKey from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import class_mapper @@ -14,6 +15,7 @@ from sqlalchemy.orm.decl_api import registry from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false @@ -653,6 +655,25 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Engineer(name="vlad", primary_language="cobol"), ) + def test_single_cols_on_sub_base_of_subquery(self): + """ + found_during_type_annotation + + """ + t = Table("t", Base.metadata, Column("id", Integer, primary_key=True)) + + class Person(Base): + __table__ = select(t).subquery() + + with expect_raises_message( + sa.exc.ArgumentError, + r"Can't declare columns on single-table-inherited subclass " + r".*Contractor.*; superclass .*Person.* is not mapped to a Table", + ): + + class Contractor(Person): + contractor_field = Column(String) + def test_single_cols_on_sub_base_of_joined(self): """test [ticket:3895]""" diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 2c9cc3b21b..e72c181100 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -965,6 +965,45 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): '"user".state, "user".zip FROM "user"', ) + def test_name_cols_by_str(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + street: Mapped[str] + state: Mapped[str] + + # TODO: this needs to be improved, we should be able to say: + # zip_: Mapped[str] = mapped_column("zip") + # and it should assign to "zip_" for the attribute. not working + + zip_: Mapped[str] = mapped_column(name="zip", key="zip_") + + address: Mapped["Address"] = composite( + Address, "street", "state", "zip_" + ) + + eq_( + User.__mapper__.attrs["address"].props, + [ + User.__mapper__.attrs["street"], + User.__mapper__.attrs["state"], + User.__mapper__.attrs["zip_"], + ], + ) + self.assert_compile( + select(User), + 'SELECT "user".id, "user".name, "user".street, ' + '"user".state, "user".zip FROM "user"', + ) + def test_cls_annotated_setup(self, decl_base): @dataclasses.dataclass class Address: diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index ab6d79383c..78b503873f 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -24,6 +24,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock @@ -35,7 +36,9 @@ from sqlalchemy.testing.schema import Table from test.orm.test_events import _RemoveListeners -class ConcreteTest(fixtures.MappedTest): +class ConcreteTest(AssertsCompiledSQL, fixtures.MappedTest): + __dialect__ = "default" + @classmethod def define_tables(cls, metadata): Table( @@ -265,6 +268,10 @@ class ConcreteTest(fixtures.MappedTest): "sometype", ) + # found_during_type_annotation + # test the comparator returned by ConcreteInheritedProperty + self.assert_compile(Manager.type == "x", "pjoin.type = :type_1") + jenn = Engineer("Jenn", "knows how to program") hacker = Hacker("Karina", "Badass", "knows how to hack") diff --git a/test/orm/test_query.py b/test/orm/test_query.py index e3732bef50..d9013b2c4a 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -2259,6 +2259,22 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): ) assert a1.c.users_id is not None + def test_no_subquery_for_from_statement(self): + """ + found_during_typing + + """ + User = self.classes.User + + session = fixture_session() + q = session.query(User.id).from_statement(text("select * from user")) + + with expect_raises_message( + sa.exc.InvalidRequestError, + r"Can't call this method on a Query that uses from_statement\(\)", + ): + q.subquery() + def test_reduced_subquery(self): User = self.classes.User ua = aliased(User) @@ -6183,12 +6199,6 @@ class TextTest(QueryTest, AssertsCompiledSQL): "FROM users GROUP BY name", ) - def test_orm_columns_accepts_text(self): - from sqlalchemy.orm.base import _orm_columns - - t = text("x") - eq_(_orm_columns(t), [t]) - def test_order_by_w_eager_one(self): User = self.classes.User s = fixture_session() diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index fd192ab519..22bc549085 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -4324,7 +4324,8 @@ class InvalidRemoteSideTest(fixtures.MappedTest): assert_raises_message( sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " - r"both of the same direction symbol\('ONETOMANY'\). Did you " + r"both of the same " + r"direction .*RelationshipDirection.ONETOMANY.*. Did you " "mean to set remote_side on the many-to-one side ?", configure_mappers, ) @@ -4347,7 +4348,8 @@ class InvalidRemoteSideTest(fixtures.MappedTest): assert_raises_message( sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " - r"both of the same direction symbol\('MANYTOONE'\). Did you " + r"both of the same direction .*RelationshipDirection.MANYTOONE.*." + "Did you " "mean to set remote_side on the many-to-one side ?", configure_mappers, ) @@ -4367,7 +4369,8 @@ class InvalidRemoteSideTest(fixtures.MappedTest): # can't be sure of ordering here assert_raises_message( sa.exc.ArgumentError, - r"both of the same direction symbol\('ONETOMANY'\). Did you " + r"both of the same direction " + r".*RelationshipDirection.ONETOMANY.*. Did you " "mean to set remote_side on the many-to-one side ?", configure_mappers, ) @@ -4391,7 +4394,8 @@ class InvalidRemoteSideTest(fixtures.MappedTest): # can't be sure of ordering here assert_raises_message( sa.exc.ArgumentError, - r"both of the same direction symbol\('MANYTOONE'\). Did you " + r"both of the same direction " + r".*RelationshipDirection.MANYTOONE.*. Did you " "mean to set remote_side on the many-to-one side ?", configure_mappers, )