--- /dev/null
+.. change::
+ :tags: bug, orm
+
+ Fixed issue where the :meth:`_orm.registry.map_declaratively` method
+ would return an internal "mapper config" object and not the
+ :class:`.Mapper` object as stated in the API documentation.
from typing import Dict
from typing import Iterator
from typing import List
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Tuple
if typing.TYPE_CHECKING:
+ from .base import Connection
+ from .default import DefaultExecutionContext
from .interfaces import _DBAPICursorDescription
+ from .interfaces import DBAPICursor
+ from .interfaces import Dialect
from .interfaces import ExecutionContext
from .result import _KeyIndexType
from .result import _KeyMapRecType
from .result import _ProcessorsType
from ..sql.type_api import _ResultProcessorType
+
_T = TypeVar("_T", bound=Any)
# metadata entry tuple indexes.
) = context.result_column_struct
num_ctx_cols = len(result_columns)
else:
- result_columns = (
+ result_columns = ( # type: ignore
cols_are_ordered
) = (
num_ctx_cols
alternate_cursor_description: Optional[_DBAPICursorDescription] = None
- def soft_close(self, result, dbapi_cursor):
+ def soft_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
raise NotImplementedError()
- def hard_close(self, result, dbapi_cursor):
+ def hard_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
raise NotImplementedError()
- def yield_per(self, result, dbapi_cursor, num):
+ def yield_per(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ num: int,
+ ) -> None:
return
- def fetchone(self, result, dbapi_cursor, hard_close=False):
+ def fetchone(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ hard_close: bool = False,
+ ) -> Any:
raise NotImplementedError()
- def fetchmany(self, result, dbapi_cursor, size=None):
+ def fetchmany(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ size: Optional[int] = None,
+ ) -> Any:
raise NotImplementedError()
- def fetchall(self, result):
+ def fetchall(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ ) -> Any:
raise NotImplementedError()
- def handle_exception(self, result, dbapi_cursor, err):
+ def handle_exception(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ err: BaseException,
+ ) -> NoReturn:
raise err
__slots__ = ()
- def soft_close(self, result, dbapi_cursor):
+ def soft_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
result.cursor_strategy = _NO_CURSOR_DQL
- def hard_close(self, result, dbapi_cursor):
+ def hard_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
result.cursor_strategy = _NO_CURSOR_DQL
- def handle_exception(self, result, dbapi_cursor, err):
+ def handle_exception(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ err: BaseException,
+ ) -> NoReturn:
result.connection._handle_dbapi_exception(
err, None, None, dbapi_cursor, result.context
)
- def yield_per(self, result, dbapi_cursor, num):
+ def yield_per(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ num: int,
+ ) -> None:
result.cursor_strategy = BufferedRowCursorFetchStrategy(
dbapi_cursor,
{"max_row_buffer": num},
growth_factor=0,
)
- def fetchone(self, result, dbapi_cursor, hard_close=False):
+ def fetchone(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ hard_close: bool = False,
+ ) -> Any:
try:
row = dbapi_cursor.fetchone()
if row is None:
except BaseException as e:
self.handle_exception(result, dbapi_cursor, e)
- def fetchmany(self, result, dbapi_cursor, size=None):
+ def fetchmany(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ size: Optional[int] = None,
+ ) -> Any:
try:
if size is None:
l = dbapi_cursor.fetchmany()
except BaseException as e:
self.handle_exception(result, dbapi_cursor, e)
- def fetchall(self, result, dbapi_cursor):
+ def fetchall(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ ) -> Any:
try:
rows = dbapi_cursor.fetchall()
result._soft_close()
_NO_RESULT_METADATA = _NoResultMetaData()
+SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
+
+
class CursorResult(Result[_T]):
"""A Result that is representing state from a DBAPI cursor.
closed: bool = False
_is_cursor = True
- def __init__(self, context, cursor_strategy, cursor_description):
+ context: DefaultExecutionContext
+ dialect: Dialect
+ cursor_strategy: ResultFetchStrategy
+ connection: Connection
+
+ def __init__(
+ self,
+ context: DefaultExecutionContext,
+ cursor_strategy: ResultFetchStrategy,
+ cursor_description: Optional[_DBAPICursorDescription],
+ ):
self.context = context
self.dialect = context.dialect
self.cursor = context.cursor
if not self._soft_closed:
cursor = self.cursor
- self.cursor = None
+ self.cursor = None # type: ignore
self.connection._safe_close_cursor(cursor)
self._soft_closed = True
return self.dialect.supports_sane_multi_rowcount
@util.memoized_property
- def rowcount(self):
+ def rowcount(self) -> int:
"""Return the 'rowcount' for this result.
The 'rowcount' reports the number of rows *matched*
return self.context.rowcount
except BaseException as e:
self.cursor_strategy.handle_exception(self, self.cursor, e)
+ raise # not called
@property
def lastrowid(self):
)
return merged_result
- def close(self):
+ def close(self) -> Any:
"""Close this :class:`_engine.CursorResult`.
This closes out the underlying DBAPI cursor corresponding to the
self._soft_close(hard=True)
@_generative
- def yield_per(self, num):
+ def yield_per(self: SelfCursorResult, num: int) -> SelfCursorResult:
self._yield_per = num
self.cursor_strategy.yield_per(self, self.cursor, num)
return self
from .base import Engine
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
+ from .interfaces import _DBAPICursorDescription
from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import _IsolationLevel
def handle_dbapi_exception(self, e):
pass
- @property
- def rowcount(self):
+ @util.non_memoized_property
+ def rowcount(self) -> int:
return self.cursor.rowcount
def supports_sane_rowcount(self):
strategy = _cursor.BufferedRowCursorFetchStrategy(
self.cursor, self.execution_options
)
- cursor_description = (
+ cursor_description: _DBAPICursorDescription = (
strategy.alternate_cursor_description
or self.cursor.description
)
@property
def description(
self,
- ) -> Sequence[
- Tuple[
- str,
- "DBAPIType",
- Optional[int],
- Optional[int],
- Optional[int],
- Optional[int],
- Optional[bool],
- ]
- ]:
+ ) -> _DBAPICursorDescription:
"""The description attribute of the Cursor.
.. seealso::
_DBAPIAnyExecuteParams = Union[
_DBAPIMultiExecuteParams, _DBAPISingleExecuteParams
]
-_DBAPICursorDescription = Tuple[str, Any, Any, Any, Any, Any, Any]
+_DBAPICursorDescription = Tuple[
+ str,
+ "DBAPIType",
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[bool],
+]
_AnySingleExecuteParams = _DBAPISingleExecuteParams
_AnyMultiExecuteParams = _DBAPIMultiExecuteParams
"""
+ engine: Engine
+ """engine which the Connection is associated with"""
+
connection: Connection
"""Connection object which can be freely used by default value
generators to execute SQL. This Connection should reference the
from ..orm.base import SQLORMOperations
from ..sql import operators
from ..sql import or_
-from ..sql.elements import SQLCoreOperations
from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
from ..orm.interfaces import MapperProperty
from ..orm.interfaces import PropComparator
from ..orm.mapper import Mapper
+ from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _InfoType
+
_T = TypeVar("_T", bound=Any)
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
@property
def _comparator(self) -> PropComparator[Any]:
- return self._get_property().comparator
+ return getattr( # type: ignore
+ self.owning_class, self.target_collection
+ ).comparator
def __clause_element__(self) -> NoReturn:
raise NotImplementedError(
proxy.setter = setter
def _criterion_exists(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> ColumnElement[bool]:
is_has = kwargs.pop("is_has", None)
return self._comparator._criterion_exists(inner)
if self._target_is_object:
- prop = getattr(self.target_class, self.value_attr)
- value_expr = prop._criterion_exists(criterion, **kwargs)
+ attr = getattr(self.target_class, self.value_attr)
+ value_expr = attr.comparator._criterion_exists(criterion, **kwargs)
else:
if kwargs:
raise exc.ArgumentError(
return self._comparator._criterion_exists(value_expr)
def any(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
- ) -> SQLCoreOperations[Any]:
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce a proxied 'any' expression using EXISTS.
This expression will be a composed product
)
def has(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
- ) -> SQLCoreOperations[Any]:
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce a proxied 'has' expression using EXISTS.
This expression will be a composed product
self._ambiguous()
def any(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> NoReturn:
self._ambiguous()
def has(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> NoReturn:
self._ambiguous()
"""Public API functions and helpers for declarative."""
+from __future__ import annotations
+from typing import Callable
+from typing import TYPE_CHECKING
from ... import inspection
from ...orm import exc as orm_exc
from ...schema import Table
from ...util import OrderedDict
+if TYPE_CHECKING:
+ from ...engine.reflection import Inspector
+ from ...sql.schema import MetaData
+
class ConcreteBase:
"""A helper class for 'concrete' declarative mappings.
mapper = thingy.cls.__mapper__
metadata = mapper.class_.metadata
for rel in mapper._props.values():
+
if (
isinstance(rel, relationships.Relationship)
- and rel.secondary is not None
+ and rel._init_args.secondary._is_populated()
):
- if isinstance(rel.secondary, Table):
- cls._reflect_table(rel.secondary, insp)
- elif isinstance(rel.secondary, str):
+
+ secondary_arg = rel._init_args.secondary
+
+ if isinstance(secondary_arg.argument, Table):
+ cls._reflect_table(secondary_arg.argument, insp)
+ elif isinstance(secondary_arg.argument, str):
_, resolve_arg = _resolver(rel.parent.class_, rel)
- rel.secondary = resolve_arg(rel.secondary)
- rel.secondary._resolvers += (
+ resolver = resolve_arg(
+ secondary_arg.argument, True
+ )
+ resolver._resolvers += (
cls._sa_deferred_table_resolver(
insp, metadata
),
)
- # controversy! do we resolve it here? or leave
- # it deferred? I think doing it here is necessary
- # so the connection does not leak.
- rel.secondary = rel.secondary()
+ secondary_arg.argument = resolver()
@classmethod
- def _sa_deferred_table_resolver(cls, inspector, metadata):
- def _resolve(key):
+ def _sa_deferred_table_resolver(
+ cls, inspector: Inspector, metadata: MetaData
+ ) -> Callable[[str], Table]:
+ def _resolve(key: str) -> Table:
t1 = Table(key, metadata)
cls._reflect_table(t1, inspector)
return t1
return ret_expr
@util.non_memoized_property
- def property(self) -> Optional[interfaces.MapperProperty[_T]]:
- return None
+ def property(self) -> interfaces.MapperProperty[_T]:
+ raise NotImplementedError()
def adapt_to_entity(
self, adapt_to_entity: AliasedInsp[Any]
return [(self.expression, value)]
@util.non_memoized_property
- def property(self) -> Optional[MapperProperty[_T]]:
+ def property(self) -> MapperProperty[_T]:
# this accessor is not normally used, however is accessed by things
# like ORM synonyms if the hybrid is used in this context; the
# .property attribute is not necessarily accessible
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Collection
+from typing import Iterable
+from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Type
if TYPE_CHECKING:
from ._typing import _EntityType
from ._typing import _ORMColumnExprArgument
+ from .descriptor_props import _CC
from .descriptor_props import _CompositeAttrType
from .interfaces import PropComparator
from .mapper import Mapper
from .relationships import _ORMColCollectionArgument
from .relationships import _ORMOrderByArgument
from .relationships import _RelationshipJoinConditionArgument
+ from .session import _SessionBind
from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _FromClauseArgument
from ..sql._typing import _InfoType
+ from ..sql._typing import _OnClauseArgument
from ..sql._typing import _TypeEngineArgument
+ from ..sql.elements import ColumnElement
from ..sql.schema import _ServerDefaultType
from ..sql.schema import FetchedValue
from ..sql.selectable import Alias
from ..sql.selectable import Subquery
+
_T = typing.TypeVar("_T")
@overload
def composite(
- class_: Type[_T],
+ class_: Type[_CC],
*attrs: _CompositeAttrType[Any],
**kwargs: Any,
-) -> Composite[_T]:
+) -> Composite[_CC]:
...
def relationship(
argument: Optional[_RelationshipArgumentType[Any]] = None,
- secondary: Optional[FromClause] = None,
+ secondary: Optional[Union[FromClause, str]] = None,
*,
uselist: Optional[bool] = None,
collection_class: Optional[
cascade: str = "save-update, merge",
viewonly: bool = False,
lazy: _LazyLoadArgumentType = "select",
- passive_deletes: bool = False,
+ passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
active_history: bool = False,
enable_typechecks: bool = True,
foreign_keys: Optional[_ORMColCollectionArgument] = None,
remote_side: Optional[_ORMColCollectionArgument] = None,
join_depth: Optional[int] = None,
- comparator_factory: Optional[Type[PropComparator[Any]]] = None,
+ comparator_factory: Optional[Type[Relationship.Comparator[Any]]] = None,
single_parent: bool = False,
innerjoin: bool = False,
distinct_target_key: Optional[bool] = None,
than can be achieved with synonyms.
"""
- return Synonym(name, map_column, descriptor, comparator_factory, doc, info)
+ return Synonym(
+ name,
+ map_column=map_column,
+ descriptor=descriptor,
+ comparator_factory=comparator_factory,
+ doc=doc,
+ info=info,
+ )
-def create_session(bind=None, **kwargs):
+def create_session(
+ bind: Optional[_SessionBind] = None, **kwargs: Any
+) -> Session:
r"""Create a new :class:`.Session`
with no automation enabled by default.
return Session(bind=bind, **kwargs)
-def _mapper_fn(*arg, **kw):
+def _mapper_fn(*arg: Any, **kw: Any) -> NoReturn:
"""Placeholder for the now-removed ``mapper()`` function.
Classical mappings should be performed using the
)
-def dynamic_loader(argument, **kw):
+def dynamic_loader(
+ argument: Optional[_RelationshipArgumentType[Any]] = None, **kw: Any
+) -> Relationship[Any]:
"""Construct a dynamically-loading mapper property.
This is essentially the same as
return relationship(argument, **kw)
-def backref(name, **kwargs):
+def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
"""Create a back reference with explicit keyword arguments, which are the
same arguments one can send to :func:`relationship`.
return (name, kwargs)
-def deferred(*columns, **kw):
+def deferred(
+ column: _ORMColumnExprArgument[_T],
+ *additional_columns: _ORMColumnExprArgument[Any],
+ **kw: Any,
+) -> ColumnProperty[_T]:
r"""Indicate a column-based mapped attribute that by default will
not load unless accessed.
:ref:`deferred`
"""
- return ColumnProperty(deferred=True, *columns, **kw)
+ kw["deferred"] = True
+ return ColumnProperty(column, *additional_columns, **kw)
def query_expression(
return prop
-def clear_mappers():
+def clear_mappers() -> None:
"""Remove all mappers from all classes.
.. versionchanged:: 1.4 This function now locates all
def with_polymorphic(
- base,
- classes,
- selectable=False,
- flat=False,
- polymorphic_on=None,
- aliased=False,
- adapt_on_names=False,
- innerjoin=False,
- _use_mapper_path=False,
-):
+ base: Union[_O, Mapper[_O]],
+ classes: Iterable[Type[Any]],
+ selectable: Union[Literal[False, None], FromClause] = False,
+ flat: bool = False,
+ polymorphic_on: Optional[ColumnElement[Any]] = None,
+ aliased: bool = False,
+ innerjoin: bool = False,
+ adapt_on_names: bool = False,
+ _use_mapper_path: bool = False,
+) -> AliasedClass[_O]:
"""Produce an :class:`.AliasedClass` construct which specifies
columns for descendant mappers of the given base.
)
-def join(left, right, onclause=None, isouter=False, full=False):
+def join(
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ isouter: bool = False,
+ full: bool = False,
+) -> _ORMJoin:
r"""Produce an inner join between left and right clauses.
:func:`_orm.join` is an extension to the core join interface
return _ORMJoin(left, right, onclause, isouter, full)
-def outerjoin(left, right, onclause=None, full=False):
+def outerjoin(
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ full: bool = False,
+) -> _ORMJoin:
"""Produce a left outer join between left and right clauses.
This is the "outer join" version of the :func:`_orm.join` function,
import operator
from typing import Any
-from typing import Callable
from typing import Dict
+from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Type
if TYPE_CHECKING:
from .attributes import AttributeImpl
from .attributes import CollectionAttributeImpl
+ from .attributes import HasCollectionAdapter
+ from .attributes import QueryableAttribute
from .base import PassiveFlag
from .decl_api import registry as _registry_type
from .descriptor_props import _CompositeClassProto
+ from .interfaces import InspectionAttr
from .interfaces import MapperProperty
from .interfaces import UserDefinedOption
from .mapper import Mapper
from .state import InstanceState
from .util import AliasedClass
from .util import AliasedInsp
+ from ..sql._typing import _CE
from ..sql.base import ExecutableOption
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+
# I would have preferred this were bound=object however it seems
# to not travel in all situations when defined in that way.
_O = TypeVar("_O", bound=Any)
"""
+_OO = TypeVar("_OO", bound=object)
+"""The 'ORM mapped object, that's definitely object' type.
+
+"""
+
+
if TYPE_CHECKING:
_RegistryType = _registry_type
]
+_ClassDict = Mapping[str, Any]
_InstanceDict = Dict[str, Any]
_IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]]
roles.ExpressionElementRole[_T],
]
-# somehow Protocol didn't want to work for this one
-_ORMAdapterProto = Callable[
- [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T]
-]
+
+_ORMCOLEXPR = TypeVar("_ORMCOLEXPR", bound=ColumnElement[Any])
+
+
+class _ORMAdapterProto(Protocol):
+ """protocol for the :class:`.AliasedInsp._orm_adapt_element` method
+ which is a synonym for :class:`.AliasedInsp._adapt_element`.
+
+
+ """
+
+ def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE:
+ ...
class _LoaderCallable(Protocol):
def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]:
...
+ def insp_is_attribute(
+ obj: InspectionAttr,
+ ) -> TypeGuard[QueryableAttribute[Any]]:
+ ...
+
+ def attr_is_internal_proxy(
+ obj: InspectionAttr,
+ ) -> TypeGuard[QueryableAttribute[Any]]:
+ ...
+
def prop_is_relationship(
prop: MapperProperty[Any],
) -> TypeGuard[Relationship[Any]]:
) -> TypeGuard[CollectionAttributeImpl]:
...
+ def is_has_collection_adapter(
+ impl: AttributeImpl,
+ ) -> TypeGuard[HasCollectionAdapter]:
+ ...
+
else:
insp_is_mapper_property = operator.attrgetter("is_property")
insp_is_mapper = operator.attrgetter("is_mapper")
insp_is_aliased_class = operator.attrgetter("is_aliased_class")
+ insp_is_attribute = operator.attrgetter("is_attribute")
+ attr_is_internal_proxy = operator.attrgetter("_is_internal_proxy")
is_collection_impl = operator.attrgetter("collection")
prop_is_relationship = operator.attrgetter("_is_relationship")
+ is_has_collection_adapter = operator.attrgetter(
+ "_is_has_collection_adapter"
+ )
pass
-_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]]
+_AllPendingType = Sequence[
+ Tuple[Optional["InstanceState[Any]"], Optional[object]]
+]
NO_KEY = NoKey("no name")
supports_population: bool
dynamic: bool
+ _is_has_collection_adapter = False
+
_replace_token: AttributeEventToken
_remove_token: AttributeEventToken
_append_token: AttributeEventToken
state: InstanceState[Any],
dict_: _InstanceDict,
value: Any,
- initiator: Optional[AttributeEventToken],
+ initiator: Optional[AttributeEventToken] = None,
passive: PassiveFlag = PASSIVE_OFF,
check_old: Any = None,
pop: bool = False,
state: InstanceState[Any],
dict_: Dict[str, Any],
value: Any,
- initiator: Optional[AttributeEventToken],
+ initiator: Optional[AttributeEventToken] = None,
passive: PassiveFlag = PASSIVE_OFF,
check_old: Optional[object] = None,
pop: bool = False,
state: InstanceState[Any],
dict_: _InstanceDict,
value: Any,
- initiator: Optional[AttributeEventToken],
+ initiator: Optional[AttributeEventToken] = None,
passive: PassiveFlag = PASSIVE_OFF,
check_old: Any = None,
pop: bool = False,
class HasCollectionAdapter:
__slots__ = ()
+ collection: bool
+ _is_has_collection_adapter = True
+
def _dispose_previous_collection(
self,
state: InstanceState[Any],
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- user_data: Optional[_AdaptedCollectionProtocol] = None,
+ user_data: Literal[None] = ...,
passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
) -> CollectionAdapter:
...
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- user_data: Optional[_AdaptedCollectionProtocol] = None,
- passive: PassiveFlag = PASSIVE_OFF,
+ user_data: _AdaptedCollectionProtocol = ...,
+ passive: PassiveFlag = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = ...,
+ passive: PassiveFlag = ...,
) -> Union[
Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
]:
state: InstanceState[Any],
dict_: _InstanceDict,
user_data: Optional[_AdaptedCollectionProtocol] = None,
- passive: PassiveFlag = PASSIVE_OFF,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
) -> Union[
Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
]:
raise NotImplementedError()
+ def set(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken] = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ _adapt: bool = True,
+ ) -> None:
+ raise NotImplementedError()
+
if TYPE_CHECKING:
initiator: Optional[AttributeEventToken],
passive: PassiveFlag = PASSIVE_OFF,
) -> None:
- collection = self.get_collection(state, dict_, passive=passive)
+ collection = self.get_collection(
+ state, dict_, user_data=None, passive=passive
+ )
if collection is PASSIVE_NO_RESULT:
value = self.fire_append_event(state, dict_, value, initiator)
assert (
initiator: Optional[AttributeEventToken],
passive: PassiveFlag = PASSIVE_OFF,
) -> None:
- collection = self.get_collection(state, state.dict, passive=passive)
+ collection = self.get_collection(
+ state, state.dict, user_data=None, passive=passive
+ )
if collection is PASSIVE_NO_RESULT:
self.fire_remove_event(state, dict_, value, initiator)
assert (
dict_: _InstanceDict,
value: Any,
initiator: Optional[AttributeEventToken] = None,
- passive: PassiveFlag = PASSIVE_OFF,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
check_old: Any = None,
pop: bool = False,
_adapt: bool = True,
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- user_data: Optional[_AdaptedCollectionProtocol] = None,
+ user_data: Literal[None] = ...,
passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
) -> CollectionAdapter:
...
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- user_data: Optional[_AdaptedCollectionProtocol] = None,
+ user_data: _AdaptedCollectionProtocol = ...,
+ passive: PassiveFlag = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = ...,
passive: PassiveFlag = PASSIVE_OFF,
) -> Union[
Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
impl_class: Optional[Type[AttributeImpl]] = None,
backref: Optional[str] = None,
**kw: Any,
-) -> InstrumentedAttribute[Any]:
+) -> QueryableAttribute[Any]:
manager = manager_of_class(class_)
if uselist:
attr._dispose_previous_collection(state, old, old_collection, False)
user_data = attr._default_value(state, dict_)
- adapter = attr.get_collection(state, dict_, user_data)
+ adapter: CollectionAdapter = attr.get_collection(state, dict_, user_data)
adapter._reset_empty()
return adapter
from typing import Callable
from typing import Dict
from typing import Generic
+from typing import no_type_check
from typing import Optional
from typing import overload
from typing import Type
from ..util import FastIntFlag
from ..util.langhelpers import TypingOnly
from ..util.typing import Literal
-from ..util.typing import Self
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _ExternalEntityType
from ._typing import _InternalEntityType
from .attributes import InstrumentedAttribute
from .instrumentation import ClassManager
+ from .interfaces import PropComparator
from .mapper import Mapper
from .state import InstanceState
from .util import AliasedClass
+ from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _InfoType
+ from ..sql.elements import ColumnElement
_T = TypeVar("_T", bound=Any)
EXT_STOP = util.symbol("EXT_STOP")
EXT_SKIP = util.symbol("EXT_SKIP")
-ONETOMANY = util.symbol(
- "ONETOMANY",
+
+class RelationshipDirection(Enum):
+ ONETOMANY = 1
"""Indicates the one-to-many direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
-MANYTOONE = util.symbol(
- "MANYTOONE",
+ MANYTOONE = 2
"""Indicates the many-to-one direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
-MANYTOMANY = util.symbol(
- "MANYTOMANY",
+ MANYTOMANY = 3
"""Indicates the many-to-many direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
+
+
+ONETOMANY, MANYTOONE, MANYTOMANY = tuple(RelationshipDirection)
class InspectionAttrExtensionType(Enum):
_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
-_F = TypeVar("_F", bound=Callable)
+_F = TypeVar("_F", bound=Callable[..., Any])
_Self = TypeVar("_Self")
return None
-def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]:
+def _class_to_mapper(
+ class_or_mapper: Union[Mapper[_T], Type[_T]]
+) -> Mapper[_T]:
+ # can't get mypy to see an overload for this
insp = inspection.inspect(class_or_mapper, False)
if insp is not None:
- return insp.mapper
+ return insp.mapper # type: ignore
else:
+ assert isinstance(class_or_mapper, type)
raise exc.UnmappedClassError(class_or_mapper)
def _mapper_or_none(
- entity: Union[_T, _InternalEntityType[_T]]
+ entity: Union[Type[_T], _InternalEntityType[_T]]
) -> Optional[Mapper[_T]]:
"""Return the :class:`_orm.Mapper` for the given class or None if the
class is not mapped.
"""
+ # can't get mypy to see an overload for this
insp = inspection.inspect(entity, False)
if insp is not None:
- return insp.mapper
+ return insp.mapper # type: ignore
else:
return None
-def _is_mapped_class(entity):
+def _is_mapped_class(entity: Any) -> bool:
"""Return True if the given object is a mapped class,
:class:`_orm.Mapper`, or :class:`.AliasedClass`.
"""
)
-def _orm_columns(entity):
- insp = inspection.inspect(entity, False)
- if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"):
- return [c for c in insp.selectable.c]
- else:
- return [entity]
-
-
-def _is_aliased_class(entity):
+def _is_aliased_class(entity: Any) -> bool:
insp = inspection.inspect(entity, False)
return insp is not None and getattr(insp, "is_aliased_class", False)
-def _entity_descriptor(entity, key):
+@no_type_check
+def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any:
"""Return a class attribute given an entity and string name.
May return :class:`.InstrumentedAttribute` or user-defined
if typing.TYPE_CHECKING:
- def of_type(self, class_):
+ def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
...
- def and_(self, *criteria):
+ def and_(
+ self, *criteria: _ColumnExpressionArgument[bool]
+ ) -> PropComparator[bool]:
...
- def any(self, criterion=None, **kwargs): # noqa: A001
+ def any( # noqa: A001
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
...
- def has(self, criterion=None, **kwargs):
+ def has(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
...
if typing.TYPE_CHECKING:
@overload
- def __get__(self: Self, instance: Any, owner: Literal[None]) -> Self:
+ def __get__(
+ self, instance: Any, owner: Literal[None]
+ ) -> ORMDescriptor[_T]:
...
@overload
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
"""Routines to handle the string class registry used by declarative.
from __future__ import annotations
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Mapping
from typing import MutableMapping
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
from .. import inspection
from .. import util
from ..sql.schema import _get_table_key
+from ..util.typing import CallableReference
+
+if TYPE_CHECKING:
+ from .relationships import Relationship
+ from ..sql.schema import MetaData
+ from ..sql.schema import Table
+
+_T = TypeVar("_T", bound=Any)
_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
# the _decl_class_registry, which is usually weak referencing.
# the internal registries here link to classes with weakrefs and remove
# themselves when all references to contained classes are removed.
-_registries = set()
+_registries: Set[ClsRegistryToken] = set()
-def add_class(classname, cls, decl_class_registry):
+def add_class(
+ classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType
+) -> None:
"""Add a class to the _decl_class_registry associated with the
given declarative class.
existing = decl_class_registry[classname]
if not isinstance(existing, _MultipleClassMarker):
existing = decl_class_registry[classname] = _MultipleClassMarker(
- [cls, existing]
+ [cls, cast("Type[Any]", existing)]
)
else:
decl_class_registry[classname] = cls
try:
- root_module = decl_class_registry["_sa_module_registry"]
+ root_module = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
except KeyError:
decl_class_registry[
"_sa_module_registry"
module.add_class(classname, cls)
-def remove_class(classname, cls, decl_class_registry):
+def remove_class(
+ classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType
+) -> None:
if classname in decl_class_registry:
existing = decl_class_registry[classname]
if isinstance(existing, _MultipleClassMarker):
del decl_class_registry[classname]
try:
- root_module = decl_class_registry["_sa_module_registry"]
+ root_module = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
except KeyError:
return
module.remove_class(classname, cls)
-def _key_is_empty(key, decl_class_registry, test):
+def _key_is_empty(
+ key: str,
+ decl_class_registry: _ClsRegistryType,
+ test: Callable[[Any], bool],
+) -> bool:
"""test if a key is empty of a certain object.
used for unit tests against the registry to see if garbage collection
for sub_thing in thing.contents:
if test(sub_thing):
return False
+ else:
+ raise NotImplementedError("unknown codepath")
else:
return not test(thing)
__slots__ = "on_remove", "contents", "__weakref__"
- def __init__(self, classes, on_remove=None):
+ contents: Set[weakref.ref[Type[Any]]]
+ on_remove: CallableReference[Optional[Callable[[], None]]]
+
+ def __init__(
+ self,
+ classes: Iterable[Type[Any]],
+ on_remove: Optional[Callable[[], None]] = None,
+ ):
self.on_remove = on_remove
self.contents = set(
[weakref.ref(item, self._remove_item) for item in classes]
)
_registries.add(self)
- def remove_item(self, cls):
+ def remove_item(self, cls: Type[Any]) -> None:
self._remove_item(weakref.ref(cls))
- def __iter__(self):
+ def __iter__(self) -> Generator[Optional[Type[Any]], None, None]:
return (ref() for ref in self.contents)
- def attempt_get(self, path, key):
+ def attempt_get(self, path: List[str], key: str) -> Type[Any]:
if len(self.contents) > 1:
raise exc.InvalidRequestError(
'Multiple classes found for path "%s" '
raise NameError(key)
return cls
- def _remove_item(self, ref):
+ def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None:
self.contents.discard(ref)
if not self.contents:
_registries.discard(self)
if self.on_remove:
self.on_remove()
- def add_item(self, item):
+ def add_item(self, item: Type[Any]) -> None:
# protect against class registration race condition against
# asynchronous garbage collection calling _remove_item,
# [ticket:3208]
__slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
- def __init__(self, name, parent):
+ parent: Optional[_ModuleMarker]
+ contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]]
+ mod_ns: _ModNS
+ path: List[str]
+
+ def __init__(self, name: str, parent: Optional[_ModuleMarker]):
self.parent = parent
self.name = name
self.contents = {}
self.path = []
_registries.add(self)
- def __contains__(self, name):
+ def __contains__(self, name: str) -> bool:
return name in self.contents
- def __getitem__(self, name):
+ def __getitem__(self, name: str) -> ClsRegistryToken:
return self.contents[name]
- def _remove_item(self, name):
+ def _remove_item(self, name: str) -> None:
self.contents.pop(name, None)
if not self.contents and self.parent is not None:
self.parent._remove_item(self.name)
_registries.discard(self)
- def resolve_attr(self, key):
- return getattr(self.mod_ns, key)
+ def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]:
+ return self.mod_ns.__getattr__(key)
- def get_module(self, name):
+ def get_module(self, name: str) -> _ModuleMarker:
if name not in self.contents:
marker = _ModuleMarker(name, self)
self.contents[name] = marker
else:
- marker = self.contents[name]
+ marker = cast(_ModuleMarker, self.contents[name])
return marker
- def add_class(self, name, cls):
+ def add_class(self, name: str, cls: Type[Any]) -> None:
if name in self.contents:
- existing = self.contents[name]
+ existing = cast(_MultipleClassMarker, self.contents[name])
existing.add_item(cls)
else:
existing = self.contents[name] = _MultipleClassMarker(
[cls], on_remove=lambda: self._remove_item(name)
)
- def remove_class(self, name, cls):
+ def remove_class(self, name: str, cls: Type[Any]) -> None:
if name in self.contents:
- existing = self.contents[name]
+ existing = cast(_MultipleClassMarker, self.contents[name])
existing.remove_item(cls)
class _ModNS:
__slots__ = ("__parent",)
- def __init__(self, parent):
+ __parent: _ModuleMarker
+
+ def __init__(self, parent: _ModuleMarker):
self.__parent = parent
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]:
try:
value = self.__parent.contents[key]
except KeyError:
class _GetColumns:
__slots__ = ("cls",)
- def __init__(self, cls):
+ cls: Type[Any]
+
+ def __init__(self, cls: Type[Any]):
self.cls = cls
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
mp = class_mapper(self.cls, configure=False)
if mp:
if key not in mp.all_orm_descriptors:
desc = mp.all_orm_descriptors[key]
if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
+ assert isinstance(desc, attributes.QueryableAttribute)
prop = desc.property
if isinstance(prop, Synonym):
key = prop.name
class _GetTable:
__slots__ = "key", "metadata"
- def __init__(self, key, metadata):
+ key: str
+ metadata: MetaData
+
+ def __init__(self, key: str, metadata: MetaData):
self.key = key
self.metadata = metadata
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Table:
return self.metadata.tables[_get_table_key(key, self.key)]
-def _determine_container(key, value):
+def _determine_container(key: str, value: Any) -> _GetColumns:
if isinstance(value, _MultipleClassMarker):
value = value.attempt_get([], key)
return _GetColumns(value)
"favor_tables",
)
- def __init__(self, cls, prop, fallback, arg, favor_tables=False):
+ cls: Type[Any]
+ prop: Relationship[Any]
+ fallback: Mapping[str, Any]
+ arg: str
+ favor_tables: bool
+ _resolvers: Tuple[Callable[[str], Any], ...]
+
+ def __init__(
+ self,
+ cls: Type[Any],
+ prop: Relationship[Any],
+ fallback: Mapping[str, Any],
+ arg: str,
+ favor_tables: bool = False,
+ ):
self.cls = cls
self.prop = prop
self.arg = arg
self._resolvers = ()
self.favor_tables = favor_tables
- def _access_cls(self, key):
+ def _access_cls(self, key: str) -> Any:
cls = self.cls
manager = attributes.manager_of_class(cls)
decl_base = manager.registry
+ assert decl_base is not None
decl_class_registry = decl_base._class_registry
metadata = decl_base.metadata
if key in metadata.tables:
return metadata.tables[key]
elif key in metadata._schemas:
- return _GetTable(key, cls.metadata)
+ return _GetTable(key, getattr(cls, "metadata", metadata))
if key in decl_class_registry:
return _determine_container(key, decl_class_registry[key])
if key in metadata.tables:
return metadata.tables[key]
elif key in metadata._schemas:
- return _GetTable(key, cls.metadata)
+ return _GetTable(key, getattr(cls, "metadata", metadata))
- if (
- "_sa_module_registry" in decl_class_registry
- and key in decl_class_registry["_sa_module_registry"]
+ if "_sa_module_registry" in decl_class_registry and key in cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
):
- registry = decl_class_registry["_sa_module_registry"]
+ registry = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
return registry.resolve_attr(key)
elif self._resolvers:
for resolv in self._resolvers:
return self.fallback[key]
- def _raise_for_name(self, name, err):
+ def _raise_for_name(self, name: str, err: Exception) -> NoReturn:
generic_match = re.match(r"(.+)\[(.+)\]", name)
if generic_match:
% (self.prop.parent, self.arg, name, self.cls)
) from err
- def _resolve_name(self):
+ def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]:
name = self.arg
d = self._dict
rval = None
if isinstance(rval, _GetColumns):
return rval.cls
else:
+ if TYPE_CHECKING:
+ assert isinstance(rval, (type, Table, _ModNS))
return rval
- def __call__(self):
+ def __call__(self) -> Any:
try:
x = eval(self.arg, globals(), self._dict)
self._raise_for_name(n.args[0], n)
-_fallback_dict = None
+_fallback_dict: Mapping[str, Any] = None # type: ignore
-def _resolver(cls, prop):
+def _resolver(
+ cls: Type[Any], prop: Relationship[Any]
+) -> Tuple[
+ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+ Callable[[str, bool], _class_resolver],
+]:
global _fallback_dict
{"foreign": foreign, "remote": remote}
)
- def resolve_arg(arg, favor_tables=False):
+ def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver:
return _class_resolver(
cls, prop, _fallback_dict, arg, favor_tables=favor_tables
)
- def resolve_name(arg):
+ def resolve_name(
+ arg: str,
+ ) -> Callable[[], Union[Type[Any], Table, _ModNS]]:
return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
return resolve_name, resolve_arg
from typing import Dict
from typing import Iterable
from typing import List
+from typing import NoReturn
from typing import Optional
from typing import Set
from typing import Tuple
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from .attributes import AttributeEventToken
from .attributes import CollectionAttributeImpl
from .mapped_collection import attribute_mapped_collection
from .mapped_collection import column_mapped_collection
self.invalidated = False
self.empty = False
- def _warn_invalidated(self):
+ def _warn_invalidated(self) -> None:
util.warn("This collection has been invalidated.")
@property
return self._data()
@property
- def _referenced_by_owner(self):
+ def _referenced_by_owner(self) -> bool:
"""return True if the owner state still refers to this collection.
This will return False within a bulk replace operation,
def bulk_appender(self):
return self._data()._sa_appender
- def append_with_event(self, item, initiator=None):
+ def append_with_event(
+ self, item: Any, initiator: Optional[AttributeEventToken] = None
+ ) -> None:
"""Add an entity to the collection, firing mutation events."""
self._data()._sa_appender(item, _sa_initiator=initiator)
self.empty = True
self.owner_state._empty_collections[self._key] = user_data
- def _reset_empty(self):
+ def _reset_empty(self) -> None:
assert (
self.empty
), "This collection adapter is not in the 'empty' state"
self._key
] = self.owner_state._empty_collections.pop(self._key)
- def _refuse_empty(self):
+ def _refuse_empty(self) -> NoReturn:
raise sa_exc.InvalidRequestError(
"This is a special 'empty' collection which cannot accommodate "
"internal mutation operations"
)
- def append_without_event(self, item):
+ def append_without_event(self, item: Any) -> None:
"""Add or restore an entity to the collection, firing no events."""
if self.empty:
self._refuse_empty()
self._data()._sa_appender(item, _sa_initiator=False)
- def append_multiple_without_event(self, items):
+ def append_multiple_without_event(self, items: Iterable[Any]) -> None:
"""Add or restore an entity to the collection, firing no events."""
if self.empty:
self._refuse_empty()
def bulk_remover(self):
return self._data()._sa_remover
- def remove_with_event(self, item, initiator=None):
+ def remove_with_event(
+ self, item: Any, initiator: Optional[AttributeEventToken] = None
+ ) -> None:
"""Remove an entity from the collection, firing mutation events."""
self._data()._sa_remover(item, _sa_initiator=initiator)
- def remove_without_event(self, item):
+ def remove_without_event(self, item: Any) -> None:
"""Remove an entity from the collection, firing no events."""
if self.empty:
self._refuse_empty()
self._data()._sa_remover(item, _sa_initiator=False)
- def clear_with_event(self, initiator=None):
+ def clear_with_event(
+ self, initiator: Optional[AttributeEventToken] = None
+ ) -> None:
"""Empty the collection, firing a mutation event for each entity."""
if self.empty:
for item in list(self):
remover(item, _sa_initiator=initiator)
- def clear_without_event(self):
+ def clear_without_event(self) -> None:
"""Empty the collection, firing no events."""
if self.empty:
from typing import Any
from typing import cast
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Optional
from typing import Set
from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
+from ..sql._typing import _TP
from ..sql._typing import is_dml
from ..sql._typing import is_insert_update
from ..sql._typing import is_select_base
from ..sql.dml import UpdateBase
from ..sql.elements import GroupedElement
from ..sql.elements import TextClause
-from ..sql.selectable import ExecutableReturnsRows
from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
from ..sql.selectable import LABEL_STYLE_NONE
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import Select
from ..sql.selectable import SelectLabelStyle
from ..sql.selectable import SelectState
+from ..sql.selectable import TypedReturnsRows
from ..sql.visitors import InternalTraversal
if TYPE_CHECKING:
from ._typing import _InternalEntityType
+ from .loading import PostLoad
from .mapper import Mapper
from .query import Query
+ from .session import _BindArguments
+ from .session import Session
+ from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptionsParameter
+ from ..sql._typing import _ColumnsClauseArgument
+ from ..sql.compiler import SQLCompiler
from ..sql.dml import _DMLTableElement
from ..sql.elements import ColumnElement
+ from ..sql.selectable import _JoinTargetElement
from ..sql.selectable import _LabelConventionCallable
+ from ..sql.selectable import _SetupJoinsElement
+ from ..sql.selectable import ExecutableReturnsRows
from ..sql.selectable import SelectBase
from ..sql.type_api import TypeEngine
_EMPTY_DICT = util.immutabledict()
-LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM")
+LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM
class QueryContext:
"loaders_require_uniquing",
)
+ runid: int
+ post_load_paths: Dict[PathRegistry, PostLoad]
+ compile_state: ORMCompileState
+
class default_load_options(Options):
_only_return_tuples = False
_populate_existing = False
def __init__(
self,
- compile_state,
- statement,
- params,
- session,
- load_options,
- execution_options=None,
- bind_arguments=None,
+ compile_state: CompileState,
+ statement: Union[Select[Any], FromStatement[Any]],
+ params: _CoreSingleExecuteParams,
+ session: Session,
+ load_options: Union[
+ Type[QueryContext.default_load_options],
+ QueryContext.default_load_options,
+ ],
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ bind_arguments: Optional[_BindArguments] = None,
):
self.load_options = load_options
self.execution_options = execution_options or _EMPTY_DICT
attributes: Dict[Any, Any]
global_attributes: Dict[Any, Any]
- statement: Union[Select, FromStatement]
- select_statement: Union[Select, FromStatement]
+ statement: Union[Select[Any], FromStatement[Any]]
+ select_statement: Union[Select[Any], FromStatement[Any]]
_entities: List[_QueryEntity]
_polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter]
compile_options: Union[
Tuple[Any, ...]
]
current_path: PathRegistry = _path_registry
+ _has_mapper_entities = False
def __init__(self, *arg, **kw):
raise NotImplementedError()
return SelectState._column_naming_convention(label_style)
@classmethod
- def create_for_statement(cls, statement_container, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
This method is always invoked in the context of SQLCompiler.process().
eager_joins = _EMPTY_DICT
@classmethod
- def create_for_statement(cls, statement_container, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement_container: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
if compiler is not None:
toplevel = not compiler.stack
return None
-class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
+class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
"""Core construct that represents a load of ORM objects from various
:class:`.ReturnsRows` and other classes including:
_for_update_arg = None
- element: Union[SelectBase, TextClause, UpdateBase]
+ element: Union[ExecutableReturnsRows, TextClause]
_traverse_internals = [
("_raw_columns", InternalTraversal.dp_clauseelement_list),
("_compile_options", InternalTraversal.dp_has_cache_key)
]
- def __init__(self, entities, element):
+ def __init__(
+ self,
+ entities: Iterable[_ColumnsClauseArgument[Any]],
+ element: Union[ExecutableReturnsRows, TextClause],
+ ):
self._raw_columns = [
coercions.expect(
roles.ColumnsClauseRole,
_having_criteria = ()
@classmethod
- def create_for_statement(cls, statement, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
)
@classmethod
- @util.preload_module("sqlalchemy.orm.query")
def from_statement(cls, statement, from_statement):
- query = util.preloaded.orm_query
from_statement = coercions.expect(
roles.ReturnsRowsRole,
apply_propagate_attrs=statement,
)
- stmt = query.FromStatement(statement._raw_columns, from_statement)
+ stmt = FromStatement(statement._raw_columns, from_statement)
stmt.__dict__.update(
_with_options=statement._with_options,
return d
-def _legacy_filter_by_entity_zero(query_or_augmented_select):
+def _legacy_filter_by_entity_zero(
+ query_or_augmented_select: Union[Query[Any], Select[Any]]
+) -> Optional[_InternalEntityType[Any]]:
self = query_or_augmented_select
if self._setup_joins:
_last_joined_entity = self._last_joined_entity
return _entity_from_pre_ent_zero(self)
-def _entity_from_pre_ent_zero(query_or_augmented_select):
+def _entity_from_pre_ent_zero(
+ query_or_augmented_select: Union[Query[Any], Select[Any]]
+) -> Optional[_InternalEntityType[Any]]:
self = query_or_augmented_select
if not self._raw_columns:
return None
return ent
-def _determine_last_joined_entity(setup_joins, entity_zero=None):
+def _determine_last_joined_entity(
+ setup_joins: Tuple[_SetupJoinsElement, ...],
+ entity_zero: Optional[_InternalEntityType[Any]] = None,
+) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]:
if not setup_joins:
return None
(target, onclause, from_, flags) = setup_joins[-1]
- if isinstance(target, interfaces.PropComparator):
+ if isinstance(
+ target,
+ attributes.QueryableAttribute,
+ ):
return target.entity
else:
return target
__slots__ = ()
+ supports_single_entity: bool
+
_non_hashable_value = False
_null_column_type = False
use_id_for_hash = False
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
+ def row_processor(self, context, result):
+ raise NotImplementedError()
+
@classmethod
def to_compile_state(
cls, compile_state, entities, entities_collection, is_current_entities
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
"""Public API functions and helpers for declarative."""
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import ClassVar
+from typing import Dict
+from typing import FrozenSet
+from typing import Iterator
from typing import Mapping
from typing import Optional
+from typing import overload
+from typing import Set
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import weakref
from . import attributes
from . import clsregistry
-from . import exc as orm_exc
from . import instrumentation
from . import interfaces
from . import mapperlib
from .decl_base import _mapper
from .descriptor_props import Synonym as _orm_synonym
from .mapper import Mapper
+from .state import InstanceState
from .. import exc
from .. import inspection
from .. import util
from ..sql.elements import SQLCoreOperations
from ..sql.schema import MetaData
from ..sql.selectable import FromClause
-from ..sql.type_api import TypeEngine
from ..util import hybridmethod
from ..util import hybridproperty
from ..util import typing as compat_typing
+from ..util.typing import CallableReference
+from ..util.typing import Literal
+if TYPE_CHECKING:
+ from ._typing import _O
+ from ._typing import _RegistryType
+ from .descriptor_props import Synonym
+ from .instrumentation import ClassManager
+ from .interfaces import MapperProperty
+ from ..sql._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
-_TypeAnnotationMapType = Mapping[Type, Union[Type[TypeEngine], TypeEngine]]
+# it's not clear how to have Annotated, Union objects etc. as keys here
+# from a typing perspective so just leave it open ended for now
+_TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"]
+_MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"]
+
+_DeclaredAttrDecorated = Callable[
+ ..., Union[Mapped[_T], SQLCoreOperations[_T]]
+]
-def has_inherited_table(cls):
+def has_inherited_table(cls: Type[_O]) -> bool:
"""Given a class, return True if any of the classes it inherits from has a
mapped table, otherwise return False.
class _DynamicAttributesType(type):
- def __setattr__(cls, key, value):
+ def __setattr__(cls, key: str, value: Any) -> None:
if "__mapper__" in cls.__dict__:
_add_attribute(cls, key, value)
else:
type.__setattr__(cls, key, value)
- def __delattr__(cls, key):
+ def __delattr__(cls, key: str) -> None:
if "__mapper__" in cls.__dict__:
_del_attribute(cls, key)
else:
class DeclarativeAttributeIntercept(
- _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"]
+ _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
):
"""Metaclass that may be used in conjunction with the
:class:`_orm.DeclarativeBase` class to support addition of class
class DeclarativeMeta(
- _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"]
+ _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
):
metadata: MetaData
- registry: "RegistryType"
+ registry: RegistryType
def __init__(
cls, classname: Any, bases: Any, dict_: Any, **kw: Any
type.__init__(cls, classname, bases, dict_)
-def synonym_for(name, map_column=False):
+def synonym_for(
+ name: str, map_column: bool = False
+) -> Callable[[Callable[..., Any]], Synonym[Any]]:
"""Decorator that produces an :func:`_orm.synonym`
attribute in conjunction with a Python descriptor.
"""
- def decorate(fn):
+ def decorate(fn: Callable[..., Any]) -> Synonym[Any]:
return _orm_synonym(name, map_column=map_column, descriptor=fn)
return decorate
if typing.TYPE_CHECKING:
- def __set__(self, instance, value):
+ def __set__(self, instance: Any, value: Any) -> None:
...
- def __delete__(self, instance: Any):
+ def __delete__(self, instance: Any) -> None:
...
def __init__(
self,
- fn: Callable[..., Union[Mapped[_T], SQLCoreOperations[_T]]],
- cascading=False,
+ fn: _DeclaredAttrDecorated[_T],
+ cascading: bool = False,
):
self.fget = fn
self._cascading = cascading
def _collect_return_annotation(self) -> Optional[Type[Any]]:
return util.get_annotations(self.fget).get("return")
- def __get__(self, instance, owner) -> InstrumentedAttribute[_T]:
+ # this is the Mapped[] API where at class descriptor get time we want
+ # the type checker to see InstrumentedAttribute[_T]. However the
+ # callable function prior to mapping in fact calls the given
+ # declarative function that does not return InstrumentedAttribute
+ @overload
+ def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]:
+ ...
+
+ @overload
+ def __get__(self, instance: object, owner: Any) -> _T:
+ ...
+
+ def __get__(
+ self, instance: Optional[object], owner: Any
+ ) -> Union[InstrumentedAttribute[_T], _T]:
# the declared_attr needs to make use of a cache that exists
# for the span of the declarative scan_attributes() phase.
# to achieve this we look at the class manager that's configured.
+
+ # note this method should not be called outside of the declarative
+ # setup phase
+
cls = owner
manager = attributes.opt_manager_of_class(cls)
if manager is None:
"Unmanaged access of declarative attribute %s from "
"non-mapped class %s" % (self.fget.__name__, cls.__name__)
)
- return self.fget(cls)
+ return self.fget(cls) # type: ignore
elif manager.is_mapped:
# the class is mapped, which means we're outside of the declarative
# scan setup, just run the function.
- return self.fget(cls)
+ return self.fget(cls) # type: ignore
# here, we are inside of the declarative scan. use the registry
# that is tracking the values of these attributes.
declarative_scan = manager.declarative_scan()
+
+ # assert that we are in fact in the declarative scan
assert declarative_scan is not None
+
reg = declarative_scan.declared_attr_reg
if self in reg:
- return reg[self]
+ return reg[self] # type: ignore
else:
reg[self] = obj = self.fget(cls)
- return obj
+ return obj # type: ignore
@hybridmethod
- def _stateful(cls, **kw):
+ def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]:
return _stateful_declared_attr(**kw)
@hybridproperty
- def cascading(cls):
+ def cascading(cls) -> _stateful_declared_attr[_T]:
"""Mark a :class:`.declared_attr` as cascading.
This is a special-use modifier which indicates that a column
return cls._stateful(cascading=True)
-class _stateful_declared_attr(declared_attr):
- def __init__(self, **kw):
+class _stateful_declared_attr(declared_attr[_T]):
+ kw: Dict[str, Any]
+
+ def __init__(self, **kw: Any):
self.kw = kw
- def _stateful(self, **kw):
+ @hybridmethod
+ def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]:
new_kw = self.kw.copy()
new_kw.update(kw)
return _stateful_declared_attr(**new_kw)
- def __call__(self, fn):
+ def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]:
return declared_attr(fn, **self.kw)
-def declarative_mixin(cls):
+def declarative_mixin(cls: Type[_T]) -> Type[_T]:
"""Mark a class as providing the feature of "declarative mixin".
E.g.::
return cls
-def _setup_declarative_base(cls):
+def _setup_declarative_base(cls: Type[Any]) -> None:
if "metadata" in cls.__dict__:
- metadata = cls.metadata
+ metadata = cls.metadata # type: ignore
else:
metadata = None
reg = registry(
metadata=metadata, type_annotation_map=type_annotation_map
)
- cls.registry = reg
+ cls.registry = reg # type: ignore
- cls._sa_registry = reg
+ cls._sa_registry = reg # type: ignore
if "metadata" not in cls.__dict__:
- cls.metadata = cls.registry.metadata
+ cls.metadata = cls.registry.metadata # type: ignore
-class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
+class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]):
"""Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass
to intercept new attributes.
"""
- registry: ClassVar["registry"]
- _sa_registry: ClassVar["registry"]
+ registry: ClassVar[_RegistryType]
+ _sa_registry: ClassVar[_RegistryType]
metadata: ClassVar[MetaData]
- __mapper__: ClassVar[Mapper]
+ __mapper__: ClassVar[Mapper[Any]]
__table__: Optional[FromClause]
if typing.TYPE_CHECKING:
class DeclarativeBase(
- inspection.Inspectable["InstanceState"],
+ inspection.Inspectable[InstanceState[Any]],
metaclass=DeclarativeAttributeIntercept,
):
"""Base class used for declarative class definitions.
"""
- registry: ClassVar["registry"]
- _sa_registry: ClassVar["registry"]
+ registry: ClassVar[_RegistryType]
+ _sa_registry: ClassVar[_RegistryType]
metadata: ClassVar[MetaData]
- __mapper__: ClassVar[Mapper]
+ __mapper__: ClassVar[Mapper[Any]]
__table__: Optional[FromClause]
if typing.TYPE_CHECKING:
if DeclarativeBase in cls.__bases__:
_setup_declarative_base(cls)
else:
- cls._sa_registry.map_declaratively(cls)
+ _as_declarative(cls._sa_registry, cls, cls.__dict__)
-def add_mapped_attribute(target, key, attr):
+def add_mapped_attribute(
+ target: Type[_O], key: str, attr: MapperProperty[Any]
+) -> None:
"""Add a new mapped attribute to an ORM mapped class.
E.g.::
def declarative_base(
+ *,
metadata: Optional[MetaData] = None,
- mapper=None,
- cls=object,
- name="Base",
+ mapper: Optional[Callable[..., Mapper[Any]]] = None,
+ cls: Type[Any] = object,
+ name: str = "Base",
class_registry: Optional[clsregistry._ClsRegistryType] = None,
type_annotation_map: Optional[_TypeAnnotationMapType] = None,
constructor: Callable[..., None] = _declarative_constructor,
- metaclass=DeclarativeMeta,
+ metaclass: Type[Any] = DeclarativeMeta,
) -> Any:
r"""Construct a base class for declarative class definitions.
"""
+ _class_registry: clsregistry._ClsRegistryType
+ _managers: weakref.WeakKeyDictionary[ClassManager[Any], Literal[True]]
+ _non_primary_mappers: weakref.WeakKeyDictionary[Mapper[Any], Literal[True]]
+ metadata: MetaData
+ constructor: CallableReference[Callable[..., None]]
+ type_annotation_map: _MutableTypeAnnotationMapType
+ _dependents: Set[_RegistryType]
+ _dependencies: Set[_RegistryType]
+ _new_mappers: bool
+
def __init__(
self,
+ *,
metadata: Optional[MetaData] = None,
class_registry: Optional[clsregistry._ClsRegistryType] = None,
type_annotation_map: Optional[_TypeAnnotationMapType] = None,
def update_type_annotation_map(
self,
- type_annotation_map: Mapping[
- Type, Union[Type[TypeEngine], TypeEngine]
- ],
+ type_annotation_map: _TypeAnnotationMapType,
) -> None:
"""update the :paramref:`_orm.registry.type_annotation_map` with new
values."""
)
@property
- def mappers(self):
+ def mappers(self) -> FrozenSet[Mapper[Any]]:
"""read only collection of all :class:`_orm.Mapper` objects."""
return frozenset(manager.mapper for manager in self._managers).union(
self._non_primary_mappers
)
- def _set_depends_on(self, registry):
+ def _set_depends_on(self, registry: RegistryType) -> None:
if registry is self:
return
registry._dependents.add(self)
self._dependencies.add(registry)
- def _flag_new_mapper(self, mapper):
+ def _flag_new_mapper(self, mapper: Mapper[Any]) -> None:
mapper._ready_for_configure = True
if self._new_mappers:
return
reg._new_mappers = True
@classmethod
- def _recurse_with_dependents(cls, registries):
+ def _recurse_with_dependents(
+ cls, registries: Set[RegistryType]
+ ) -> Iterator[RegistryType]:
todo = registries
done = set()
while todo:
todo.update(reg._dependents.difference(done))
@classmethod
- def _recurse_with_dependencies(cls, registries):
+ def _recurse_with_dependencies(
+ cls, registries: Set[RegistryType]
+ ) -> Iterator[RegistryType]:
todo = registries
done = set()
while todo:
# them before
todo.update(reg._dependencies.difference(done))
- def _mappers_to_configure(self):
+ def _mappers_to_configure(self) -> Iterator[Mapper[Any]]:
return itertools.chain(
(
manager.mapper
),
)
- def _add_non_primary_mapper(self, np_mapper):
+ def _add_non_primary_mapper(self, np_mapper: Mapper[Any]) -> None:
self._non_primary_mappers[np_mapper] = True
- def _dispose_cls(self, cls):
+ def _dispose_cls(self, cls: Type[_O]) -> None:
clsregistry.remove_class(cls.__name__, cls, self._class_registry)
- def _add_manager(self, manager):
+ def _add_manager(self, manager: ClassManager[Any]) -> None:
self._managers[manager] = True
if manager.is_mapped:
raise exc.ArgumentError(
assert manager.registry is None
manager.registry = self
- def configure(self, cascade=False):
+ def configure(self, cascade: bool = False) -> None:
"""Configure all as-yet unconfigured mappers in this
:class:`_orm.registry`.
"""
mapperlib._configure_registries({self}, cascade=cascade)
- def dispose(self, cascade=False):
+ def dispose(self, cascade: bool = False) -> None:
"""Dispose of all mappers in this :class:`_orm.registry`.
After invocation, all the classes that were mapped within this registry
mapperlib._dispose_registries({self}, cascade=cascade)
- def _dispose_manager_and_mapper(self, manager):
+ def _dispose_manager_and_mapper(self, manager: ClassManager[Any]) -> None:
if "mapper" in manager.__dict__:
mapper = manager.mapper
def generate_base(
self,
- mapper=None,
- cls=object,
- name="Base",
- metaclass=DeclarativeMeta,
- ):
+ mapper: Optional[Callable[..., Mapper[Any]]] = None,
+ cls: Type[Any] = object,
+ name: str = "Base",
+ metaclass: Type[Any] = DeclarativeMeta,
+ ) -> Any:
"""Generate a declarative base class.
Classes that inherit from the returned class object will be
if hasattr(cls, "__class_getitem__"):
- def __class_getitem__(cls, key):
+ def __class_getitem__(cls: Type[_T], key: str) -> Type[_T]:
# allow generic classes in py3.9+
return cls
return metaclass(name, bases, class_dict)
- def mapped(self, cls):
+ def mapped(self, cls: Type[_O]) -> Type[_O]:
"""Class decorator that will apply the Declarative mapping process
to a given class.
_as_declarative(self, cls, cls.__dict__)
return cls
- def as_declarative_base(self, **kw):
+ def as_declarative_base(self, **kw: Any) -> Callable[[Type[_T]], Type[_T]]:
"""
Class decorator which will invoke
:meth:`_orm.registry.generate_base`
"""
- def decorate(cls):
+ def decorate(cls: Type[_T]) -> Type[_T]:
kw["cls"] = cls
kw["name"] = cls.__name__
- return self.generate_base(**kw)
+ return self.generate_base(**kw) # type: ignore
return decorate
- def map_declaratively(self, cls):
+ def map_declaratively(self, cls: Type[_O]) -> Mapper[_O]:
"""Map a class declaratively.
In this form of mapping, the class is scanned for mapping information,
:meth:`_orm.registry.map_imperatively`
"""
- return _as_declarative(self, cls, cls.__dict__)
+ _as_declarative(self, cls, cls.__dict__)
+ return cls.__mapper__ # type: ignore
- def map_imperatively(self, class_, local_table=None, **kw):
+ def map_imperatively(
+ self,
+ class_: Type[_O],
+ local_table: Optional[FromClause] = None,
+ **kw: Any,
+ ) -> Mapper[_O]:
r"""Map a class imperatively.
In this form of mapping, the class is not scanned for any mapping
RegistryType = registry
-def as_declarative(**kw):
+def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]:
"""
Class decorator which will adapt a given class into a
:func:`_orm.declarative_base`.
@inspection._inspects(
DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept
)
-def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]:
- mp: Mapper[Any] = _inspect_mapped_class(cls)
+def _inspect_decl_meta(cls: Type[Any]) -> Optional[Mapper[Any]]:
+ mp: Optional[Mapper[Any]] = _inspect_mapped_class(cls)
if mp is None:
if _DeferredMapperConfig.has_cls(cls):
_DeferredMapperConfig.raise_unmapped_for_cls(cls)
- raise orm_exc.UnmappedClassError(
- cls,
- msg="Class %s has a deferred mapping on it. It is not yet "
- "usable as a mapped class." % orm_exc._safe_cls_name(cls),
- )
return mp
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
"""Internal implementation for declarative."""
from __future__ import annotations
import collections
from typing import Any
+from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Mapping
+from typing import NoReturn
+from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
import weakref
from . import attributes
from . import exc as orm_exc
from . import instrumentation
from . import mapperlib
+from ._typing import _O
+from ._typing import attr_is_internal_proxy
from .attributes import InstrumentedAttribute
from .attributes import QueryableAttribute
from .base import _is_mapped_class
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .mapper import Mapper as mapper
+from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
from .util import _is_mapped_annotation
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+from ..util.typing import Protocol
if TYPE_CHECKING:
+ from ._typing import _ClassDict
from ._typing import _RegistryType
+ from .decl_api import declared_attr
+ from .instrumentation import ClassManager
+ from ..sql.schema import MetaData
+ from ..sql.selectable import FromClause
+
+_T = TypeVar("_T", bound=Any)
+
+_MapperKwArgs = Mapping[str, Any]
+
+_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]]
-def _declared_mapping_info(cls):
+class _DeclMappedClassProtocol(Protocol[_O]):
+ metadata: MetaData
+ __mapper__: Mapper[_O]
+ __table__: Table
+ __tablename__: str
+ __mapper_args__: Mapping[str, Any]
+ __table_args__: Optional[_TableArgsType]
+
+ def __declare_first__(self) -> None:
+ pass
+
+ def __declare_last__(self) -> None:
+ pass
+
+
+def _declared_mapping_info(
+ cls: Type[Any],
+) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]:
# deferred mapping
if _DeferredMapperConfig.has_cls(cls):
return _DeferredMapperConfig.config_for_cls(cls)
return None
-def _resolve_for_abstract_or_classical(cls):
+def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]:
if cls is object:
return None
+ sup: Optional[Type[Any]]
+
if cls.__dict__.get("__abstract__", False):
- for sup in cls.__bases__:
- sup = _resolve_for_abstract_or_classical(sup)
+ for base_ in cls.__bases__:
+ sup = _resolve_for_abstract_or_classical(base_)
if sup is not None:
return sup
else:
return cls
-def _get_immediate_cls_attr(cls, attrname, strict=False):
+def _get_immediate_cls_attr(
+ cls: Type[Any], attrname: str, strict: bool = False
+) -> Optional[Any]:
"""return an attribute of the class that is either present directly
on the class, e.g. not on a superclass, or is from a superclass but
this superclass is a non-mapped mixin, that is, not a descendant of
return getattr(cls, attrname)
for base in cls.__mro__[1:]:
- _is_classicial_inherits = _dive_for_cls_manager(base)
+ _is_classicial_inherits = _dive_for_cls_manager(base) is not None
if attrname in base.__dict__ and (
base is cls
return None
-def _dive_for_cls_manager(cls):
+def _dive_for_cls_manager(cls: Type[_O]) -> Optional[ClassManager[_O]]:
# because the class manager registration is pluggable,
# we need to do the search for every class in the hierarchy,
# rather than just a simple "cls._sa_class_manager"
- # python 2 old style class
- if not hasattr(cls, "__mro__"):
- return None
-
for base in cls.__mro__:
- manager = attributes.opt_manager_of_class(base)
+ manager: Optional[ClassManager[_O]] = attributes.opt_manager_of_class(
+ base
+ )
if manager:
return manager
return None
-def _as_declarative(registry, cls, dict_):
+def _as_declarative(
+ registry: _RegistryType, cls: Type[Any], dict_: _ClassDict
+) -> Optional[_MapperConfig]:
# declarative scans the class for attributes. no table or mapper
# args passed separately.
-
return _MapperConfig.setup_mapping(registry, cls, dict_, None, {})
-def _mapper(registry, cls, table, mapper_kw):
+def _mapper(
+ registry: _RegistryType,
+ cls: Type[_O],
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
+) -> Mapper[_O]:
_ImperativeMapperConfig(registry, cls, table, mapper_kw)
- return cls.__mapper__
+ return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__
@util.preload_module("sqlalchemy.orm.decl_api")
return isinstance(obj, (declared_attr, util.classproperty))
-def _check_declared_props_nocascade(obj, name, cls):
+def _check_declared_props_nocascade(
+ obj: Any, name: str, cls: Type[_O]
+) -> bool:
if _is_declarative_props(obj):
if getattr(obj, "_cascading", False):
util.warn(
"__weakref__",
)
+ cls: Type[Any]
+ classname: str
+ properties: util.OrderedDict[str, MapperProperty[Any]]
+ declared_attr_reg: Dict[declared_attr[Any], Any]
+
@classmethod
- def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
+ def setup_mapping(
+ cls,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ dict_: _ClassDict,
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
+ ) -> Optional[_MapperConfig]:
manager = attributes.opt_manager_of_class(cls)
if manager and manager.class_ is cls_:
raise exc.InvalidRequestError(
)
if cls_.__dict__.get("__abstract__", False):
- return
+ return None
defer_map = _get_immediate_cls_attr(
cls_, "_sa_decl_prepare_nocascade", strict=True
) or hasattr(cls_, "_sa_decl_prepare")
if defer_map:
- cfg_cls = _DeferredMapperConfig
+ return _DeferredMapperConfig(
+ registry, cls_, dict_, table, mapper_kw
+ )
else:
- cfg_cls = _ClassScanMapperConfig
-
- return cfg_cls(registry, cls_, dict_, table, mapper_kw)
+ return _ClassScanMapperConfig(
+ registry, cls_, dict_, table, mapper_kw
+ )
def __init__(
self,
registry: _RegistryType,
cls_: Type[Any],
- mapper_kw: Dict[str, Any],
+ mapper_kw: _MapperKwArgs,
):
self.cls = util.assert_arg_type(cls_, type, "cls_")
self.classname = cls_.__name__
"Mapper." % self.cls
)
- def set_cls_attribute(self, attrname, value):
+ def set_cls_attribute(self, attrname: str, value: _T) -> _T:
manager = instrumentation.manager_of_class(self.cls)
manager.install_member(attrname, value)
return value
- def _early_mapping(self, mapper_kw):
+ def map(self, mapper_kw: _MapperKwArgs = ...) -> Mapper[Any]:
+ raise NotImplementedError()
+
+ def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None:
self.map(mapper_kw)
def __init__(
self,
- registry,
- cls_,
- table,
- mapper_kw,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
):
super(_ImperativeMapperConfig, self).__init__(
registry, cls_, mapper_kw
self._early_mapping(mapper_kw)
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
mapper_cls = mapper
return self.set_cls_attribute(
mapper_cls(self.cls, self.local_table, **mapper_kw),
)
- def _setup_inheritance(self, mapper_kw):
+ def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
cls = self.cls
inherits = mapper_kw.get("inherits", None)
# since we search for classical mappings now, search for
# multiple mapped bases as well and raise an error.
inherits_search = []
- for c in cls.__bases__:
- c = _resolve_for_abstract_or_classical(c)
+ for base_ in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(base_)
if c is None:
continue
if _declared_mapping_info(
"inherits",
)
+ registry: _RegistryType
+ clsdict_view: _ClassDict
+ collected_annotations: Dict[str, Tuple[Any, bool]]
+ collected_attributes: Dict[str, Any]
+ local_table: Optional[FromClause]
+ persist_selectable: Optional[FromClause]
+ declared_columns: util.OrderedSet[Column[Any]]
+ column_copies: Dict[
+ Union[MappedColumn[Any], Column[Any]],
+ Union[MappedColumn[Any], Column[Any]],
+ ]
+ tablename: Optional[str]
+ mapper_args: Mapping[str, Any]
+ table_args: Optional[_TableArgsType]
+ mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
+ inherits: Optional[Type[Any]]
+
def __init__(
self,
- registry,
- cls_,
- dict_,
- table,
- mapper_kw,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ dict_: _ClassDict,
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
):
# grab class dict before the instrumentation manager has been added.
self.persist_selectable = None
self.collected_attributes = {}
- self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
+ self.collected_annotations = {}
self.declared_columns = util.OrderedSet()
self.column_copies = {}
self._early_mapping(mapper_kw)
- def _setup_declared_events(self):
+ def _setup_declared_events(self) -> None:
if _get_immediate_cls_attr(self.cls, "__declare_last__"):
@event.listens_for(mapper, "after_configured")
- def after_configured():
- self.cls.__declare_last__()
+ def after_configured() -> None:
+ cast(
+ "_DeclMappedClassProtocol[Any]", self.cls
+ ).__declare_last__()
if _get_immediate_cls_attr(self.cls, "__declare_first__"):
@event.listens_for(mapper, "before_configured")
- def before_configured():
- self.cls.__declare_first__()
-
- def _cls_attr_override_checker(self, cls):
+ def before_configured() -> None:
+ cast(
+ "_DeclMappedClassProtocol[Any]", self.cls
+ ).__declare_first__()
+
+ def _cls_attr_override_checker(
+ self, cls: Type[_O]
+ ) -> Callable[[str, Any], bool]:
"""Produce a function that checks if a class has overridden an
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__", None
+ cls, "__sa_dataclass_metadata_key__"
)
if sa_dataclass_metadata_key is None:
- def attribute_is_overridden(key, obj):
+ def attribute_is_overridden(key: str, obj: Any) -> bool:
return getattr(cls, key) is not obj
else:
absent = object()
- def attribute_is_overridden(key, obj):
+ def attribute_is_overridden(key: str, obj: Any) -> bool:
if _is_declarative_props(obj):
obj = obj.fget
]
)
- def _cls_attr_resolver(self, cls):
+ def _cls_attr_resolver(
+ self, cls: Type[Any]
+ ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]:
"""produce a function to iterate the "attributes" of a class,
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__", None
+ sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
)
cls_annotations = util.get_annotations(cls)
)
if sa_dataclass_metadata_key is None:
- def local_attributes_for_class():
+ def local_attributes_for_class() -> Iterable[
+ Tuple[str, Any, Any, bool]
+ ]:
return (
(
name,
field.name: field for field in util.local_dataclass_fields(cls)
}
- def local_attributes_for_class():
+ fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key
+
+ def local_attributes_for_class() -> Iterable[
+ Tuple[str, Any, Any, bool]
+ ]:
for name in names:
field = dataclass_fields.get(name, None)
if field and sa_dataclass_metadata_key in field.metadata:
yield field.name, _as_dc_declaredattr(
- field.metadata, sa_dataclass_metadata_key
+ field.metadata, fixed_sa_dataclass_metadata_key
), cls_annotations.get(field.name), True
else:
yield name, cls_vars.get(name), cls_annotations.get(
return local_attributes_for_class
- def _scan_attributes(self):
+ def _scan_attributes(self) -> None:
cls = self.cls
+ cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+
clsdict_view = self.clsdict_view
collected_attributes = self.collected_attributes
column_copies = self.column_copies
mapper_args_fn = None
table_args = inherited_table_args = None
+
tablename = None
fixed_table = "__table__" in clsdict_view
# make a copy of it so a class-level dictionary
# is not overwritten when we update column-based
# arguments.
- def mapper_args_fn():
- return dict(cls.__mapper_args__)
+ def _mapper_args_fn() -> Dict[str, Any]:
+ return dict(cls_as_Decl.__mapper_args__)
+
+ mapper_args_fn = _mapper_args_fn
elif name == "__tablename__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
)
if not tablename and (not class_mapped or check_decl):
- tablename = cls.__tablename__
+ tablename = cls_as_Decl.__tablename__
elif name == "__table_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
)
if not table_args and (not class_mapped or check_decl):
- table_args = cls.__table_args__
+ table_args = cls_as_Decl.__table_args__
if not isinstance(
table_args, (tuple, dict, type(None))
):
# or similar. note there is no known case that
# produces nested proxies, so we are only
# looking one level deep right now.
+
if (
isinstance(ret, InspectionAttr)
- and ret._is_internal_proxy
+ and attr_is_internal_proxy(ret)
and not isinstance(
ret.original_property, MapperProperty
)
collected_attributes[name] = column_copies[
obj
] = ret
+
if (
isinstance(ret, (Column, MapperProperty))
and ret.doc is None
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
- def _warn_for_decl_attributes(self, cls, key, c):
+ def _warn_for_decl_attributes(
+ self, cls: Type[Any], key: str, c: Any
+ ) -> None:
if isinstance(c, expression.ColumnClause):
util.warn(
f"Attribute '{key}' on class {cls} appears to "
)
def _produce_column_copies(
- self, attributes_for_class, attribute_is_overridden
- ):
+ self,
+ attributes_for_class: Callable[
+ [], Iterable[Tuple[str, Any, Any, bool]]
+ ],
+ attribute_is_overridden: Callable[[str, Any], bool],
+ ) -> None:
cls = self.cls
dict_ = self.clsdict_view
collected_attributes = self.collected_attributes
continue
elif name not in dict_ and not (
"__table__" in dict_
- and (obj.name or name) in dict_["__table__"].c
+ and (getattr(obj, "name", None) or name)
+ in dict_["__table__"].c
):
if obj.foreign_keys:
for fk in obj.foreign_keys:
setattr(cls, name, copy_)
- def _extract_mappable_attributes(self):
+ def _extract_mappable_attributes(self) -> None:
cls = self.cls
collected_attributes = self.collected_attributes
"declarative base class."
)
elif isinstance(value, Column):
- _undefer_column_name(k, self.column_copies.get(value, value))
+ _undefer_column_name(
+ k, self.column_copies.get(value, value) # type: ignore
+ )
elif isinstance(value, _IntrospectsAnnotations):
annotation, is_dataclass = self.collected_annotations.get(
- k, (None, None)
+ k, (None, False)
)
value.declarative_scan(
self.registry, cls, k, annotation, is_dataclass
)
our_stuff[k] = value
- def _extract_declared_columns(self):
+ def _extract_declared_columns(self) -> None:
our_stuff = self.properties
# extract columns from the class dict
% (self.classname, name, (", ".join(sorted(keys))))
)
- def _setup_table(self, table=None):
+ def _setup_table(self, table: Optional[FromClause] = None) -> None:
cls = self.cls
+ cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+
tablename = self.tablename
table_args = self.table_args
clsdict_view = self.clsdict_view
if "__table__" not in clsdict_view and table is None:
if hasattr(cls, "__table_cls__"):
- table_cls = util.unbound_method_to_callable(cls.__table_cls__)
+ table_cls = cast(
+ Type[Table],
+ util.unbound_method_to_callable(cls.__table_cls__), # type: ignore # noqa: E501
+ )
else:
table_cls = Table
if tablename is not None:
- args, table_kw = (), {}
+ args: Tuple[Any, ...] = ()
+ table_kw: Dict[str, Any] = {}
+
if table_args:
if isinstance(table_args, dict):
table_kw = table_args
)
else:
if table is None:
- table = cls.__table__
+ table = cls_as_Decl.__table__
if declared_columns:
for c in declared_columns:
if not table.c.contains_column(c):
"Can't add additional column %r when "
"specifying __table__" % c.key
)
+
self.local_table = table
- def _metadata_for_cls(self, manager):
+ def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData:
if hasattr(self.cls, "metadata"):
- return self.cls.metadata
+ return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata
else:
return manager.registry.metadata
- def _setup_inheritance(self, mapper_kw):
+ def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
table = self.local_table
cls = self.cls
table_args = self.table_args
# since we search for classical mappings now, search for
# multiple mapped bases as well and raise an error.
inherits_search = []
- for c in cls.__bases__:
- c = _resolve_for_abstract_or_classical(c)
+ for base_ in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(base_)
if c is None:
continue
if _declared_mapping_info(
"table-mapped class." % cls
)
elif self.inherits:
- inherited_mapper = _declared_mapping_info(self.inherits)
- inherited_table = inherited_mapper.local_table
- inherited_persist_selectable = inherited_mapper.persist_selectable
+ inherited_mapper_or_config = _declared_mapping_info(self.inherits)
+ assert inherited_mapper_or_config is not None
+ inherited_table = inherited_mapper_or_config.local_table
+ inherited_persist_selectable = (
+ inherited_mapper_or_config.persist_selectable
+ )
if table is None:
# single table inheritance.
"Can't place __table_args__ on an inherited class "
"with no table."
)
+
# add any columns declared here to the inherited table.
- for c in declared_columns:
- if c.name in inherited_table.c:
- if inherited_table.c[c.name] is c:
+ if declared_columns and not isinstance(inherited_table, Table):
+ raise exc.ArgumentError(
+ f"Can't declare columns on single-table-inherited "
+ f"subclass {self.cls}; superclass {self.inherits} "
+ "is not mapped to a Table"
+ )
+
+ for col in declared_columns:
+ assert inherited_table is not None
+ if col.name in inherited_table.c:
+ if inherited_table.c[col.name] is col:
continue
raise exc.ArgumentError(
"Column '%s' on class %s conflicts with "
"existing column '%s'"
- % (c, cls, inherited_table.c[c.name])
+ % (col, cls, inherited_table.c[col.name])
)
- if c.primary_key:
+ if col.primary_key:
raise exc.ArgumentError(
"Can't place primary key columns on an inherited "
"class with no table."
)
- inherited_table.append_column(c)
+
+ if TYPE_CHECKING:
+ assert isinstance(inherited_table, Table)
+
+ inherited_table.append_column(col)
if (
inherited_persist_selectable is not None
and inherited_persist_selectable is not inherited_table
):
- inherited_persist_selectable._refresh_for_new_column(c)
+ inherited_persist_selectable._refresh_for_new_column(
+ col
+ )
- def _prepare_mapper_arguments(self, mapper_kw):
+ def _prepare_mapper_arguments(self, mapper_kw: _MapperKwArgs) -> None:
properties = self.properties
if self.mapper_args_fn:
# not mapped on the parent class, to avoid
# mapping columns specific to sibling/nephew classes
inherited_mapper = _declared_mapping_info(self.inherits)
+ assert isinstance(inherited_mapper, Mapper)
inherited_table = inherited_mapper.local_table
if "exclude_properties" not in mapper_args:
result_mapper_args["properties"] = properties
self.mapper_args = result_mapper_args
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
self._prepare_mapper_arguments(mapper_kw)
if hasattr(self.cls, "__mapper_cls__"):
- mapper_cls = util.unbound_method_to_callable(
- self.cls.__mapper_cls__
+ mapper_cls = cast(
+ "Type[Mapper[Any]]",
+ util.unbound_method_to_callable(
+ self.cls.__mapper_cls__ # type: ignore
+ ),
)
else:
mapper_cls = mapper
@util.preload_module("sqlalchemy.orm.decl_api")
-def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
+def _as_dc_declaredattr(
+ field_metadata: Mapping[str, Any], sa_dataclass_metadata_key: str
+) -> Any:
# wrap lambdas inside dataclass fields inside an ad-hoc declared_attr.
# we can't write it because field.metadata is immutable :( so we have
# to go through extra trouble to compare these
class _DeferredMapperConfig(_ClassScanMapperConfig):
- _configs = util.OrderedDict()
+ _cls: weakref.ref[Type[Any]]
+
+ _configs: util.OrderedDict[
+ weakref.ref[Type[Any]], _DeferredMapperConfig
+ ] = util.OrderedDict()
- def _early_mapping(self, mapper_kw):
+ def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None:
pass
- @property
- def cls(self):
- return self._cls()
+ # mypy disallows plain property override of variable
+ @property # type: ignore
+ def cls(self) -> Type[Any]: # type: ignore
+ return self._cls() # type: ignore
@cls.setter
- def cls(self, class_):
+ def cls(self, class_: Type[Any]) -> None:
self._cls = weakref.ref(class_, self._remove_config_cls)
self._configs[self._cls] = self
@classmethod
- def _remove_config_cls(cls, ref):
+ def _remove_config_cls(cls, ref: weakref.ref[Type[Any]]) -> None:
cls._configs.pop(ref, None)
@classmethod
- def has_cls(cls, class_):
+ def has_cls(cls, class_: Type[Any]) -> bool:
# 2.6 fails on weakref if class_ is an old style class
return isinstance(class_, type) and weakref.ref(class_) in cls._configs
@classmethod
- def raise_unmapped_for_cls(cls, class_):
+ def raise_unmapped_for_cls(cls, class_: Type[Any]) -> NoReturn:
if hasattr(class_, "_sa_raise_deferred_config"):
- class_._sa_raise_deferred_config()
+ class_._sa_raise_deferred_config() # type: ignore
raise orm_exc.UnmappedClassError(
class_,
- msg="Class %s has a deferred mapping on it. It is not yet "
- "usable as a mapped class." % orm_exc._safe_cls_name(class_),
+ msg=(
+ f"Class {orm_exc._safe_cls_name(class_)} has a deferred "
+ "mapping on it. It is not yet usable as a mapped class."
+ ),
)
@classmethod
- def config_for_cls(cls, class_):
+ def config_for_cls(cls, class_: Type[Any]) -> _DeferredMapperConfig:
return cls._configs[weakref.ref(class_)]
@classmethod
- def classes_for_base(cls, base_cls, sort=True):
+ def classes_for_base(
+ cls, base_cls: Type[Any], sort: bool = True
+ ) -> List[_DeferredMapperConfig]:
classes_for_base = [
m
for m, cls_ in [(m, m.cls) for m in cls._configs.values()]
all_m_by_cls = dict((m.cls, m) for m in classes_for_base)
- tuples = []
+ tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = []
for m_cls in all_m_by_cls:
tuples.extend(
(all_m_by_cls[base_cls], all_m_by_cls[m_cls])
)
return list(topological.sort(tuples, classes_for_base))
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
self._configs.pop(self._cls, None)
return super(_DeferredMapperConfig, self).map(mapper_kw)
-def _add_attribute(cls, key, value):
+def _add_attribute(
+ cls: Type[Any], key: str, value: MapperProperty[Any]
+) -> None:
"""add an attribute to an existing declarative class.
This runs through the logic to determine MapperProperty,
"""
if "__mapper__" in cls.__dict__:
+ mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls)
if isinstance(value, Column):
_undefer_column_name(key, value)
- cls.__table__.append_column(value, replace_existing=True)
- cls.__mapper__.add_property(key, value)
+ # TODO: raise for this is not a Table
+ mapped_cls.__table__.append_column(value, replace_existing=True)
+ mapped_cls.__mapper__.add_property(key, value)
elif isinstance(value, _MapsColumns):
mp = value.mapper_property_to_assign
for col in value.columns_to_assign:
_undefer_column_name(key, col)
- cls.__table__.append_column(col, replace_existing=True)
+ # TODO: raise for this is not a Table
+ mapped_cls.__table__.append_column(col, replace_existing=True)
if not mp:
- cls.__mapper__.add_property(key, col)
+ mapped_cls.__mapper__.add_property(key, col)
if mp:
- cls.__mapper__.add_property(key, mp)
+ mapped_cls.__mapper__.add_property(key, mp)
elif isinstance(value, MapperProperty):
- cls.__mapper__.add_property(key, value)
+ mapped_cls.__mapper__.add_property(key, value)
elif isinstance(value, QueryableAttribute) and value.key != key:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = Synonym(value.key)
- cls.__mapper__.add_property(key, value)
+ mapped_cls.__mapper__.add_property(key, value)
else:
type.__setattr__(cls, key, value)
- cls.__mapper__._expire_memoizations()
+ mapped_cls.__mapper__._expire_memoizations()
else:
type.__setattr__(cls, key, value)
-def _del_attribute(cls, key):
+def _del_attribute(cls: Type[Any], key: str) -> None:
if (
"__mapper__" in cls.__dict__
and key in cls.__dict__
- and not cls.__mapper__._dispose_called
+ and not cast(
+ "_DeclMappedClassProtocol[Any]", cls
+ ).__mapper__._dispose_called
):
value = cls.__dict__[key]
if isinstance(
)
else:
type.__delattr__(cls, key)
- cls.__mapper__._expire_memoizations()
+ cast(
+ "_DeclMappedClassProtocol[Any]", cls
+ ).__mapper__._expire_memoizations()
else:
type.__delattr__(cls, key)
-def _declarative_constructor(self, **kwargs):
+def _declarative_constructor(self: Any, **kwargs: Any) -> None:
"""A simple constructor that allows initialization from kwargs.
Sets attributes on the constructed instance using the names and
_declarative_constructor.__name__ = "__init__"
-def _undefer_column_name(key, column):
+def _undefer_column_name(key: str, column: Column[Any]) -> None:
if column.key is None:
column.key = key
if column.name is None:
from typing import Any
from typing import Callable
from typing import List
+from typing import NoReturn
from typing import Optional
+from typing import Sequence
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import attributes
from . import util as orm_util
+from .base import LoaderCallableStatus
from .base import Mapped
+from .base import PassiveFlag
+from .base import SQLORMOperations
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .. import sql
from .. import util
from ..sql import expression
-from ..sql import operators
+from ..sql.elements import BindParameter
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _InstanceDict
+ from ._typing import _RegistryType
+ from .attributes import History
from .attributes import InstrumentedAttribute
+ from .attributes import QueryableAttribute
+ from .context import ORMCompileState
+ from .mapper import Mapper
+ from .properties import ColumnProperty
from .properties import MappedColumn
+ from .state import InstanceState
+ from ..engine.base import Connection
+ from ..engine.row import Row
+ from ..sql._typing import _DMLColumnArgument
from ..sql._typing import _InfoType
+ from ..sql.elements import ClauseList
+ from ..sql.elements import ColumnElement
from ..sql.schema import Column
+ from ..sql.selectable import Select
+ from ..util.typing import _AnnotationScanType
+ from ..util.typing import CallableReference
+ from ..util.typing import DescriptorReference
+ from ..util.typing import RODescriptorReference
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
class _CompositeClassProto(Protocol):
+ def __init__(self, *args: Any):
+ ...
+
def __composite_values__(self) -> Tuple[Any, ...]:
...
""":class:`.MapperProperty` which proxies access to a
user-defined descriptor."""
- doc = None
+ doc: Optional[str] = None
uses_objects = False
_links_to_entity = False
- def instrument_class(self, mapper):
+ descriptor: DescriptorReference[Any]
+
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ raise NotImplementedError()
+
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
prop = self
- class _ProxyImpl:
+ class _ProxyImpl(attributes.AttributeImpl):
accepts_scalar_loader = False
load_on_unexpire = True
collection = False
@property
- def uses_objects(self):
+ def uses_objects(self) -> bool: # type: ignore
return prop.uses_objects
- def __init__(self, key):
+ def __init__(self, key: str):
self.key = key
- if hasattr(prop, "get_history"):
-
- def get_history(
- self, state, dict_, passive=attributes.PASSIVE_OFF
- ):
- return prop.get_history(state, dict_, passive)
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ return prop.get_history(state, dict_, passive)
if self.descriptor is None:
desc = getattr(mapper.class_, self.key, None)
if self.descriptor is None:
- def fset(obj, value):
+ def fset(obj: Any, value: Any) -> None:
setattr(obj, self.name, value)
- def fdel(obj):
+ def fdel(obj: Any) -> None:
delattr(obj, self.name)
- def fget(obj):
+ def fget(obj: Any) -> Any:
return getattr(obj, self.name)
self.descriptor = property(fget=fget, fset=fset, fdel=fdel)
]
+_CC = TypeVar("_CC", bound=_CompositeClassProto)
+
+
class Composite(
- _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T]
+ _MapsColumns[_CC], _IntrospectsAnnotations, DescriptorProperty[_CC]
):
"""Defines a "composite" mapped attribute, representing a collection
of columns as one attribute.
"""
- composite_class: Union[
- Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]]
+ composite_class: Union[Type[_CC], Callable[..., _CC]]
+ attrs: Tuple[_CompositeAttrType[Any], ...]
+
+ _generated_composite_accessor: CallableReference[
+ Optional[Callable[[_CC], Tuple[Any, ...]]]
]
- attrs: Tuple[_CompositeAttrType, ...]
+
+ comparator_factory: Type[Comparator[_CC]]
def __init__(
self,
- class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None,
- *attrs: _CompositeAttrType,
+ class_: Union[
+ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+ ] = None,
+ *attrs: _CompositeAttrType[Any],
active_history: bool = False,
deferred: bool = False,
group: Optional[str] = None,
- comparator_factory: Optional[Type[Comparator]] = None,
+ comparator_factory: Optional[Type[Comparator[_CC]]] = None,
info: Optional[_InfoType] = None,
):
super().__init__()
# will initialize within declarative_scan
self.composite_class = None # type: ignore
else:
- self.composite_class = class_
+ self.composite_class = class_ # type: ignore
self.attrs = attrs
self.active_history = active_history
)
self._generated_composite_accessor = None
if info is not None:
- self.info = info
+ self.info.update(info)
util.set_creation_order(self)
self._create_descriptor()
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
super().instrument_class(mapper)
self._setup_event_handlers()
- def _composite_values_from_instance(
- self, value: _CompositeClassProto
- ) -> Tuple[Any, ...]:
+ def _composite_values_from_instance(self, value: _CC) -> Tuple[Any, ...]:
if self._generated_composite_accessor:
return self._generated_composite_accessor(value)
else:
else:
return accessor()
- def do_init(self):
+ def do_init(self) -> None:
"""Initialization which occurs after the :class:`.Composite`
has been associated with its parent mapper.
_COMPOSITE_FGET = object()
- def _create_descriptor(self):
+ def _create_descriptor(self) -> None:
"""Create the Python descriptor that will serve as
the access point on instances of the mapped class.
"""
- def fget(instance):
+ def fget(instance: Any) -> Any:
dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
return dict_.get(self.key, None)
- def fset(instance, value):
+ def fset(instance: Any, value: Any) -> None:
dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
attr = state.manager[self.key]
- previous = dict_.get(self.key, attributes.NO_VALUE)
+ previous = dict_.get(self.key, LoaderCallableStatus.NO_VALUE)
for fn in attr.dispatch.set:
value = fn(state, value, previous, attr.impl)
dict_[self.key] = value
):
setattr(instance, key, value)
- def fdel(instance):
+ def fdel(instance: Any) -> None:
state = attributes.instance_state(instance)
dict_ = attributes.instance_dict(instance)
- previous = dict_.pop(self.key, attributes.NO_VALUE)
+ previous = dict_.pop(self.key, LoaderCallableStatus.NO_VALUE)
attr = state.manager[self.key]
attr.dispatch.remove(state, previous, attr.impl)
for key in self._attribute_keys:
@util.preload_module("sqlalchemy.orm.properties")
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
MappedColumn = util.preloaded.orm_properties.MappedColumn
argument = _extract_mapped_subtype(
@util.preload_module("sqlalchemy.orm.properties")
@util.preload_module("sqlalchemy.orm.decl_base")
- def _setup_for_dataclass(self, registry, cls, key):
+ def _setup_for_dataclass(
+ self, registry: _RegistryType, cls: Type[Any], key: str
+ ) -> None:
MappedColumn = util.preloaded.orm_properties.MappedColumn
decl_base = util.preloaded.orm_decl_base
self._generated_composite_accessor = getter
@util.memoized_property
- def _comparable_elements(self):
+ def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
return [getattr(self.parent.class_, prop.key) for prop in self.props]
@util.memoized_property
@util.preload_module("orm.properties")
- def props(self):
+ def props(self) -> Sequence[MapperProperty[Any]]:
props = []
MappedColumn = util.preloaded.orm_properties.MappedColumn
elif isinstance(attr, attributes.InstrumentedAttribute):
prop = attr.property
else:
+ prop = None
+
+ if not isinstance(prop, MapperProperty):
raise sa_exc.ArgumentError(
"Composite expects Column objects or mapped "
- "attributes/attribute names as arguments, got: %r"
- % (attr,)
+ f"attributes/attribute names as arguments, got: {attr!r}"
)
+
props.append(prop)
return props
- @property
+ @util.non_memoized_property
@util.preload_module("orm.properties")
- def columns(self):
+ def columns(self) -> Sequence[Column[Any]]:
MappedColumn = util.preloaded.orm_properties.MappedColumn
return [
a.column if isinstance(a, MappedColumn) else a
]
@property
- def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
+ def mapper_property_to_assign(self) -> Optional[MapperProperty[_CC]]:
return self
@property
- def columns_to_assign(self) -> List[schema.Column]:
+ def columns_to_assign(self) -> List[schema.Column[Any]]:
return [c for c in self.columns if c.table is None]
- def _setup_arguments_on_columns(self):
+ @util.preload_module("orm.properties")
+ def _setup_arguments_on_columns(self) -> None:
"""Propagate configuration arguments made on this composite
to the target columns, for those that apply.
"""
+ ColumnProperty = util.preloaded.orm_properties.ColumnProperty
+
for prop in self.props:
- prop.active_history = self.active_history
+ if not isinstance(prop, ColumnProperty):
+ continue
+ else:
+ cprop = prop
+
+ cprop.active_history = self.active_history
if self.deferred:
- prop.deferred = self.deferred
- prop.strategy_key = (("deferred", True), ("instrument", True))
- prop.group = self.group
+ cprop.deferred = self.deferred
+ cprop.strategy_key = (("deferred", True), ("instrument", True))
+ cprop.group = self.group
- def _setup_event_handlers(self):
+ def _setup_event_handlers(self) -> None:
"""Establish events that populate/expire the composite attribute."""
- def load_handler(state, context):
+ def load_handler(
+ state: InstanceState[Any], context: ORMCompileState
+ ) -> None:
_load_refresh_handler(state, context, None, is_refresh=False)
- def refresh_handler(state, context, to_load):
+ def refresh_handler(
+ state: InstanceState[Any],
+ context: ORMCompileState,
+ to_load: Optional[Sequence[str]],
+ ) -> None:
# note this corresponds to sqlalchemy.ext.mutable load_attrs()
if not to_load or (
).intersection(to_load):
_load_refresh_handler(state, context, to_load, is_refresh=True)
- def _load_refresh_handler(state, context, to_load, is_refresh):
+ def _load_refresh_handler(
+ state: InstanceState[Any],
+ context: ORMCompileState,
+ to_load: Optional[Sequence[str]],
+ is_refresh: bool,
+ ) -> None:
dict_ = state.dict
# if context indicates we are coming from the
*[state.dict[key] for key in self._attribute_keys]
)
- def expire_handler(state, keys):
+ def expire_handler(
+ state: InstanceState[Any], keys: Optional[Sequence[str]]
+ ) -> None:
if keys is None or set(self._attribute_keys).intersection(keys):
state.dict.pop(self.key, None)
- def insert_update_handler(mapper, connection, state):
+ def insert_update_handler(
+ mapper: Mapper[Any],
+ connection: Connection,
+ state: InstanceState[Any],
+ ) -> None:
"""After an insert or update, some columns may be expired due
to server side defaults, or re-populated due to client side
defaults. Pop out the composite value here so that it
# TODO: need a deserialize hook here
@util.memoized_property
- def _attribute_keys(self):
+ def _attribute_keys(self) -> Sequence[str]:
return [prop.key for prop in self.props]
- def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
"""Provided for userland code that uses attributes.get_history()."""
- added = []
- deleted = []
+ added: List[Any] = []
+ deleted: List[Any] = []
has_history = False
for prop in self.props:
else:
return attributes.History((), [self.composite_class(*added)], ())
- def _comparator_factory(self, mapper):
+ def _comparator_factory(
+ self, mapper: Mapper[Any]
+ ) -> Composite.Comparator[_CC]:
return self.comparator_factory(self, mapper)
- class CompositeBundle(orm_util.Bundle):
- def __init__(self, property_, expr):
+ class CompositeBundle(orm_util.Bundle[_T]):
+ def __init__(
+ self,
+ property_: Composite[_T],
+ expr: ClauseList,
+ ):
self.property = property_
super().__init__(property_.key, *expr)
- def create_row_processor(self, query, procs, labels):
- def proc(row):
+ def create_row_processor(
+ self,
+ query: Select[Any],
+ procs: Sequence[Callable[[Row[Any]], Any]],
+ labels: Sequence[str],
+ ) -> Callable[[Row[Any]], Any]:
+ def proc(row: Row[Any]) -> Any:
return self.property.composite_class(
*[proc(row) for proc in procs]
)
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
+ prop: RODescriptorReference[Composite[_PT]]
+
@util.memoized_property
- def clauses(self):
+ def clauses(self) -> ClauseList:
return expression.ClauseList(
group=False, *self._comparable_elements
)
- def __clause_element__(self):
+ def __clause_element__(self) -> Composite.CompositeBundle[_PT]:
return self.expression
@util.memoized_property
- def expression(self):
+ def expression(self) -> Composite.CompositeBundle[_PT]:
clauses = self.clauses._annotate(
{
"parententity": self._parententity,
)
return Composite.CompositeBundle(self.prop, clauses)
- def _bulk_update_tuples(self, value):
- if isinstance(value, sql.elements.BindParameter):
+ def _bulk_update_tuples(
+ self, value: Any
+ ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
+ if isinstance(value, BindParameter):
value = value.value
+ values: Sequence[Any]
+
if value is None:
values = [None for key in self.prop._attribute_keys]
- elif isinstance(value, self.prop.composite_class):
+ elif isinstance(self.prop.composite_class, type) and isinstance(
+ value, self.prop.composite_class
+ ):
values = self.prop._composite_values_from_instance(value)
else:
raise sa_exc.ArgumentError(
% (self.prop, value)
)
- return zip(self._comparable_elements, values)
+ return list(zip(self._comparable_elements, values))
@util.memoized_property
- def _comparable_elements(self):
+ def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
if self._adapt_to_entity:
return [
getattr(self._adapt_to_entity.entity, prop.key)
else:
return self.prop._comparable_elements
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ values: Sequence[Any]
if other is None:
values = [None] * len(self.prop._comparable_elements)
else:
a == b for a, b in zip(self.prop._comparable_elements, values)
]
if self._adapt_to_entity:
+ assert self.adapter is not None
comparisons = [self.adapter(x) for x in comparisons]
return sql.and_(*comparisons)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
return sql.not_(self.__eq__(other))
- def __str__(self):
+ def __str__(self) -> str:
return str(self.parent.class_.__name__) + "." + self.key
"""
- def _comparator_factory(self, mapper):
+ def _comparator_factory(
+ self, mapper: Mapper[Any]
+ ) -> Type[PropComparator[_T]]:
+
comparator_callable = None
for m in self.parent.iterate_to_root():
p = m._props[self.key]
- if not isinstance(p, ConcreteInheritedProperty):
+ if getattr(p, "comparator_factory", None) is not None:
comparator_callable = p.comparator_factory
break
- return comparator_callable
+ assert comparator_callable is not None
+ return comparator_callable(p, mapper) # type: ignore
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- def warn():
+ def warn() -> NoReturn:
raise AttributeError(
"Concrete %s does not implement "
"attribute %r at the instance level. Add "
)
class NoninheritedConcreteProp:
- def __set__(s, obj, value):
+ def __set__(s: Any, obj: Any, value: Any) -> NoReturn:
warn()
- def __delete__(s, obj):
+ def __delete__(s: Any, obj: Any) -> NoReturn:
warn()
- def __get__(s, obj, owner):
+ def __get__(s: Any, obj: Any, owner: Any) -> Any:
if obj is None:
return self.descriptor
warn()
"""
+ comparator_factory: Optional[Type[PropComparator[_T]]]
+
def __init__(
self,
- name,
- map_column=None,
- descriptor=None,
- comparator_factory=None,
- doc=None,
- info=None,
+ name: str,
+ map_column: Optional[bool] = None,
+ descriptor: Optional[Any] = None,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
):
super().__init__()
self.map_column = map_column
self.descriptor = descriptor
self.comparator_factory = comparator_factory
- self.doc = doc or (descriptor and descriptor.__doc__) or None
+ if doc:
+ self.doc = doc
+ elif descriptor and descriptor.__doc__:
+ self.doc = descriptor.__doc__
+ else:
+ self.doc = None
if info:
- self.info = info
+ self.info.update(info)
util.set_creation_order(self)
- @property
- def uses_objects(self):
- return getattr(self.parent.class_, self.name).impl.uses_objects
+ if not TYPE_CHECKING:
+
+ @property
+ def uses_objects(self) -> bool:
+ return getattr(self.parent.class_, self.name).impl.uses_objects
# TODO: when initialized, check _proxied_object,
# emit a warning if its not a column-based property
@util.memoized_property
- def _proxied_object(self):
+ def _proxied_object(
+ self,
+ ) -> Union[MapperProperty[_T], SQLORMOperations[_T]]:
attr = getattr(self.parent.class_, self.name)
if not hasattr(attr, "property") or not isinstance(
attr.property, MapperProperty
# hybrid or association proxy
if isinstance(attr, attributes.QueryableAttribute):
return attr.comparator
- elif isinstance(attr, operators.ColumnOperators):
+ elif isinstance(attr, SQLORMOperations):
+ # assocaition proxy comes here
return attr
raise sa_exc.InvalidRequestError(
)
return attr.property
- def _comparator_factory(self, mapper):
+ def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]:
prop = self._proxied_object
if isinstance(prop, MapperProperty):
else:
return prop
- def get_history(self, *arg, **kw):
- attr = getattr(self.parent.class_, self.name)
- return attr.impl.get_history(*arg, **kw)
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ attr: QueryableAttribute[Any] = getattr(self.parent.class_, self.name)
+ return attr.impl.get_history(state, dict_, passive=passive)
@util.preload_module("sqlalchemy.orm.properties")
- def set_parent(self, parent, init):
+ def set_parent(self, parent: Mapper[Any], init: bool) -> None:
properties = util.preloaded.orm_properties
if self.map_column:
"%r for column %r"
% (self.key, self.name, self.name, self.key)
)
- p = properties.ColumnProperty(
+ p: ColumnProperty[Any] = properties.ColumnProperty(
parent.persist_selectable.c[self.key]
)
parent._configure_property(self.name, p, init=init, setparent=True)
from __future__ import annotations
+from typing import Any
+from typing import Optional
+from typing import overload
+from typing import TYPE_CHECKING
+from typing import Union
+
from . import attributes
from . import exc as orm_exc
from . import interfaces
from . import strategies
from . import util as orm_util
from .base import object_mapper
+from .base import PassiveFlag
from .query import Query
from .session import object_session
from .. import exc
from .. import log
from .. import util
from ..engine import result
+from ..util.typing import Literal
+
+if TYPE_CHECKING:
+ from ._typing import _InstanceDict
+ from .attributes import _AdaptedCollectionProtocol
+ from .attributes import AttributeEventToken
+ from .attributes import CollectionAdapter
+ from .base import LoaderCallableStatus
+ from .state import InstanceState
@log.class_logger
@relationships.Relationship.strategy_for(lazy="dynamic")
-class DynaLoader(strategies.AbstractRelationshipLoader):
+class DynaLoader(strategies.AbstractRelationshipLoader, log.Identified):
def init_class_attribute(self, mapper):
self.is_class_level = True
if not self.uselist:
else:
return self.query_class(self, state)
+ @overload
def get_collection(
self,
- state,
- dict_,
- user_data=None,
- passive=attributes.PASSIVE_NO_INITIALIZE,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Literal[None] = ...,
+ passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: _AdaptedCollectionProtocol = ...,
+ passive: PassiveFlag = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = ...,
+ passive: PassiveFlag = ...,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
+ ...
+
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
if not passive & attributes.SQL_OK:
data = self._get_collection_history(state, passive).added_items
else:
def set(
self,
- state,
- dict_,
- value,
- initiator=None,
- passive=attributes.PASSIVE_OFF,
- check_old=None,
- pop=False,
- _adapt=True,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken] = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ _adapt: bool = True,
+ ) -> None:
if initiator and initiator.parent_token is self.parent_token:
return
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
"""ORM event interfaces.
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
"""SQLAlchemy ORM exceptions."""
from typing import Any
from typing import Optional
+from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from .. import exc as sa_exc
from .. import util
from ..exc import MultipleResultsFound # noqa
from ..exc import NoResultFound # noqa
+if TYPE_CHECKING:
+ from .interfaces import LoaderStrategy
+ from .interfaces import MapperProperty
+ from .state import InstanceState
+
+_T = TypeVar("_T", bound=Any)
NO_STATE = (AttributeError, KeyError)
"""Exception types that may be raised by instrumentation implementations."""
)
UnmappedError.__init__(self, msg)
- def __reduce__(self):
+ def __reduce__(self) -> Any:
return self.__class__, (None, self.args[0])
class UnmappedClassError(UnmappedError):
"""An mapping operation was requested for an unknown class."""
- def __init__(self, cls: Type[object], msg: Optional[str] = None):
+ def __init__(self, cls: Type[_T], msg: Optional[str] = None):
if not msg:
msg = _default_unmapped(cls)
UnmappedError.__init__(self, msg)
"""
@util.preload_module("sqlalchemy.orm.base")
- def __init__(self, state, msg=None):
+ def __init__(self, state: InstanceState[Any], msg: Optional[str] = None):
base = util.preloaded.orm_base
if not msg:
sa_exc.InvalidRequestError.__init__(self, msg)
- def __reduce__(self):
+ def __reduce__(self) -> Any:
return self.__class__, (None, self.args[0])
def __init__(
self,
- applied_to_property_type,
- requesting_property,
- applies_to,
- actual_strategy_type,
- strategy_key,
+ applied_to_property_type: Type[Any],
+ requesting_property: MapperProperty[Any],
+ applies_to: Optional[Type[MapperProperty[Any]]],
+ actual_strategy_type: Optional[Type[LoaderStrategy]],
+ strategy_key: Tuple[Any, ...],
):
if actual_strategy_type is None:
sa_exc.InvalidRequestError.__init__(
% (strategy_key, requesting_property),
)
else:
+ assert applies_to is not None
sa_exc.InvalidRequestError.__init__(
self,
'Can\'t apply "%s" strategy to property "%s", '
)
-def _safe_cls_name(cls):
+def _safe_cls_name(cls: Type[Any]) -> str:
+ cls_name: Optional[str]
try:
cls_name = ".".join((cls.__module__, cls.__name__))
except AttributeError:
@util.preload_module("sqlalchemy.orm.base")
-def _default_unmapped(cls) -> Optional[str]:
+def _default_unmapped(cls: Type[Any]) -> Optional[str]:
base = util.preloaded.orm_base
try:
from __future__ import annotations
from typing import Any
+from typing import cast
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import NoReturn
from typing import Optional
from typing import Set
+from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
import weakref
) -> Optional[_O]:
raise NotImplementedError()
- def keys(self):
+ def keys(self) -> Iterable[_IdentityKeyType[Any]]:
return self._dict.keys()
def values(self) -> Iterable[object]:
class WeakInstanceDict(IdentityMap):
- _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]]
+ _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]]
def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
- state = self._dict[key]
+ state = cast("InstanceState[_O]", self._dict[key])
o = state.obj()
if o is None:
raise KeyError(key)
def contains_state(self, state: InstanceState[Any]) -> bool:
if state.key in self._dict:
+ if TYPE_CHECKING:
+ assert state.key is not None
try:
return self._dict[state.key] is state
except KeyError:
def replace(
self, state: InstanceState[Any]
) -> Optional[InstanceState[Any]]:
+ assert state.key is not None
if state.key in self._dict:
try:
- existing = self._dict[state.key]
+ existing = existing_non_none = self._dict[state.key]
except KeyError:
# catch gc removed the key after we just checked for it
existing = None
else:
- if existing is not state:
- self._manage_removed_state(existing)
+ if existing_non_none is not state:
+ self._manage_removed_state(existing_non_none)
else:
return None
else:
def add(self, state: InstanceState[Any]) -> bool:
key = state.key
+ assert key is not None
# inline of self.__contains__
if key in self._dict:
try:
if key not in self._dict:
return default
try:
- state = self._dict[key]
+ state = cast("InstanceState[_O]", self._dict[key])
except KeyError:
# catch gc removed the key after we just checked for it
return default
return default
return o
- def items(self) -> List[InstanceState[Any]]:
+ def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]:
values = self.all_states()
result = []
for state in values:
value = state.obj()
+ key = state.key
+ assert key is not None
if value is not None:
- result.append((state.key, value))
+ result.append((key, value))
return result
def values(self) -> List[object]:
def _fast_discard(self, state: InstanceState[Any]) -> None:
# used by InstanceState for state being
# GC'ed, inlines _managed_removed_state
+ key = state.key
+ assert key is not None
try:
- st = self._dict[state.key]
+ st = self._dict[key]
except KeyError:
# catch gc removed the key after we just checked for it
pass
else:
if st is state:
- self._dict.pop(state.key, None)
+ self._dict.pop(key, None)
def discard(self, state: InstanceState[Any]) -> None:
self.safe_discard(state)
def safe_discard(self, state: InstanceState[Any]) -> None:
- if state.key in self._dict:
+ key = state.key
+ if key in self._dict:
+ assert key is not None
try:
- st = self._dict[state.key]
+ st = self._dict[key]
except KeyError:
# catch gc removed the key after we just checked for it
pass
else:
if st is state:
- self._dict.pop(state.key, None)
+ self._dict.pop(key, None)
self._manage_removed_state(state)
if TYPE_CHECKING:
from ._typing import _RegistryType
from .attributes import AttributeImpl
- from .attributes import InstrumentedAttribute
+ from .attributes import QueryableAttribute
from .collections import _AdaptedCollectionProtocol
from .collections import _CollectionFactoryType
from .decl_base import _MapperConfig
class ClassManager(
HasMemoized,
- Dict[str, "InstrumentedAttribute[Any]"],
+ Dict[str, "QueryableAttribute[Any]"],
Generic[_O],
EventTarget,
):
factory: Optional[_ManagerFactory]
declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
- registry: Optional[_RegistryType] = None
+
+ registry: _RegistryType
+
+ if not TYPE_CHECKING:
+ # starts as None during setup
+ registry = None
+
+ class_: Type[_O]
_bases: List[ClassManager[Any]]
else:
return default
- def _attr_has_impl(self, key):
+ def _attr_has_impl(self, key: str) -> bool:
"""Return True if the given attribute is fully initialized.
i.e. has an impl.
def dict_getter(self):
return _default_dict_getter
- def instrument_attribute(self, key, inst, propagated=False):
+ def instrument_attribute(
+ self,
+ key: str,
+ inst: QueryableAttribute[Any],
+ propagated: bool = False,
+ ) -> None:
if propagated:
if key in self.local_attrs:
return # don't override local attr with inherited attr
delattr(self.class_, self.MANAGER_ATTR)
def install_descriptor(
- self, key: str, inst: InstrumentedAttribute[Any]
+ self, key: str, inst: QueryableAttribute[Any]
) -> None:
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
raise KeyError(
# InstanceState management
def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O:
- instance = self.class_.__new__(self.class_)
+ # here, we would prefer _O to be bound to "object"
+ # so that mypy sees that __new__ is present. currently
+ # it's bound to Any as there were other problems not having
+ # it that way but these can be revisited
+ instance = self.class_.__new__(self.class_) # type: ignore
if state is None:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
"""
from typing import Set
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .base import MANYTOONE as MANYTOONE # noqa: F401
from .base import NotExtension as NotExtension # noqa: F401
from .base import ONETOMANY as ONETOMANY # noqa: F401
+from .base import RelationshipDirection as RelationshipDirection # noqa: F401
from .base import SQLORMOperations
from .. import ColumnElement
from .. import inspection
from ..sql.cache_key import HasCacheKey
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
-from ..util.typing import DescriptorReference
+from ..util.typing import RODescriptorReference
from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
from .loading import _PopulatorDict
from .mapper import Mapper
from .path_registry import AbstractEntityRegistry
- from .path_registry import PathRegistry
from .query import Query
from .session import Session
from .state import InstanceState
from .strategy_options import _LoadElement
from .util import AliasedInsp
- from .util import CascadeOptions
from .util import ORMAdapter
from ..engine.result import Result
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _DMLColumnArgument
from ..sql._typing import _InfoType
from ..sql.operators import OperatorType
- from ..sql.util import ColumnAdapter
from ..sql.visitors import _TraverseInternalsType
+ from ..util.typing import _AnnotationScanType
+
+_StrategyKey = Tuple[Any, ...]
_T = TypeVar("_T", bound=Any)
)
-class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]):
+class ORMColumnsClauseRole(
+ roles.ColumnsClauseRole, roles.TypedColumnsClauseRole[_T]
+):
__slots__ = ()
_role_name = "ORM mapped entity, aliased entity, or Column expression"
registry: RegistryType,
cls: Type[Any],
key: str,
- annotation: Optional[Type[Any]],
- is_dataclass_field: Optional[bool],
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
) -> None:
"""Perform class-specific initializaton at early declarative scanning
time.
"parent",
"key",
"info",
+ "doc",
)
_cache_key_traversal: _TraverseInternalsType = [
("key", visitors.ExtendedInternalTraversal.dp_string),
]
- cascade: Optional[CascadeOptions] = None
- """The set of 'cascade' attribute names.
-
- This collection is checked before the 'cascade_iterator' method is called.
-
- The collection typically only applies to a Relationship.
-
- """
+ if not TYPE_CHECKING:
+ cascade = None
is_property = True
"""Part of the InspectionAttr interface; states this object is a
"""
+ doc: Optional[str]
+ """optional documentation string"""
+
def _memoized_attr_info(self) -> _InfoType:
"""Info dictionary associated with the object, allowing user-defined
data to be associated with this :class:`.InspectionAttr`.
self,
context: ORMCompileState,
query_entity: _MapperEntity,
- path: PathRegistry,
- adapter: Optional[ColumnAdapter],
+ path: AbstractEntityRegistry,
+ adapter: Optional[ORMAdapter],
**kwargs: Any,
) -> None:
"""Called by Query for the purposes of constructing a SQL statement.
self,
context: ORMCompileState,
query_entity: _MapperEntity,
- path: PathRegistry,
+ path: AbstractEntityRegistry,
mapper: Mapper[Any],
result: Result[Any],
- adapter: Optional[ColumnAdapter],
+ adapter: Optional[ORMAdapter],
populators: _PopulatorDict,
) -> None:
"""Produce row processing functions and append to the given
dest_state: InstanceState[Any],
dest_dict: _InstanceDict,
load: bool,
- _recursive: Set[InstanceState[Any]],
+ _recursive: Dict[Any, object],
_resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
) -> None:
"""Merge the attribute represented by this ``MapperProperty``
_parententity: _InternalEntityType[Any]
_adapt_to_entity: Optional[AliasedInsp[Any]]
- prop: DescriptorReference[MapperProperty[_T]]
+ prop: RODescriptorReference[MapperProperty[_T]]
def __init__(
self,
self._adapt_to_entity = adapt_to_entity
@util.non_memoized_property
- def property(self) -> Optional[MapperProperty[_T]]:
+ def property(self) -> MapperProperty[_T]:
"""Return the :class:`.MapperProperty` associated with this
:class:`.PropComparator`.
return self.prop.comparator._criterion_exists(criterion, **kwargs)
@util.ro_non_memoized_property
- def adapter(self) -> Optional[_ORMAdapterProto[_T]]:
+ def adapter(self) -> Optional[_ORMAdapterProto]:
"""Produce a callable that adapts column expressions
to suit an aliased version of this comparator.
if self._adapt_to_entity is None:
return None
else:
- return self._adapt_to_entity._adapt_element
+ return self._adapt_to_entity._orm_adapt_element
@util.ro_non_memoized_property
def info(self) -> _InfoType:
) -> ColumnElement[Any]:
...
- def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
+ def of_type(self, class_: _EntityType[_T]) -> PropComparator[_T]:
r"""Redefine this object in terms of a polymorphic subclass,
:func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased`
construct.
inherit_cache = True
strategy_wildcard_key: ClassVar[str]
- strategy_key: Tuple[Any, ...]
+ strategy_key: _StrategyKey
- _strategies: Dict[Tuple[Any, ...], LoaderStrategy]
+ _strategies: Dict[_StrategyKey, LoaderStrategy]
def _memoized_attr__wildcard_token(self) -> Tuple[str]:
return (
return load
- def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy:
+ def _get_strategy(self, key: _StrategyKey) -> LoaderStrategy:
try:
return self._strategies[key]
except KeyError:
self._strategies[key] = strategy = cls(self, key)
return strategy
- def setup(self, context, query_entity, path, adapter, **kwargs):
+ def setup(
+ self,
+ context: ORMCompileState,
+ query_entity: _MapperEntity,
+ path: AbstractEntityRegistry,
+ adapter: Optional[ORMAdapter],
+ **kwargs: Any,
+ ) -> None:
loader = self._get_context_loader(context, path)
if loader and loader.strategy:
strat = self._get_strategy(loader.strategy)
)
def create_row_processor(
- self, context, query_entity, path, mapper, result, adapter, populators
- ):
+ self,
+ context: ORMCompileState,
+ query_entity: _MapperEntity,
+ path: AbstractEntityRegistry,
+ mapper: Mapper[Any],
+ result: Result[Any],
+ adapter: Optional[ORMAdapter],
+ populators: _PopulatorDict,
+ ) -> None:
loader = self._get_context_loader(context, path)
if loader and loader.strategy:
strat = self._get_strategy(loader.strategy)
populators,
)
- def do_init(self):
+ def do_init(self) -> None:
self._strategies = {}
self.strategy = self._get_strategy(self.strategy_key)
- def post_instrument_class(self, mapper):
+ def post_instrument_class(self, mapper: Mapper[Any]) -> None:
if (
not self.parent.non_primary
and not mapper.class_manager._attr_has_impl(self.key)
self.strategy.init_class_attribute(mapper)
_all_strategies: collections.defaultdict[
- Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]]
+ Type[MapperProperty[Any]], Dict[_StrategyKey, Type[LoaderStrategy]]
] = collections.defaultdict(dict)
@classmethod
for prop_cls in cls.__mro__:
if prop_cls in cls._all_strategies:
+ if TYPE_CHECKING:
+ assert issubclass(prop_cls, MapperProperty)
strategies = cls._all_strategies[prop_cls]
try:
return strategies[key]
_is_compile_state = True
- def process_compile_state(self, compile_state):
- """Apply a modification to a given :class:`.CompileState`.
+ def process_compile_state(self, compile_state: ORMCompileState) -> None:
+ """Apply a modification to a given :class:`.ORMCompileState`.
This method is part of the implementation of a particular
:class:`.CompileStateOption` and is only invoked internally
"""
def process_compile_state_replaced_entities(
- self, compile_state, mapper_entities
- ):
- """Apply a modification to a given :class:`.CompileState`,
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ ) -> None:
+ """Apply a modification to a given :class:`.ORMCompileState`,
given entities that were replaced by with_only_columns() or
with_entities().
__slots__ = ()
def process_compile_state_replaced_entities(
- self, compile_state, mapper_entities
- ):
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ ) -> None:
self.process_compile_state(compile_state)
_is_criteria_option = True
- def get_global_criteria(self, attributes):
+ def get_global_criteria(self, attributes: Dict[str, Any]) -> None:
"""update additional entity criteria options in the given
attributes dictionary.
"""
- def __init__(self, payload=None):
+ def __init__(self, payload: Optional[Any] = None):
self.payload = payload
"strategy_opts",
)
- _strategy_keys: ClassVar[List[Tuple[Any, ...]]]
+ _strategy_keys: ClassVar[List[_StrategyKey]]
def __init__(
- self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...]
+ self, parent: MapperProperty[Any], strategy_key: _StrategyKey
):
self.parent_property = parent
self.is_class_level = False
"""
- def __str__(self):
+ def __str__(self) -> str:
return str(self.parent_property)
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
from .base import LoaderCallableStatus
+ from .context import QueryContext
from .interfaces import ORMOption
from .mapper import Mapper
+ from .query import Query
from .session import Session
from .state import InstanceState
+ from ..engine.cursor import CursorResult
from ..engine.interfaces import _ExecuteOptions
+ from ..engine.result import Result
from ..sql import Select
_T = TypeVar("_T", bound=Any)
_PopulatorDict = Dict[str, List[Tuple[str, Any]]]
-def instances(cursor, context):
+def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]:
"""Return a :class:`.Result` given an ORM query context.
:param cursor: a :class:`.CursorResult`, generated by a statement
unique_filters = [
_no_unique
if context.yield_per
- else _not_hashable(ent.column.type)
+ else _not_hashable(ent.column.type) # type: ignore
if (not ent.use_id_for_hash and ent._non_hashable_value)
else id
if ent.use_id_for_hash
labels, extra, _unique_filters=unique_filters
)
- def chunks(size):
+ def chunks(size): # type: ignore
while True:
yield_per = size
"is superseded by the :func:`_orm.merge_frozen_result` function.",
)
@util.preload_module("sqlalchemy.orm.context")
-def merge_result(query, iterator, load=True):
+def merge_result(
+ query: Query[Any],
+ iterator: Union[FrozenResult, Iterable[Sequence[Any]], Iterable[object]],
+ load: bool = True,
+) -> Union[FrozenResult, Iterable[Any]]:
"""Merge a result into the given :class:`.Query` object's Session.
See :meth:`_orm.Query.merge_result` for top-level documentation on this
result.append(keyed_tuple(newrow))
if frozen_result:
- return frozen_result.with_data(result)
+ return frozen_result.with_new_rows(result)
else:
return iter(result)
finally:
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
from __future__ import annotations
from ..sql import util as sql_util
from ..sql import visitors
from ..sql.cache_key import MemoizedHasCacheKey
+from ..sql.elements import KeyedColumnElement
from ..sql.schema import Table
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util import HasMemoized
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
- from ..sql.elements import KeyedColumnElement
from ..sql.schema import Column
from ..sql.selectable import FromClause
from ..sql.util import ColumnAdapter
dispatch: dispatcher[Mapper[_O]]
_dispose_called = False
+ _configure_failed: Any = False
_ready_for_configure = False
@util.deprecated_params(
self.batch = batch
self.eager_defaults = eager_defaults
self.column_prefix = column_prefix
- self.polymorphic_on = (
- coercions.expect(
+
+ # interim - polymorphic_on is further refined in
+ # _configure_polymorphic_setter
+ self.polymorphic_on = ( # type: ignore
+ coercions.expect( # type: ignore
roles.ColumnArgumentOrKeyRole,
polymorphic_on,
argname="polymorphic_on",
)
@util.preload_module("sqlalchemy.orm.descriptor_props")
- def _configure_property(self, key, prop, init=True, setparent=True):
+ def _configure_property(
+ self,
+ key: str,
+ prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]],
+ init: bool = True,
+ setparent: bool = True,
+ ) -> MapperProperty[Any]:
descriptor_props = util.preloaded.orm_descriptor_props
- self._log("_configure_property(%s, %s)", key, prop.__class__.__name__)
+ self._log(
+ "_configure_property(%s, %s)", key, prop_arg.__class__.__name__
+ )
- if not isinstance(prop, MapperProperty):
- prop = self._property_from_column(key, prop)
+ if not isinstance(prop_arg, MapperProperty):
+ prop = self._property_from_column(key, prop_arg)
+ else:
+ prop = prop_arg
if isinstance(prop, properties.ColumnProperty):
col = self.persist_selectable.corresponding_column(prop.columns[0])
if self.configured:
self._expire_memoizations()
+ return prop
+
@util.preload_module("sqlalchemy.orm.descriptor_props")
- def _property_from_column(self, key, prop):
+ def _property_from_column(
+ self,
+ key: str,
+ prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]],
+ ) -> MapperProperty[Any]:
"""generate/update a :class:`.ColumnProperty` given a
:class:`_schema.Column` object."""
descriptor_props = util.preloaded.orm_descriptor_props
# we were passed a Column or a list of Columns;
# generate a properties.ColumnProperty
- columns = util.to_list(prop)
+ columns = util.to_list(prop_arg)
column = columns[0]
- assert isinstance(column, expression.ColumnElement)
- prop = self._props.get(key, None)
+ prop = self._props.get(key)
if isinstance(prop, properties.ColumnProperty):
if (
"columns get mapped." % (key, self, column.key, prop)
)
- def _check_configure(self):
+ def _check_configure(self) -> None:
if self.registry._new_mappers:
_configure_registries({self.registry}, cascade=True)
- def _post_configure_properties(self):
+ def _post_configure_properties(self) -> None:
"""Call the ``init()`` method on all ``MapperProperties``
attached to this mapper.
for key, value in dict_of_properties.items():
self.add_property(key, value)
- def add_property(self, key, prop):
+ def add_property(
+ self, key: str, prop: Union[Column[Any], MapperProperty[Any]]
+ ) -> None:
"""Add an individual MapperProperty to this mapper.
If the mapper has not been configured yet, just adds the
the given MapperProperty is configured immediately.
"""
+ prop = self._configure_property(key, prop, init=self.configured)
+ assert isinstance(prop, MapperProperty)
self._init_properties[key] = prop
- self._configure_property(key, prop, init=self.configured)
- def _expire_memoizations(self):
+ def _expire_memoizations(self) -> None:
for mapper in self.iterate_to_root():
mapper._reset_memoizations()
@property
- def _log_desc(self):
+ def _log_desc(self) -> str:
return (
"("
+ self.class_.__name__
+ ")"
)
- def _log(self, msg, *args):
+ def _log(self, msg: str, *args: Any) -> None:
self.logger.info("%s " + msg, *((self._log_desc,) + args))
- def _log_debug(self, msg, *args):
+ def _log_debug(self, msg: str, *args: Any) -> None:
self.logger.debug("%s " + msg, *((self._log_desc,) + args))
- def __repr__(self):
+ def __repr__(self) -> str:
return "<Mapper at 0x%x; %s>" % (id(self), self.class_.__name__)
- def __str__(self):
+ def __str__(self) -> str:
return "Mapper[%s%s(%s)]" % (
self.class_.__name__,
self.non_primary and " (non-primary)" or "",
"Mapper '%s' has no property '%s'" % (self, key)
) from err
- def get_property_by_column(self, column):
+ def get_property_by_column(
+ self, column: ColumnElement[_T]
+ ) -> MapperProperty[_T]:
"""Given a :class:`_schema.Column` object, return the
:class:`.MapperProperty` which maps this column."""
return result
- def _is_userland_descriptor(self, assigned_name, obj):
+ def _is_userland_descriptor(self, assigned_name: str, obj: Any) -> bool:
if isinstance(
obj,
(
_configure_registries(_all_registries(), cascade=True)
-def _configure_registries(registries, cascade):
+def _configure_registries(
+ registries: Set[_RegistryType], cascade: bool
+) -> None:
for reg in registries:
if reg._new_mappers:
break
@util.preload_module("sqlalchemy.orm.decl_api")
-def _do_configure_registries(registries, cascade):
+def _do_configure_registries(
+ registries: Set[_RegistryType], cascade: bool
+) -> None:
registry = util.preloaded.orm_decl_api.registry
@util.preload_module("sqlalchemy.orm.decl_api")
-def _dispose_registries(registries, cascade):
+def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None:
registry = util.preloaded.orm_decl_api.registry
from ..sql.cache_key import _CacheKeyTraversalType
from ..sql.elements import BindParameter
from ..sql.visitors import anon_map
+ from ..util.typing import _LiteralStar
from ..util.typing import TypeGuard
def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]:
return PathRegistry.deserialize(path)
-_WILDCARD_TOKEN = "*"
+_WILDCARD_TOKEN: _LiteralStar = "*"
_DEFAULT_TOKEN = "_sa_default"
is_token = False
is_root = False
has_entity = False
+ is_property = False
is_entity = False
path: _PathRepresentation
def __hash__(self) -> int:
return id(self)
- def __getitem__(self, key: Any) -> PathRegistry:
+ @overload
+ def __getitem__(self, entity: str) -> TokenRegistry:
+ ...
+
+ @overload
+ def __getitem__(self, entity: int) -> _PathElementType:
+ ...
+
+ @overload
+ def __getitem__(self, entity: slice) -> _PathRepresentation:
+ ...
+
+ @overload
+ def __getitem__(
+ self, entity: _InternalEntityType[Any]
+ ) -> AbstractEntityRegistry:
+ ...
+
+ @overload
+ def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
+ ...
+
+ def __getitem__(
+ self,
+ entity: Union[
+ str, int, slice, _InternalEntityType[Any], MapperProperty[Any]
+ ],
+ ) -> Union[
+ TokenRegistry,
+ _PathElementType,
+ _PathRepresentation,
+ PropRegistry,
+ AbstractEntityRegistry,
+ ]:
raise NotImplementedError()
# TODO: what are we using this for?
is_root = True
is_unnatural = False
- @overload
- def __getitem__(self, entity: str) -> TokenRegistry:
- ...
-
- @overload
- def __getitem__(
- self, entity: _InternalEntityType[Any]
- ) -> AbstractEntityRegistry:
- ...
-
- def __getitem__(
- self, entity: Union[str, _InternalEntityType[Any]]
+ def _getitem(
+ self, entity: Any
) -> Union[TokenRegistry, AbstractEntityRegistry]:
if entity in PathToken._intern:
if TYPE_CHECKING:
f"invalid argument for RootRegistry.__getitem__: {entity}"
)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
PathRegistry.root = RootRegistry()
else:
yield self
- def __getitem__(self, entity: Any) -> Any:
+ def _getitem(self, entity: Any) -> Any:
try:
return self.path[entity]
except TypeError as err:
raise IndexError(f"{entity}") from err
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class PropRegistry(PathRegistry):
__slots__ = (
"is_unnatural",
)
inherit_cache = True
+ is_property = True
prop: MapperProperty[Any]
mapper: Optional[Mapper[Any]]
assert self.entity is not None
return self[self.entity]
- @overload
- def __getitem__(self, entity: slice) -> _PathRepresentation:
- ...
-
- @overload
- def __getitem__(self, entity: int) -> _PathElementType:
- ...
-
- @overload
- def __getitem__(
- self, entity: _InternalEntityType[Any]
- ) -> AbstractEntityRegistry:
- ...
-
- def __getitem__(
+ def _getitem(
self, entity: Union[int, slice, _InternalEntityType[Any]]
) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]:
if isinstance(entity, (int, slice)):
else:
return SlotsEntityRegistry(self, entity)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class AbstractEntityRegistry(CreatesToken):
__slots__ = (
# self.natural_path = parent.natural_path + (entity, )
self.natural_path = self.path
+ @property
+ def root_entity(self) -> _InternalEntityType[Any]:
+ return cast("_InternalEntityType[Any]", self.path[0])
+
@property
def entity_path(self) -> PathRegistry:
return self
def __bool__(self) -> bool:
return True
- @overload
- def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
- ...
-
- @overload
- def __getitem__(self, entity: str) -> TokenRegistry:
- ...
-
- @overload
- def __getitem__(self, entity: int) -> _PathElementType:
- ...
-
- @overload
- def __getitem__(self, entity: slice) -> _PathRepresentation:
- ...
-
- def __getitem__(
+ def _getitem(
self, entity: Any
) -> Union[_PathElementType, _PathRepresentation, PathRegistry]:
if isinstance(entity, (int, slice)):
else:
return PropRegistry(self, entity)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class SlotsEntityRegistry(AbstractEntityRegistry):
# for aliased class, return lightweight, no-cycles created
def pop(self, key: Any, default: Any) -> Any:
return self._cache.pop(key, default)
- def __getitem__(self, entity: Any) -> Any:
+ def _getitem(self, entity: Any) -> Any:
if isinstance(entity, (int, slice)):
return self.path[entity]
elif isinstance(entity, PathToken):
return TokenRegistry(self, entity)
else:
return self._cache[entity]
+
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
+
+if TYPE_CHECKING:
+
+ def path_is_entity(
+ path: PathRegistry,
+ ) -> TypeGuard[AbstractEntityRegistry]:
+ ...
+
+ def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]:
+ ...
+
+else:
+ path_is_entity = operator.attrgetter("is_entity")
+ path_is_property = operator.attrgetter("is_property")
from typing import Any
from typing import cast
+from typing import Dict
from typing import List
from typing import Optional
+from typing import Sequence
from typing import Set
from typing import Type
from typing import TYPE_CHECKING
from . import attributes
from . import strategy_options
-from .base import SQLCoreOperations
from .descriptor_props import Composite
from .descriptor_props import ConcreteInheritedProperty
from .descriptor_props import Synonym
from ..sql import coercions
from ..sql import roles
from ..sql import sqltypes
+from ..sql.elements import SQLCoreOperations
from ..sql.schema import Column
from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
from ..util.typing import NoneType
+from ..util.typing import Self
if TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
from ._typing import _ORMColumnExprArgument
+ from ._typing import _RegistryType
+ from .mapper import Mapper
+ from .session import Session
+ from .state import _InstallLoaderCallableProto
+ from .state import InstanceState
from ..sql._typing import _InfoType
- from ..sql.elements import KeyedColumnElement
+ from ..sql.elements import ColumnElement
+ from ..sql.elements import NamedColumn
+ from ..sql.operators import OperatorType
+ from ..util.typing import _AnnotationScanType
+ from ..util.typing import RODescriptorReference
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
+_NC = TypeVar("_NC", bound="NamedColumn[Any]")
__all__ = [
"ColumnProperty",
inherit_cache = True
_links_to_entity = False
- columns: List[KeyedColumnElement[Any]]
- _orig_columns: List[KeyedColumnElement[Any]]
+ columns: List[NamedColumn[Any]]
+ _orig_columns: List[NamedColumn[Any]]
_is_polymorphic_discriminator: bool
+ _mapped_by_synonym: Optional[str]
+
+ comparator_factory: Type[PropComparator[_T]]
+
__slots__ = (
"_orig_columns",
"columns",
"descriptor",
"active_history",
"expire_on_flush",
- "doc",
"_creation_order",
"_is_polymorphic_discriminator",
"_mapped_by_synonym",
group: Optional[str] = None,
deferred: bool = False,
raiseload: bool = False,
- comparator_factory: Optional[Type[PropComparator]] = None,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
descriptor: Optional[Any] = None,
active_history: bool = False,
expire_on_flush: bool = True,
self.expire_on_flush = expire_on_flush
if info is not None:
- self.info = info
+ self.info.update(info)
if doc is not None:
self.doc = doc
self.strategy_key += (("raiseload", True),)
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
column = self.columns[0]
if column.key is None:
column.key = key
return self
@property
- def columns_to_assign(self) -> List[Column]:
+ def columns_to_assign(self) -> List[Column[Any]]:
+ # mypy doesn't care about the isinstance here
return [
- c
+ c # type: ignore
for c in self.columns
if isinstance(c, Column) and c.table is None
]
- def _memoized_attr__renders_in_subqueries(self):
+ def _memoized_attr__renders_in_subqueries(self) -> bool:
return ("deferred", True) not in self.strategy_key or (
- self not in self.parent._readonly_props
+ self not in self.parent._readonly_props # type: ignore
)
@util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
- def _memoized_attr__deferred_column_loader(self):
+ def _memoized_attr__deferred_column_loader(
+ self,
+ ) -> _InstallLoaderCallableProto[Any]:
state = util.preloaded.orm_state
strategies = util.preloaded.orm_strategies
return state.InstanceState._instance_level_callable_processor(
)
@util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
- def _memoized_attr__raise_column_loader(self):
+ def _memoized_attr__raise_column_loader(
+ self,
+ ) -> _InstallLoaderCallableProto[Any]:
state = util.preloaded.orm_state
strategies = util.preloaded.orm_strategies
return state.InstanceState._instance_level_callable_processor(
self.key,
)
- def __clause_element__(self):
+ def __clause_element__(self) -> roles.ColumnsClauseRole:
"""Allow the ColumnProperty to work in expression before it is turned
into an instrumented attribute.
"""
return self.expression
@property
- def expression(self):
+ def expression(self) -> roles.ColumnsClauseRole:
"""Return the primary column or expression for this ColumnProperty.
E.g.::
"""
return self.columns[0]
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
if not self.instrument:
return
doc=self.doc,
)
- def do_init(self):
+ def do_init(self) -> None:
super().do_init()
if len(self.columns) > 1 and set(self.parent.primary_key).issuperset(
% (self.parent, self.columns[1], self.columns[0], self.key)
)
- def copy(self):
+ def copy(self) -> ColumnProperty[_T]:
return ColumnProperty(
+ *self.columns,
deferred=self.deferred,
group=self.group,
active_history=self.active_history,
- *self.columns,
- )
-
- def _getcommitted(
- self, state, dict_, column, passive=attributes.PASSIVE_OFF
- ):
- return state.get_impl(self.key).get_committed_value(
- state, dict_, passive=passive
)
def merge(
self,
- session,
- source_state,
- source_dict,
- dest_state,
- dest_dict,
- load,
- _recursive,
- _resolve_conflict_map,
- ):
+ session: Session,
+ source_state: InstanceState[Any],
+ source_dict: _InstanceDict,
+ dest_state: InstanceState[Any],
+ dest_dict: _InstanceDict,
+ load: bool,
+ _recursive: Dict[Any, object],
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+ ) -> None:
if not self.instrument:
return
elif self.key in source_dict:
"""
- __slots__ = "__clause_element__", "info", "expressions"
+ if not TYPE_CHECKING:
+ # prevent pylance from being clever about slots
+ __slots__ = "__clause_element__", "info", "expressions"
+
+ prop: RODescriptorReference[ColumnProperty[_PT]]
- def _orm_annotate_column(self, column):
+ def _orm_annotate_column(self, column: _NC) -> _NC:
"""annotate and possibly adapt a column to be returned
as the mapped-attribute exposed version of the column.
"""
pe = self._parententity
- annotations = {
+ annotations: Dict[str, Any] = {
"entity_namespace": pe,
"parententity": pe,
"parentmapper": pe,
{"compile_state_plugin": "orm", "plugin_subject": pe}
)
- def _memoized_method___clause_element__(self):
+ if TYPE_CHECKING:
+
+ def __clause_element__(self) -> NamedColumn[_PT]:
+ ...
+
+ def _memoized_method___clause_element__(
+ self,
+ ) -> NamedColumn[_PT]:
if self.adapter:
return self.adapter(self.prop.columns[0], self.prop.key)
else:
return self._orm_annotate_column(self.prop.columns[0])
- def _memoized_attr_info(self):
+ def _memoized_attr_info(self) -> _InfoType:
"""The .info dictionary for this attribute."""
ce = self.__clause_element__()
try:
- return ce.info
+ return ce.info # type: ignore
except AttributeError:
return self.prop.info
- def _memoized_attr_expressions(self):
+ def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]:
"""The full sequence of columns referenced by this
attribute, adjusted for any aliasing in progress.
self._orm_annotate_column(col) for col in self.prop.columns
]
- def _fallback_getattr(self, key):
+ def _fallback_getattr(self, key: str) -> Any:
"""proxy attribute access down to the mapped column.
this allows user-defined comparison methods to be accessed.
"""
return getattr(self.__clause_element__(), key)
- def operate(self, op, *other, **kwargs):
- return op(self.__clause_element__(), *other, **kwargs)
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501
- def reverse_operate(self, op, other, **kwargs):
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
col = self.__clause_element__()
- return op(col._bind_param(op, other), col, **kwargs)
+ return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501
- def __str__(self):
+ def __str__(self) -> str:
if not self.parent or not self.key:
return object.__repr__(self)
return str(self.parent.class_.__name__) + "." + self.key
column: Column[_T]
foreign_keys: Optional[Set[ForeignKey]]
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg: Any, **kw: Any):
self.deferred = kw.pop("deferred", False)
self.column = cast("Column[_T]", Column(*arg, **kw))
self.foreign_keys = self.column.foreign_keys
)
util.set_creation_order(self)
- def _copy(self, **kw):
- new = self.__class__.__new__(self.__class__)
+ def _copy(self: Self, **kw: Any) -> Self:
+ new = cast(Self, self.__class__.__new__(self.__class__))
new.column = self.column._copy(**kw)
new.deferred = self.deferred
new.foreign_keys = new.column.foreign_keys
return None
@property
- def columns_to_assign(self) -> List[Column]:
+ def columns_to_assign(self) -> List[Column[Any]]:
return [self.column]
- def __clause_element__(self):
+ def __clause_element__(self) -> Column[_T]:
return self.column
- def operate(self, op, *other, **kwargs):
- return op(self.__clause_element__(), *other, **kwargs)
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501
- def reverse_operate(self, op, other, **kwargs):
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
col = self.__clause_element__()
- return op(col._bind_param(op, other), col, **kwargs)
+ return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
column = self.column
if column.key is None:
column.key = key
@util.preload_module("sqlalchemy.orm.decl_base")
def declarative_scan_for_composite(
- self, registry, cls, key, param_name, param_annotation
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ param_name: str,
+ param_annotation: _AnnotationScanType,
+ ) -> None:
decl_base = util.preloaded.orm_decl_base
decl_base._undefer_column_name(param_name, self.column)
self._init_column_for_annotation(cls, registry, param_annotation)
- def _init_column_for_annotation(self, cls, registry, argument):
+ def _init_column_for_annotation(
+ self,
+ cls: Type[Any],
+ registry: _RegistryType,
+ argument: _AnnotationScanType,
+ ) -> None:
sqltype = self.column.type
nullable = False
if hasattr(argument, "__origin__"):
- nullable = NoneType in argument.__args__
+ nullable = NoneType in argument.__args__ # type: ignore
if not self._has_nullable:
self.column.nullable = nullable
if sqltype._isnull and not self.column.foreign_keys:
- sqltype = None
+ new_sqltype = None
our_type = de_optionalize_union_types(argument)
if is_fwd_ref(our_type):
our_type = de_stringify_annotation(cls, our_type)
if registry.type_annotation_map:
- sqltype = registry.type_annotation_map.get(our_type)
- if sqltype is None:
- sqltype = sqltypes._type_map_get(our_type)
+ new_sqltype = registry.type_annotation_map.get(our_type)
+ if new_sqltype is None:
+ new_sqltype = sqltypes._type_map_get(our_type) # type: ignore
- if sqltype is None:
+ if new_sqltype is None:
raise sa_exc.ArgumentError(
f"Could not locate SQLAlchemy Core "
f"type for Python type: {our_type}"
)
- self.column.type = sqltype
+ self.column.type = new_sqltype # type: ignore
import collections.abc as collections_abc
import operator
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
+from typing import Mapping
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from . import exc as orm_exc
+from . import attributes
from . import interfaces
from . import loading
from . import util as orm_util
from .context import _determine_last_joined_entity
from .context import _legacy_filter_by_entity_zero
from .context import FromStatement
-from .context import LABEL_STYLE_LEGACY_ORM
from .context import ORMCompileState
from .context import QueryContext
from .interfaces import ORMColumnDescription
from .. import util
from ..engine import Result
from ..engine import Row
+from ..event import dispatcher
+from ..event import EventTarget
from ..sql import coercions
from ..sql import expression
from ..sql import roles
from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import _entity_namespace_key
from ..sql.base import _generative
+from ..sql.base import _NoArg
from ..sql.base import Executable
from ..sql.base import Generative
+from ..sql.elements import BooleanClauseList
from ..sql.expression import Exists
from ..sql.selectable import _MemoizedSelectEntities
from ..sql.selectable import _SelectFromElements
from ..sql.selectable import HasPrefixes
from ..sql.selectable import HasSuffixes
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectLabelStyle
from ..util.typing import Literal
+from ..util.typing import Self
if TYPE_CHECKING:
from ._typing import _EntityType
+ from ._typing import _ExternalEntityType
+ from ._typing import _InternalEntityType
+ from .mapper import Mapper
+ from .path_registry import PathRegistry
+ from .session import _PKIdentityArgument
from .session import Session
+ from .state import InstanceState
+ from ..engine.cursor import CursorResult
+ from ..engine.interfaces import _ImmutableExecuteOptions
+ from ..engine.result import FrozenResult
from ..engine.result import ScalarResult
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _ColumnsClauseArgument
+ from ..sql._typing import _DMLColumnArgument
+ from ..sql._typing import _JoinTargetArgument
from ..sql._typing import _MAYBE_ENTITY
from ..sql._typing import _no_kw
from ..sql._typing import _NOT_ENTITY
+ from ..sql._typing import _OnClauseArgument
from ..sql._typing import _PropagateAttrsType
from ..sql._typing import _T0
from ..sql._typing import _T1
from ..sql._typing import _T6
from ..sql._typing import _T7
from ..sql._typing import _TypedColumnClauseArgument as _TCCA
- from ..sql.roles import TypedColumnsClauseRole
+ from ..sql.base import CacheableOptions
+ from ..sql.base import ExecutableOption
+ from ..sql.elements import ColumnElement
+ from ..sql.elements import Label
+ from ..sql.selectable import _JoinTargetElement
from ..sql.selectable import _SetupJoinsElement
from ..sql.selectable import Alias
+ from ..sql.selectable import CTE
from ..sql.selectable import ExecutableReturnsRows
+ from ..sql.selectable import FromClause
from ..sql.selectable import ScalarSelect
from ..sql.selectable import Subquery
+
__all__ = ["Query", "QueryContext"]
_T = TypeVar("_T", bound=Any)
-SelfQuery = TypeVar("SelfQuery", bound="Query")
+SelfQuery = TypeVar("SelfQuery", bound="Query[Any]")
@inspection._self_inspects
HasPrefixes,
HasSuffixes,
HasHints,
+ EventTarget,
log.Identified,
Generative,
Executable,
"""
# elements that are in Core and can be cached in the same way
- _where_criteria = ()
- _having_criteria = ()
+ _where_criteria: Tuple[ColumnElement[Any], ...] = ()
+ _having_criteria: Tuple[ColumnElement[Any], ...] = ()
- _order_by_clauses = ()
- _group_by_clauses = ()
- _limit_clause = None
- _offset_clause = None
+ _order_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+ _group_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+ _limit_clause: Optional[ColumnElement[Any]] = None
+ _offset_clause: Optional[ColumnElement[Any]] = None
- _distinct = False
- _distinct_on = ()
+ _distinct: bool = False
+ _distinct_on: Tuple[ColumnElement[Any], ...] = ()
- _for_update_arg = None
- _correlate = ()
- _auto_correlate = True
- _from_obj = ()
+ _for_update_arg: Optional[ForUpdateArg] = None
+ _correlate: Tuple[FromClause, ...] = ()
+ _auto_correlate: bool = True
+ _from_obj: Tuple[FromClause, ...] = ()
_setup_joins: Tuple[_SetupJoinsElement, ...] = ()
- _label_style = LABEL_STYLE_LEGACY_ORM
+ _label_style: SelectLabelStyle = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM
_memoized_select_entities = ()
- _compile_options = ORMCompileState.default_compile_options
+ _compile_options: Union[
+ Type[CacheableOptions], CacheableOptions
+ ] = ORMCompileState.default_compile_options
+ _with_options: Tuple[ExecutableOption, ...]
load_options = QueryContext.default_load_options + {
"_legacy_uniquing": True
}
- _params = util.EMPTY_DICT
+ _params: util.immutabledict[str, Any] = util.EMPTY_DICT
# local Query builder state, not needed for
# compilation or execution
_enable_assertions = True
- _statement = None
+ _statement: Optional[ExecutableReturnsRows] = None
+
+ session: Session
+
+ dispatch: dispatcher[Query[_T]]
# mirrors that of ClauseElement, used to propagate the "orm"
# plugin as well as the "subject" of the plugin, e.g. the mapper
"""
- self.session = session
+ # session is usually present. There's one case in subqueryloader
+ # where it stores a Query without a Session and also there are tests
+ # for the query(Entity).with_session(session) API which is likely in
+ # some old recipes, however these are legacy as select() can now be
+ # used.
+ self.session = session # type: ignore
self._set_entities(entities)
- def _set_propagate_attrs(self, values):
- self._propagate_attrs = util.immutabledict(values)
+ def _set_propagate_attrs(
+ self: SelfQuery, values: Mapping[str, Any]
+ ) -> SelfQuery:
+ self._propagate_attrs = util.immutabledict(values) # type: ignore
return self
- def _set_entities(self, entities):
+ def _set_entities(
+ self, entities: Iterable[_ColumnsClauseArgument[Any]]
+ ) -> None:
self._raw_columns = [
coercions.expect(
roles.ColumnsClauseRole,
for ent in util.to_list(entities)
]
- @overload
- def tuples(self: Query[Row[_TP]]) -> Query[_TP]:
- ...
-
- @overload
def tuples(self: Query[_O]) -> Query[Tuple[_O]]:
- ...
-
- def tuples(self) -> Query[Any]:
"""return a tuple-typed form of this :class:`.Query`.
This method invokes the :meth:`.Query.only_return_tuples`
.. versionadded:: 2.0
"""
- return self.only_return_tuples(True)
+ return self.only_return_tuples(True) # type: ignore
- def _entity_from_pre_ent_zero(self):
+ def _entity_from_pre_ent_zero(self) -> Optional[_InternalEntityType[Any]]:
if not self._raw_columns:
return None
ent = self._raw_columns[0]
if "parententity" in ent._annotations:
- return ent._annotations["parententity"]
- elif isinstance(ent, ORMColumnsClauseRole):
- return ent.entity
+ return ent._annotations["parententity"] # type: ignore
elif "bundle" in ent._annotations:
- return ent._annotations["bundle"]
+ return ent._annotations["bundle"] # type: ignore
else:
# label, other SQL expression
for element in visitors.iterate(ent):
if "parententity" in element._annotations:
- return element._annotations["parententity"]
+ return element._annotations["parententity"] # type: ignore # noqa: E501
else:
return None
- def _only_full_mapper_zero(self, methname):
+ def _only_full_mapper_zero(self, methname: str) -> Mapper[Any]:
if (
len(self._raw_columns) != 1
or "parententity" not in self._raw_columns[0]._annotations
"a single mapped class." % methname
)
- return self._raw_columns[0]._annotations["parententity"]
+ return self._raw_columns[0]._annotations["parententity"] # type: ignore # noqa: E501
- def _set_select_from(self, obj, set_base_alias):
+ def _set_select_from(
+ self, obj: Iterable[_FromClauseArgument], set_base_alias: bool
+ ) -> None:
fa = [
coercions.expect(
roles.StrictFromClauseRole,
self._from_obj = tuple(fa)
@_generative
- def _set_lazyload_from(self: SelfQuery, state) -> SelfQuery:
+ def _set_lazyload_from(
+ self: SelfQuery, state: InstanceState[Any]
+ ) -> SelfQuery:
self.load_options += {"_lazy_loaded_from": state}
return self
- def _get_condition(self):
- return self._no_criterion_condition(
- "get", order_by=False, distinct=False
- )
+ def _get_condition(self) -> None:
+ """used by legacy BakedQuery"""
+ self._no_criterion_condition("get", order_by=False, distinct=False)
- def _get_existing_condition(self):
+ def _get_existing_condition(self) -> None:
self._no_criterion_assertion("get", order_by=False, distinct=False)
- def _no_criterion_assertion(self, meth, order_by=True, distinct=True):
+ def _no_criterion_assertion(
+ self, meth: str, order_by: bool = True, distinct: bool = True
+ ) -> None:
if not self._enable_assertions:
return
if (
"Query with existing criterion. " % meth
)
- def _no_criterion_condition(self, meth, order_by=True, distinct=True):
+ def _no_criterion_condition(
+ self, meth: str, order_by: bool = True, distinct: bool = True
+ ) -> None:
self._no_criterion_assertion(meth, order_by, distinct)
self._from_obj = self._setup_joins = ()
self._order_by_clauses = self._group_by_clauses = ()
- def _no_clauseelement_condition(self, meth):
+ def _no_clauseelement_condition(self, meth: str) -> None:
if not self._enable_assertions:
return
if self._order_by_clauses:
)
self._no_criterion_condition(meth)
- def _no_statement_condition(self, meth):
+ def _no_statement_condition(self, meth: str) -> None:
if not self._enable_assertions:
return
if self._statement is not None:
% meth
)
- def _no_limit_offset(self, meth):
+ def _no_limit_offset(self, meth: str) -> None:
if not self._enable_assertions:
return
if self._limit_clause is not None or self._offset_clause is not None:
)
@property
- def _has_row_limiting_clause(self):
+ def _has_row_limiting_clause(self) -> bool:
return (
self._limit_clause is not None or self._offset_clause is not None
)
def _get_options(
- self,
- populate_existing=None,
- version_check=None,
- only_load_props=None,
- refresh_state=None,
- identity_token=None,
- ):
- load_options = {}
- compile_options = {}
+ self: SelfQuery,
+ populate_existing: Optional[bool] = None,
+ version_check: Optional[bool] = None,
+ only_load_props: Optional[Sequence[str]] = None,
+ refresh_state: Optional[InstanceState[Any]] = None,
+ identity_token: Optional[Any] = None,
+ ) -> SelfQuery:
+ load_options: Dict[str, Any] = {}
+ compile_options: Dict[str, Any] = {}
if version_check:
load_options["_version_check"] = version_check
return self
- def _clone(self):
- return self._generate()
+ def _clone(self: Self, **kw: Any) -> Self:
+ return self._generate() # type: ignore
+
+ def _get_select_statement_only(self) -> Select[_T]:
+ if self._statement is not None:
+ raise sa_exc.InvalidRequestError(
+ "Can't call this method on a Query that uses from_statement()"
+ )
+ return cast("Select[_T]", self.statement)
@property
- def statement(self):
+ def statement(self) -> Union[Select[_T], FromStatement[_T]]:
"""The full SELECT statement represented by this Query.
The statement by default will not have disambiguating labels
return stmt
- def _final_statement(self, legacy_query_style=True):
+ def _final_statement(self, legacy_query_style: bool = True) -> Select[Any]:
"""Return the 'final' SELECT statement for this :class:`.Query`.
+ This is used by the testing suite only and is fairly inefficient.
+
This is the Core-only select() that will be rendered by a complete
compilation of this query, and is what .statement used to return
in 1.3.
- This method creates a complete compile state so is fairly expensive.
"""
return q._compile_state(
use_legacy_query_style=legacy_query_style
- ).statement
+ ).statement # type: ignore
- def _statement_20(self, for_statement=False, use_legacy_query_style=True):
+ def _statement_20(
+ self, for_statement: bool = False, use_legacy_query_style: bool = True
+ ) -> Union[Select[_T], FromStatement[_T]]:
# TODO: this event needs to be deprecated, as it currently applies
# only to ORM query and occurs at this spot that is now more
# or less an artificial spot
new_query = fn(self)
if new_query is not None and new_query is not self:
self = new_query
- if not fn._bake_ok:
+ if not fn._bake_ok: # type: ignore
self._compile_options += {"_bake_ok": False}
compile_options = self._compile_options
"_use_legacy_query_style": use_legacy_query_style,
}
+ stmt: Union[Select[_T], FromStatement[_T]]
+
if self._statement is not None:
stmt = FromStatement(self._raw_columns, self._statement)
stmt.__dict__.update(
def subquery(
self,
- name=None,
- with_labels=False,
- reduce_columns=False,
- ):
+ name: Optional[str] = None,
+ with_labels: bool = False,
+ reduce_columns: bool = False,
+ ) -> Subquery:
"""Return the full SELECT statement represented by
this :class:`_query.Query`, embedded within an
:class:`_expression.Alias`.
if with_labels:
q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
- q = q.statement
+ stmt = q._get_select_statement_only()
+
+ if TYPE_CHECKING:
+ assert isinstance(stmt, Select)
if reduce_columns:
- q = q.reduce_columns()
- return q.alias(name=name)
+ stmt = stmt.reduce_columns()
+ return stmt.subquery(name=name)
- def cte(self, name=None, recursive=False, nesting=False):
+ def cte(
+ self,
+ name: Optional[str] = None,
+ recursive: bool = False,
+ nesting: bool = False,
+ ) -> CTE:
r"""Return the full SELECT statement represented by this
:class:`_query.Query` represented as a common table expression (CTE).
:meth:`_expression.HasCTE.cte`
"""
- return self.enable_eagerloads(False).statement.cte(
- name=name, recursive=recursive, nesting=nesting
+ return (
+ self.enable_eagerloads(False)
+ ._get_select_statement_only()
+ .cte(name=name, recursive=recursive, nesting=nesting)
)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[Any]:
"""Return the full SELECT statement represented by this
:class:`_query.Query`, converted
to a scalar subquery with a label of the given name.
"""
- return self.enable_eagerloads(False).statement.label(name)
+ return (
+ self.enable_eagerloads(False)
+ ._get_select_statement_only()
+ .label(name)
+ )
@overload
def as_scalar(
"""
- return self.enable_eagerloads(False).statement.scalar_subquery()
+ return (
+ self.enable_eagerloads(False)
+ ._get_select_statement_only()
+ .scalar_subquery()
+ )
@property
- def selectable(self):
+ def selectable(self) -> Union[Select[_T], FromStatement[_T]]:
"""Return the :class:`_expression.Select` object emitted by this
:class:`_query.Query`.
"""
return self.__clause_element__()
- def __clause_element__(self):
+ def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]:
return (
self._with_compile_options(
_enable_eagerloads=False, _render_for_subquery=True
return self
@property
- def is_single_entity(self):
+ def is_single_entity(self) -> bool:
"""Indicates if this :class:`_query.Query`
returns tuples or single entities.
)
@_generative
- def enable_eagerloads(self: SelfQuery, value) -> SelfQuery:
+ def enable_eagerloads(self: SelfQuery, value: bool) -> SelfQuery:
"""Control whether or not eager joins and subqueries are
rendered.
return self
@_generative
- def _with_compile_options(self: SelfQuery, **opt) -> SelfQuery:
+ def _with_compile_options(self: SelfQuery, **opt: Any) -> SelfQuery:
self._compile_options += opt
return self
alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
"instead.",
)
- def with_labels(self):
- return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ def with_labels(self: SelfQuery) -> SelfQuery:
+ return self.set_label_style(
+ SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL
+ )
apply_labels = with_labels
@property
- def get_label_style(self):
+ def get_label_style(self) -> SelectLabelStyle:
"""
Retrieve the current label style.
"""
return self._label_style
- def set_label_style(self, style):
+ def set_label_style(self: SelfQuery, style: SelectLabelStyle) -> SelfQuery:
"""Apply column labels to the return value of Query.statement.
Indicates that this Query's `statement` accessor should return
return self
@_generative
- def enable_assertions(self: SelfQuery, value) -> SelfQuery:
+ def enable_assertions(self: SelfQuery, value: bool) -> SelfQuery:
"""Control whether assertions are generated.
When set to False, the returned Query will
return self
@property
- def whereclause(self):
+ def whereclause(self) -> Optional[ColumnElement[bool]]:
"""A readonly attribute which returns the current WHERE criterion for
this Query.
criterion has been established.
"""
- return sql.elements.BooleanClauseList._construct_for_whereclause(
+ return BooleanClauseList._construct_for_whereclause(
self._where_criteria
)
@_generative
- def _with_current_path(self: SelfQuery, path) -> SelfQuery:
+ def _with_current_path(self: SelfQuery, path: PathRegistry) -> SelfQuery:
"""indicate that this query applies to objects loaded
within a certain path.
return self
@_generative
- def yield_per(self: SelfQuery, count) -> SelfQuery:
+ def yield_per(self: SelfQuery, count: int) -> SelfQuery:
r"""Yield only ``count`` rows at a time.
The purpose of this method is when fetching very large result sets
":meth:`_orm.Query.get`",
alternative="The method is now available as :meth:`_orm.Session.get`",
)
- def get(self, ident):
+ def get(self, ident: _PKIdentityArgument) -> Optional[Any]:
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
# it
return self._get_impl(ident, loading.load_on_pk_identity)
- def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
+ def _get_impl(
+ self,
+ primary_key_identity: _PKIdentityArgument,
+ db_load_fn: Callable[..., Any],
+ identity_token: Optional[Any] = None,
+ ) -> Optional[Any]:
mapper = self._only_full_mapper_zero("get")
return self.session._get_impl(
mapper,
)
@property
- def lazy_loaded_from(self):
+ def lazy_loaded_from(self) -> Optional[InstanceState[Any]]:
"""An :class:`.InstanceState` that is using this :class:`_query.Query`
for a lazy load operation.
:attr:`.ORMExecuteState.lazy_loaded_from`
"""
- return self.load_options._lazy_loaded_from
+ return self.load_options._lazy_loaded_from # type: ignore
@property
- def _current_path(self):
- return self._compile_options._current_path
+ def _current_path(self) -> PathRegistry:
+ return self._compile_options._current_path # type: ignore
@_generative
- def correlate(self: SelfQuery, *fromclauses) -> SelfQuery:
+ def correlate(
+ self: SelfQuery,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfQuery:
"""Return a :class:`.Query` construct which will correlate the given
FROM clauses to that of an enclosing :class:`.Query` or
:func:`~.expression.select`.
if fromclauses and fromclauses[0] in {None, False}:
self._correlate = ()
else:
- self._correlate = set(self._correlate).union(
+ self._correlate = self._correlate + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
return self
@_generative
- def autoflush(self: SelfQuery, setting) -> SelfQuery:
+ def autoflush(self: SelfQuery, setting: bool) -> SelfQuery:
"""Return a Query with a specific 'autoflush' setting.
As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method
return self
@_generative
- def _with_invoke_all_eagers(self: SelfQuery, value) -> SelfQuery:
+ def _with_invoke_all_eagers(self: SelfQuery, value: bool) -> SelfQuery:
"""Set the 'invoke all eagers' flag which causes joined- and
subquery loaders to traverse into already-loaded related objects
and collections.
alternative="Use the :func:`_orm.with_parent` standalone construct.",
)
@util.preload_module("sqlalchemy.orm.relationships")
- def with_parent(self, instance, property=None, from_entity=None): # noqa
+ def with_parent(
+ self: SelfQuery,
+ instance: object,
+ property: Optional[ # noqa: A002
+ attributes.QueryableAttribute[Any]
+ ] = None,
+ from_entity: Optional[_ExternalEntityType[Any]] = None,
+ ) -> SelfQuery:
"""Add filtering criterion that relates the given instance
to a child object or collection, using its attribute state
as well as an established :func:`_orm.relationship()`
An instance which has some :func:`_orm.relationship`.
:param property:
- String property name, or class-bound attribute, which indicates
+ Class bound attribute which indicates
what relationship from the instance should be used to reconcile the
parent/child relationship.
for prop in mapper.iterate_properties:
if (
isinstance(prop, relationships.Relationship)
- and prop.mapper is entity_zero.mapper
+ and prop.mapper is entity_zero.mapper # type: ignore
):
- property = prop # noqa
+ property = prop # type: ignore # noqa: A001
break
else:
raise sa_exc.InvalidRequestError(
"Could not locate a property which relates instances "
"of class '%s' to instances of class '%s'"
% (
- entity_zero.mapper.class_.__name__,
+ entity_zero.mapper.class_.__name__, # type: ignore
instance.__class__.__name__,
)
)
- return self.filter(with_parent(instance, property, entity_zero.entity))
+ return self.filter(
+ with_parent(
+ instance,
+ property, # type: ignore
+ entity_zero.entity, # type: ignore
+ )
+ )
@_generative
def add_entity(
return self
@_generative
- def with_session(self: SelfQuery, session) -> SelfQuery:
+ def with_session(self: SelfQuery, session: Session) -> SelfQuery:
"""Return a :class:`_query.Query` that will use the given
:class:`.Session`.
self.session = session
return self
- def _legacy_from_self(self, *entities):
+ def _legacy_from_self(
+ self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
+ ) -> SelfQuery:
# used for query.count() as well as for the same
# function in BakedQuery, as well as some old tests in test_baked.py.
return q
@_generative
- def _set_enable_single_crit(self: SelfQuery, val) -> SelfQuery:
+ def _set_enable_single_crit(self: SelfQuery, val: bool) -> SelfQuery:
self._compile_options += {"_enable_single_crit": val}
return self
@_generative
def _from_selectable(
- self: SelfQuery, fromclause, set_entity_from=True
+ self: SelfQuery, fromclause: FromClause, set_entity_from: bool = True
) -> SelfQuery:
for attr in (
"_where_criteria",
"is deprecated and will be removed in a "
"future release. Please use :meth:`_query.Query.with_entities`",
)
- def values(self, *columns):
+ def values(self, *columns: _ColumnsClauseArgument[Any]) -> Iterable[Any]:
"""Return an iterator yielding result tuples corresponding
to the given list of columns
q._set_entities(columns)
if not q.load_options._yield_per:
q.load_options += {"_yield_per": 10}
- return iter(q)
+ return iter(q) # type: ignore
_values = values
"future release. Please use :meth:`_query.Query.with_entities` "
"in combination with :meth:`_query.Query.scalar`",
)
- def value(self, column):
+ def value(self, column: _ColumnExpressionArgument[Any]) -> Any:
"""Return a scalar result corresponding to the given
column expression.
"""
try:
- return next(self.values(column))[0]
+ return next(self.values(column))[0] # type: ignore
except StopIteration:
return None
@overload
- def with_entities(
- self, _entity: _EntityType[_O], **kwargs: Any
- ) -> Query[_O]:
+ def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]:
...
@overload
def with_entities(
- self, _colexpr: TypedColumnsClauseRole[_T]
+ self,
+ _colexpr: roles.TypedColumnsClauseRole[_T],
) -> RowReturningQuery[Tuple[_T]]:
...
@overload
def with_entities(
- self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
- ) -> SelfQuery:
+ self, *entities: _ColumnsClauseArgument[Any]
+ ) -> Query[Any]:
...
@_generative
def with_entities(
- self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any
- ) -> SelfQuery:
+ self, *entities: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> Query[Any]:
r"""Return a new :class:`_query.Query`
replacing the SELECT list with the
given entities.
"""
if __kw:
raise _no_kw()
- _MemoizedSelectEntities._generate_for_statement(self)
+
+ # Query has all the same fields as Select for this operation
+ # this could in theory be based on a protocol but not sure if it's
+ # worth it
+ _MemoizedSelectEntities._generate_for_statement(self) # type: ignore
self._set_entities(entities)
return self
@_generative
- def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]:
+ def add_columns(
+ self, *column: _ColumnExpressionArgument[Any]
+ ) -> Query[Any]:
"""Add one or more column expressions to the list
of result columns to be returned."""
"is deprecated and will be removed in a "
"future release. Please use :meth:`_query.Query.add_columns`",
)
- def add_column(self, column) -> Query[Any]:
+ def add_column(self, column: _ColumnExpressionArgument[Any]) -> Query[Any]:
"""Add a column expression to the list of result columns to be
returned.
return self.add_columns(column)
@_generative
- def options(self: SelfQuery, *args) -> SelfQuery:
+ def options(self: SelfQuery, *args: ExecutableOption) -> SelfQuery:
"""Return a new :class:`_query.Query` object,
applying the given list of
mapper options.
opts = tuple(util.flatten_iterator(args))
if self._compile_options._current_path:
+ # opting for lower method overhead for the checks
for opt in opts:
- if opt._is_legacy_option:
- opt.process_query_conditionally(self)
+ if not opt._is_core and opt._is_legacy_option: # type: ignore
+ opt.process_query_conditionally(self) # type: ignore
else:
for opt in opts:
- if opt._is_legacy_option:
- opt.process_query(self)
+ if not opt._is_core and opt._is_legacy_option: # type: ignore
+ opt.process_query(self) # type: ignore
self._with_options += opts
return self
- def with_transformation(self, fn):
+ def with_transformation(
+ self, fn: Callable[[Query[Any]], Query[Any]]
+ ) -> Query[Any]:
"""Return a new :class:`_query.Query` object transformed by
the given function.
"""
return fn(self)
- def get_execution_options(self):
+ def get_execution_options(self) -> _ImmutableExecuteOptions:
"""Get the non-SQL options which will take effect during execution.
.. versionadded:: 1.3
return self._execution_options
@_generative
- def execution_options(self: SelfQuery, **kwargs) -> SelfQuery:
+ def execution_options(self: SelfQuery, **kwargs: Any) -> SelfQuery:
"""Set non-SQL options which take effect during execution.
Options allowed here include all of those accepted by
@_generative
def with_for_update(
self: SelfQuery,
- read=False,
- nowait=False,
- of=None,
- skip_locked=False,
- key_share=False,
+ *,
+ nowait: bool = False,
+ read: bool = False,
+ of: Optional[
+ Union[
+ _ColumnExpressionArgument[Any],
+ Sequence[_ColumnExpressionArgument[Any]],
+ ]
+ ] = None,
+ skip_locked: bool = False,
+ key_share: bool = False,
) -> SelfQuery:
"""return a new :class:`_query.Query`
with the specified options for the
return self
@_generative
- def params(self: SelfQuery, *args, **kwargs) -> SelfQuery:
+ def params(
+ self: SelfQuery, __params: Optional[Dict[str, Any]] = None, **kw: Any
+ ) -> SelfQuery:
r"""Add values for bind parameters which may have been
specified in filter().
contain unicode keys in which case \**kwargs cannot be used.
"""
- if len(args) == 1:
- kwargs.update(args[0])
- elif len(args) > 0:
- raise sa_exc.ArgumentError(
- "params() takes zero or one positional argument, "
- "which is a dictionary."
- )
- self._params = self._params.union(kwargs)
+ if __params:
+ kw.update(__params)
+ self._params = self._params.union(kw)
return self
- def where(self: SelfQuery, *criterion) -> SelfQuery:
+ def where(
+ self: SelfQuery, *criterion: _ColumnExpressionArgument[bool]
+ ) -> SelfQuery:
"""A synonym for :meth:`.Query.filter`.
.. versionadded:: 1.4
:meth:`_query.Query.filter_by` - filter on keyword expressions.
"""
- for criterion in list(criterion):
- criterion = coercions.expect(
- roles.WhereHavingRole, criterion, apply_propagate_attrs=self
+ for crit in list(criterion):
+ crit = coercions.expect(
+ roles.WhereHavingRole, crit, apply_propagate_attrs=self
)
- self._where_criteria += (criterion,)
+ self._where_criteria += (crit,)
return self
@util.memoized_property
- def _last_joined_entity(self):
+ def _last_joined_entity(
+ self,
+ ) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]:
if self._setup_joins:
return _determine_last_joined_entity(
self._setup_joins,
else:
return None
- def _filter_by_zero(self):
+ def _filter_by_zero(self) -> Any:
"""for the filter_by() method, return the target entity for which
we will attempt to derive an expression from based on string name.
"""
from_entity = self._filter_by_zero()
- if from_entity is None:
- raise sa_exc.InvalidRequestError(
- "Can't use filter_by when the first entity '%s' of a query "
- "is not a mapped class. Please use the filter method instead, "
- "or change the order of the entities in the query"
- % self._query_entity_zero()
- )
clauses = [
_entity_namespace_key(from_entity, key) == value
return self.filter(*clauses)
@_generative
- @_assertions(_no_statement_condition, _no_limit_offset)
def order_by(
- self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfQuery,
+ __first: Union[
+ Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfQuery:
"""Apply one or more ORDER BY criteria to the query and return
the newly resulting :class:`_query.Query`.
"""
- if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+ for assertion in (self._no_statement_condition, self._no_limit_offset):
+ assertion("order_by")
+
+ if not clauses and (__first is None or __first is False):
self._order_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
criterion = tuple(
coercions.expect(roles.OrderByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
self._order_by_clauses += criterion
+
return self
@_generative
- @_assertions(_no_statement_condition, _no_limit_offset)
def group_by(
- self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfQuery,
+ __first: Union[
+ Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfQuery:
"""Apply one or more GROUP BY criterion to the query and return
the newly resulting :class:`_query.Query`.
"""
- if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+ for assertion in (self._no_statement_condition, self._no_limit_offset):
+ assertion("group_by")
+
+ if not clauses and (__first is None or __first is False):
self._group_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
criterion = tuple(
coercions.expect(roles.GroupByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
self._group_by_clauses += criterion
return self
self._having_criteria += (having_criteria,)
return self
- def _set_op(self, expr_fn, *q):
- return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
+ def _set_op(self: SelfQuery, expr_fn: Any, *q: Query[Any]) -> SelfQuery:
+ list_of_queries = (self,) + q
+ return self._from_selectable(expr_fn(*(list_of_queries)).subquery())
def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce a UNION of this Query against one or more queries.
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
def join(
- self: SelfQuery, target, onclause=None, *, isouter=False, full=False
+ self: SelfQuery,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ isouter: bool = False,
+ full: bool = False,
) -> SelfQuery:
r"""Create a SQL JOIN against this :class:`_query.Query`
object's criterion
"""
- target = coercions.expect(
+ join_target = coercions.expect(
roles.JoinTargetRole,
target,
apply_propagate_attrs=self,
legacy=True,
)
if onclause is not None:
- onclause = coercions.expect(
+ onclause_element = coercions.expect(
roles.OnClauseRole, onclause, legacy=True
)
+ else:
+ onclause_element = None
+
self._setup_joins += (
(
- target,
- onclause,
+ join_target,
+ onclause_element,
None,
{
"isouter": isouter,
self.__dict__.pop("_last_joined_entity", None)
return self
- def outerjoin(self, target, onclause=None, *, full=False):
+ def outerjoin(
+ self: SelfQuery,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ full: bool = False,
+ ) -> SelfQuery:
"""Create a left outer join against this ``Query`` object's criterion
and apply generatively, returning the newly resulting ``Query``.
self._set_select_from(from_obj, False)
return self
- def __getitem__(self, item):
+ def __getitem__(self, item: Any) -> Any:
return orm_util._getitem(
self,
item,
@_generative
@_assertions(_no_statement_condition)
- def slice(self: SelfQuery, start, stop) -> SelfQuery:
+ def slice(
+ self: SelfQuery,
+ start: int,
+ stop: int,
+ ) -> SelfQuery:
"""Computes the "slice" of the :class:`_query.Query` represented by
the given indices and returns the resulting :class:`_query.Query`.
@_generative
@_assertions(_no_statement_condition)
- def limit(self: SelfQuery, limit) -> SelfQuery:
+ def limit(
+ self: SelfQuery, limit: Union[int, _ColumnExpressionArgument[int]]
+ ) -> SelfQuery:
"""Apply a ``LIMIT`` to the query and return the newly resulting
``Query``.
@_generative
@_assertions(_no_statement_condition)
- def offset(self: SelfQuery, offset) -> SelfQuery:
+ def offset(
+ self: SelfQuery, offset: Union[int, _ColumnExpressionArgument[int]]
+ ) -> SelfQuery:
"""Apply an ``OFFSET`` to the query and return the newly resulting
``Query``.
@_generative
@_assertions(_no_statement_condition)
- def distinct(self: SelfQuery, *expr) -> SelfQuery:
+ def distinct(
+ self: SelfQuery, *expr: _ColumnExpressionArgument[Any]
+ ) -> SelfQuery:
r"""Apply a ``DISTINCT`` to the query and return the newly resulting
``Query``.
:ref:`faq_query_deduplicating`
"""
- return self._iter().all()
+ return self._iter().all() # type: ignore
@_generative
@_assertions(_no_clauseelement_condition)
"""
# replicates limit(1) behavior
if self._statement is not None:
- return self._iter().first()
+ return self._iter().first() # type: ignore
else:
- return self.limit(1)._iter().first()
+ return self.limit(1)._iter().first() # type: ignore
def one_or_none(self) -> Optional[_T]:
"""Return at most one result or raise an exception.
:meth:`_query.Query.one`
"""
- return self._iter().one_or_none()
+ return self._iter().one_or_none() # type: ignore
def one(self) -> _T:
"""Return exactly one result or raise an exception.
if not isinstance(ret, collections_abc.Sequence):
return ret
return ret[0]
- except orm_exc.NoResultFound:
+ except sa_exc.NoResultFound:
return None
def __iter__(self) -> Iterable[_T]:
- return self._iter().__iter__()
+ return self._iter().__iter__() # type: ignore
def _iter(self) -> Union[ScalarResult[_T], Result[_T]]:
# new style execution.
params = self._params
statement = self._statement_20()
- result = self.session.execute(
+ result: Union[ScalarResult[_T], Result[_T]] = self.session.execute(
statement,
params,
execution_options={"_sa_orm_load_options": self.load_options},
# legacy: automatically set scalars, unique
if result._attributes.get("is_single_entity", False):
- result = result.scalars()
+ result = cast("Result[_T]", result).scalars()
if (
result._attributes.get("filtered", False)
return str(statement.compile(bind))
- def _get_bind_args(self, statement, fn, **kw):
+ def _get_bind_args(self, statement: Any, fn: Any, **kw: Any) -> Any:
return fn(clause=statement, **kw)
@property
return _column_descriptions(self, legacy=True)
- def instances(self, result_proxy: Result, context=None) -> Any:
+ def instances(
+ self,
+ result_proxy: CursorResult[Any],
+ context: Optional[QueryContext] = None,
+ ) -> Any:
"""Return an ORM result given a :class:`_engine.CursorResult` and
:class:`.QueryContext`.
# legacy: automatically set scalars, unique
if result._attributes.get("is_single_entity", False):
- result = result.scalars()
+ result = result.scalars() # type: ignore
if result._attributes.get("filtered", False):
result = result.unique()
":func:`_orm.merge_frozen_result` function.",
enable_warnings=False, # warnings occur via loading.merge_result
)
- def merge_result(self, iterator, load=True):
+ def merge_result(
+ self,
+ iterator: Union[
+ FrozenResult[Any], Iterable[Sequence[Any]], Iterable[object]
+ ],
+ load: bool = True,
+ ) -> Union[FrozenResult[Any], Iterable[Any]]:
"""Merge a result into this :class:`_query.Query` object's Session.
Given an iterator returned by a :class:`_query.Query`
self.enable_eagerloads(False)
.add_columns(sql.literal_column("1"))
.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
- .statement.with_only_columns(1)
+ ._get_select_statement_only()
+ .with_only_columns(1)
)
ezero = self._entity_from_pre_ent_zero()
return sql.exists(inner)
- def count(self):
+ def count(self) -> int:
r"""Return a count of rows this the SQL formed by this :class:`Query`
would return.
"""
col = sql.func.count(sql.literal_column("*"))
- return self._legacy_from_self(col).enable_eagerloads(False).scalar()
+ return ( # type: ignore
+ self._legacy_from_self(col).enable_eagerloads(False).scalar()
+ )
- def delete(self, synchronize_session="evaluate"):
+ def delete(self, synchronize_session: str = "evaluate") -> int:
r"""Perform a DELETE with an arbitrary WHERE clause.
Deletes rows matched by this query from the database.
self = bulk_del.query
- delete_ = sql.delete(*self._raw_columns)
+ delete_ = sql.delete(*self._raw_columns) # type: ignore
delete_._where_criteria = self._where_criteria
- result = self.session.execute(
- delete_,
- self._params,
- execution_options={"synchronize_session": synchronize_session},
+ result: CursorResult[Any] = cast(
+ "CursorResult[Any]",
+ self.session.execute(
+ delete_,
+ self._params,
+ execution_options={"synchronize_session": synchronize_session},
+ ),
)
- bulk_del.result = result
+ bulk_del.result = result # type: ignore
self.session.dispatch.after_bulk_delete(bulk_del)
result.close()
return result.rowcount
- def update(self, values, synchronize_session="evaluate", update_args=None):
+ def update(
+ self,
+ values: Dict[_DMLColumnArgument, Any],
+ synchronize_session: str = "evaluate",
+ update_args: Optional[Dict[Any, Any]] = None,
+ ) -> int:
r"""Perform an UPDATE with an arbitrary WHERE clause.
Updates rows matched by this query in the database.
bulk_ud.query = new_query
self = bulk_ud.query
- upd = sql.update(*self._raw_columns)
+ upd = sql.update(*self._raw_columns) # type: ignore
ppo = update_args.pop("preserve_parameter_order", False)
if ppo:
- upd = upd.ordered_values(*values)
+ upd = upd.ordered_values(*values) # type: ignore
else:
upd = upd.values(values)
if update_args:
upd = upd.with_dialect_options(**update_args)
upd._where_criteria = self._where_criteria
- result = self.session.execute(
- upd,
- self._params,
- execution_options={"synchronize_session": synchronize_session},
+ result: CursorResult[Any] = cast(
+ "CursorResult[Any]",
+ self.session.execute(
+ upd,
+ self._params,
+ execution_options={"synchronize_session": synchronize_session},
+ ),
)
- bulk_ud.result = result
+ bulk_ud.result = result # type: ignore
self.session.dispatch.after_bulk_update(bulk_ud)
result.close()
return result.rowcount
- def _compile_state(self, for_statement=False, **kw):
+ def _compile_state(
+ self, for_statement: bool = False, **kw: Any
+ ) -> ORMCompileState:
"""Create an out-of-compiler ORMCompileState object.
The ORMCompileState object is normally created directly as a result
# ORMSelectCompileState. We could also base this on
# query._statement is not None as we have the ORM Query here
# however this is the more general path.
- compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
- stmt, "orm"
+ compile_state_cls = cast(
+ ORMCompileState,
+ ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"),
)
return compile_state_cls.create_for_statement(stmt, None)
- def _compile_context(self, for_statement=False):
+ def _compile_context(self, for_statement: bool = False) -> QueryContext:
compile_state = self._compile_state(for_statement=for_statement)
context = QueryContext(
compile_state,
"""
- def process_compile_state(self, compile_state: ORMCompileState):
+ def process_compile_state(self, compile_state: ORMCompileState) -> None:
pass
"""
- def __init__(self, query):
+ def __init__(self, query: Query[Any]):
self.query = query.enable_eagerloads(False)
self._validate_query_state()
self.mapper = self.query._entity_from_pre_ent_zero()
- def _validate_query_state(self):
+ def _validate_query_state(self) -> None:
for attr, methname, notset, op in (
("_limit_clause", "limit()", None, operator.is_),
("_offset_clause", "offset()", None, operator.is_),
)
@property
- def session(self):
+ def session(self) -> Session:
return self.query.session
class BulkUpdate(BulkUD):
"""BulkUD which handles UPDATEs."""
- def __init__(self, query, values, update_kwargs):
+ def __init__(
+ self,
+ query: Query[Any],
+ values: Dict[_DMLColumnArgument, Any],
+ update_kwargs: Optional[Dict[Any, Any]],
+ ):
super(BulkUpdate, self).__init__(query)
self.values = values
self.update_kwargs = update_kwargs
class RowReturningQuery(Query[Row[_TP]]):
- pass
+ if TYPE_CHECKING:
+
+ def tuples(self) -> Query[_TP]: # type: ignore
+ ...
import collections
from collections import abc
+import dataclasses
import re
import typing
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Collection
from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
from . import attributes
from . import strategy_options
+from ._typing import insp_is_aliased_class
+from ._typing import is_has_collection_adapter
from .base import _is_mapped_class
from .base import class_mapper
+from .base import LoaderCallableStatus
+from .base import PassiveFlag
from .base import state_str
from .interfaces import _IntrospectsAnnotations
from .interfaces import MANYTOMANY
from .interfaces import MANYTOONE
from .interfaces import ONETOMANY
from .interfaces import PropComparator
+from .interfaces import RelationshipDirection
from .interfaces import StrategizedProperty
from .util import _extract_mapped_subtype
from .util import _orm_annotate
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _HasClauseElement
from ..sql.elements import ColumnClause
+from ..sql.elements import ColumnElement
from ..sql.util import _deep_deannotate
from ..sql.util import _shallow_annotate
from ..sql.util import adapt_criterion_to_null
if typing.TYPE_CHECKING:
from ._typing import _EntityType
+ from ._typing import _ExternalEntityType
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
from ._typing import _InternalEntityType
+ from ._typing import _O
+ from ._typing import _RegistryType
+ from .clsregistry import _class_resolver
+ from .clsregistry import _ModNS
+ from .dependency import DependencyProcessor
from .mapper import Mapper
+ from .query import Query
+ from .session import Session
+ from .state import InstanceState
+ from .strategies import LazyLoader
from .util import AliasedClass
from .util import AliasedInsp
- from ..sql.elements import ColumnElement
+ from ..sql._typing import _CoreAdapterProto
+ from ..sql._typing import _EquivalentColumnMap
+ from ..sql._typing import _InfoType
+ from ..sql.annotation import _AnnotationDict
+ from ..sql.elements import BinaryExpression
+ from ..sql.elements import BindParameter
+ from ..sql.elements import ClauseElement
+ from ..sql.schema import Table
+ from ..sql.selectable import FromClause
+ from ..util.typing import _AnnotationScanType
+ from ..util.typing import RODescriptorReference
_T = TypeVar("_T", bound=Any)
+_T1 = TypeVar("_T1", bound=Any)
+_T2 = TypeVar("_T2", bound=Any)
+
_PT = TypeVar("_PT", bound=Any)
+_PT2 = TypeVar("_PT2", bound=Any)
+
_RelationshipArgumentType = Union[
str,
str, _ColumnExpressionArgument[bool]
]
_ORMOrderByArgument = Union[
- Literal[False], str, _ColumnExpressionArgument[Any]
+ Literal[False],
+ str,
+ _ColumnExpressionArgument[Any],
+ Iterable[Union[str, _ColumnExpressionArgument[Any]]],
]
_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
_ORMColCollectionArgument = Union[
]
-def remote(expr):
+_CEA = TypeVar("_CEA", bound=_ColumnExpressionArgument[Any])
+
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
+
+
+_ColumnPairIterable = Iterable[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+_ColumnPairs = Sequence[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+_MutableColumnPairs = List[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+
+def remote(expr: _CEA) -> _CEA:
"""Annotate a portion of a primaryjoin expression
with a 'remote' annotation.
:func:`.foreign`
"""
- return _annotate_columns(
+ return _annotate_columns( # type: ignore
coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True}
)
-def foreign(expr):
+def foreign(expr: _CEA) -> _CEA:
"""Annotate a portion of a primaryjoin expression
with a 'foreign' annotation.
"""
- return _annotate_columns(
+ return _annotate_columns( # type: ignore
coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True}
)
+@dataclasses.dataclass
+class _RelationshipArg(Generic[_T1, _T2]):
+ """stores a user-defined parameter value that must be resolved and
+ parsed later at mapper configuration time.
+
+ """
+
+ __slots__ = "name", "argument", "resolved"
+ name: str
+ argument: _T1
+ resolved: Optional[_T2]
+
+ def _is_populated(self) -> bool:
+ return self.argument is not None
+
+ def _resolve_against_registry(
+ self, clsregistry_resolver: Callable[[str, bool], _class_resolver]
+ ) -> None:
+ attr_value = self.argument
+
+ if isinstance(attr_value, str):
+ self.resolved = clsregistry_resolver(
+ attr_value, self.name == "secondary"
+ )()
+ elif callable(attr_value) and not _is_mapped_class(attr_value):
+ self.resolved = attr_value()
+ else:
+ self.resolved = attr_value
+
+
+class _RelationshipArgs(NamedTuple):
+ """stores user-passed parameters that are resolved at mapper configuration
+ time.
+
+ """
+
+ secondary: _RelationshipArg[
+ Optional[Union[FromClause, str]],
+ Optional[FromClause],
+ ]
+ primaryjoin: _RelationshipArg[
+ Optional[_RelationshipJoinConditionArgument],
+ Optional[ColumnElement[Any]],
+ ]
+ secondaryjoin: _RelationshipArg[
+ Optional[_RelationshipJoinConditionArgument],
+ Optional[ColumnElement[Any]],
+ ]
+ order_by: _RelationshipArg[
+ _ORMOrderByArgument,
+ Union[Literal[None, False], Tuple[ColumnElement[Any], ...]],
+ ]
+ foreign_keys: _RelationshipArg[
+ Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
+ ]
+ remote_side: _RelationshipArg[
+ Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
+ ]
+
+
@log.class_logger
class Relationship(
_IntrospectsAnnotations, StrategizedProperty[_T], log.Identified
_links_to_entity = True
_is_relationship = True
+ _overlaps: Sequence[str]
+
+ _lazy_strategy: LazyLoader
+
_persistence_only = dict(
passive_deletes=False,
passive_updates=True,
cascade_backrefs=False,
)
- _dependency_processor = None
+ _dependency_processor: Optional[DependencyProcessor] = None
+
+ primaryjoin: ColumnElement[bool]
+ secondaryjoin: Optional[ColumnElement[bool]]
+ secondary: Optional[FromClause]
+ _join_condition: JoinCondition
+ order_by: Union[Literal[False], Tuple[ColumnElement[Any], ...]]
+
+ _user_defined_foreign_keys: Set[ColumnElement[Any]]
+ _calculated_foreign_keys: Set[ColumnElement[Any]]
+
+ remote_side: Set[ColumnElement[Any]]
+ local_columns: Set[ColumnElement[Any]]
+
+ synchronize_pairs: _ColumnPairs
+ secondary_synchronize_pairs: Optional[_ColumnPairs]
+
+ local_remote_pairs: Optional[_ColumnPairs]
+
+ direction: RelationshipDirection
+
+ _init_args: _RelationshipArgs
def __init__(
self,
argument: Optional[_RelationshipArgumentType[_T]] = None,
- secondary=None,
+ secondary: Optional[Union[FromClause, str]] = None,
*,
- uselist=None,
- collection_class=None,
- primaryjoin=None,
- secondaryjoin=None,
- back_populates=None,
- order_by=False,
- backref=None,
- cascade_backrefs=False,
- overlaps=None,
- post_update=False,
- cascade="save-update, merge",
- viewonly=False,
+ uselist: Optional[bool] = None,
+ collection_class: Optional[
+ Union[Type[Collection[Any]], Callable[[], Collection[Any]]]
+ ] = None,
+ primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+ secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+ back_populates: Optional[str] = None,
+ order_by: _ORMOrderByArgument = False,
+ backref: Optional[_ORMBackrefArgument] = None,
+ overlaps: Optional[str] = None,
+ post_update: bool = False,
+ cascade: str = "save-update, merge",
+ viewonly: bool = False,
lazy: _LazyLoadArgumentType = "select",
- passive_deletes=False,
- passive_updates=True,
- active_history=False,
- enable_typechecks=True,
- foreign_keys=None,
- remote_side=None,
- join_depth=None,
- comparator_factory=None,
- single_parent=False,
- innerjoin=False,
- distinct_target_key=None,
- load_on_pending=False,
- query_class=None,
- info=None,
- omit_join=None,
- sync_backref=None,
- doc=None,
- bake_queries=True,
- _local_remote_pairs=None,
- _legacy_inactive_history_style=False,
+ passive_deletes: Union[Literal["all"], bool] = False,
+ passive_updates: bool = True,
+ active_history: bool = False,
+ enable_typechecks: bool = True,
+ foreign_keys: Optional[_ORMColCollectionArgument] = None,
+ remote_side: Optional[_ORMColCollectionArgument] = None,
+ join_depth: Optional[int] = None,
+ comparator_factory: Optional[
+ Type[Relationship.Comparator[Any]]
+ ] = None,
+ single_parent: bool = False,
+ innerjoin: bool = False,
+ distinct_target_key: Optional[bool] = None,
+ load_on_pending: bool = False,
+ query_class: Optional[Type[Query[Any]]] = None,
+ info: Optional[_InfoType] = None,
+ omit_join: Literal[None, False] = None,
+ sync_backref: Optional[bool] = None,
+ doc: Optional[str] = None,
+ bake_queries: Literal[True] = True,
+ cascade_backrefs: Literal[False] = False,
+ _local_remote_pairs: Optional[_ColumnPairs] = None,
+ _legacy_inactive_history_style: bool = False,
):
super(Relationship, self).__init__()
self.uselist = uselist
self.argument = argument
- self.secondary = secondary
- self.primaryjoin = primaryjoin
- self.secondaryjoin = secondaryjoin
+
+ self._init_args = _RelationshipArgs(
+ _RelationshipArg("secondary", secondary, None),
+ _RelationshipArg("primaryjoin", primaryjoin, None),
+ _RelationshipArg("secondaryjoin", secondaryjoin, None),
+ _RelationshipArg("order_by", order_by, None),
+ _RelationshipArg("foreign_keys", foreign_keys, None),
+ _RelationshipArg("remote_side", remote_side, None),
+ )
+
self.post_update = post_update
- self.direction = None
self.viewonly = viewonly
if viewonly:
self._warn_for_persistence_only_flags(
self.sync_backref = sync_backref
self.lazy = lazy
self.single_parent = single_parent
- self._user_defined_foreign_keys = foreign_keys
self.collection_class = collection_class
self.passive_deletes = passive_deletes
)
self.passive_updates = passive_updates
- self.remote_side = remote_side
self.enable_typechecks = enable_typechecks
self.query_class = query_class
self.innerjoin = innerjoin
self.local_remote_pairs = _local_remote_pairs
self.load_on_pending = load_on_pending
self.comparator_factory = comparator_factory or Relationship.Comparator
- self.comparator = self.comparator_factory(self, None)
util.set_creation_order(self)
if info is not None:
- self.info = info
+ self.info.update(info)
self.strategy_key = (("lazy", self.lazy),)
- self._reverse_property = set()
+ self._reverse_property: Set[Relationship[Any]] = set()
+
if overlaps:
- self._overlaps = set(re.split(r"\s*,\s*", overlaps))
+ self._overlaps = set(re.split(r"\s*,\s*", overlaps)) # type: ignore # noqa: E501
else:
self._overlaps = ()
- self.cascade = cascade
-
- self.order_by = order_by
+ # mypy ignoring the @property setter
+ self.cascade = cascade # type: ignore
self.back_populates = back_populates
else:
self.backref = backref
- def _warn_for_persistence_only_flags(self, **kw):
+ def _warn_for_persistence_only_flags(self, **kw: Any) -> None:
for k, v in kw.items():
if v != self._persistence_only[k]:
# we are warning here rather than warn deprecated as this is a
"in a future release." % (k,)
)
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
attributes.register_descriptor(
mapper.class_,
self.key,
"_extra_criteria",
)
+ prop: RODescriptorReference[Relationship[_PT]]
+ _of_type: Optional[_EntityType[_PT]]
+
def __init__(
self,
- prop,
- parentmapper,
- adapt_to_entity=None,
- of_type=None,
- extra_criteria=(),
+ prop: Relationship[_PT],
+ parentmapper: _InternalEntityType[Any],
+ adapt_to_entity: Optional[AliasedInsp[Any]] = None,
+ of_type: Optional[_EntityType[_PT]] = None,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
):
"""Construction of :class:`.Relationship.Comparator`
is internal to the ORM's attribute mechanics.
self._of_type = None
self._extra_criteria = extra_criteria
- def adapt_to_entity(self, adapt_to_entity):
+ def adapt_to_entity(
+ self, adapt_to_entity: AliasedInsp[Any]
+ ) -> Relationship.Comparator[Any]:
return self.__class__(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=adapt_to_entity,
of_type=self._of_type,
)
- entity: _InternalEntityType
+ entity: _InternalEntityType[_PT]
"""The target entity referred to by this
:class:`.Relationship.Comparator`.
"""
- mapper: Mapper[Any]
+ mapper: Mapper[_PT]
"""The target :class:`_orm.Mapper` referred to by this
:class:`.Relationship.Comparator`.
"""
- def _memoized_attr_entity(self) -> _InternalEntityType:
+ def _memoized_attr_entity(self) -> _InternalEntityType[_PT]:
if self._of_type:
- return inspect(self._of_type)
+ return inspect(self._of_type) # type: ignore
else:
return self.prop.entity
- def _memoized_attr_mapper(self) -> Mapper[Any]:
+ def _memoized_attr_mapper(self) -> Mapper[_PT]:
return self.entity.mapper
- def _source_selectable(self):
+ def _source_selectable(self) -> FromClause:
if self._adapt_to_entity:
return self._adapt_to_entity.selectable
else:
return self.property.parent._with_polymorphic_selectable
- def __clause_element__(self):
+ def __clause_element__(self) -> ColumnElement[bool]:
adapt_from = self._source_selectable()
if self._of_type:
of_type_entity = inspect(self._of_type)
dest,
secondary,
target_adapter,
- ) = self.property._create_joins(
+ ) = self.prop._create_joins(
source_selectable=adapt_from,
source_polymorphic=True,
of_type_entity=of_type_entity,
else:
return pj
- def of_type(self, cls):
+ def of_type(self, class_: _EntityType[_PT]) -> PropComparator[_PT]:
r"""Redefine this object in terms of a polymorphic subclass.
See :meth:`.PropComparator.of_type` for an example.
"""
return Relationship.Comparator(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
- of_type=cls,
+ of_type=class_,
extra_criteria=self._extra_criteria,
)
def and_(
self, *criteria: _ColumnExpressionArgument[bool]
- ) -> PropComparator[bool]:
+ ) -> PropComparator[Any]:
"""Add AND criteria.
See :meth:`.PropComparator.and_` for an example.
)
return Relationship.Comparator(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
of_type=self._of_type,
extra_criteria=self._extra_criteria + exprs,
)
- def in_(self, other):
+ def in_(self, other: Any) -> NoReturn:
"""Produce an IN clause - this is not implemented
for :func:`_orm.relationship`-based attributes at this time.
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
"""Implement the ``==`` operator.
In a many-to-one context, such as::
or many-to-many context produce a NOT EXISTS clause.
"""
- if isinstance(other, (util.NoneType, expression.Null)):
+ if other is None or isinstance(other, expression.Null):
if self.property.direction in [ONETOMANY, MANYTOMANY]:
return ~self._criterion_exists()
else:
criterion: Optional[_ColumnExpressionArgument[bool]] = None,
**kwargs: Any,
) -> Exists:
+
+ where_criteria = (
+ coercions.expect(roles.WhereHavingRole, criterion)
+ if criterion is not None
+ else None
+ )
+
if getattr(self, "_of_type", None):
- info = inspect(self._of_type)
+ info: Optional[_InternalEntityType[Any]] = inspect(
+ self._of_type
+ )
+ assert info is not None
target_mapper, to_selectable, is_aliased_class = (
info.mapper,
info.selectable,
single_crit = target_mapper._single_table_criterion
if single_crit is not None:
- if criterion is not None:
- criterion = single_crit & criterion
+ if where_criteria is not None:
+ where_criteria = single_crit & where_criteria
else:
- criterion = single_crit
+ where_criteria = single_crit
else:
is_aliased_class = False
to_selectable = None
for k in kwargs:
crit = getattr(self.property.mapper.class_, k) == kwargs[k]
- if criterion is None:
- criterion = crit
+ if where_criteria is None:
+ where_criteria = crit
else:
- criterion = criterion & crit
+ where_criteria = where_criteria & crit
# annotate the *local* side of the join condition, in the case
# of pj + sj this is the full primaryjoin, in the case of just
j = _orm_annotate(pj, exclude=self.property.remote_side)
if (
- criterion is not None
+ where_criteria is not None
and target_adapter
and not is_aliased_class
):
# limit this adapter to annotated only?
- criterion = target_adapter.traverse(criterion)
+ where_criteria = target_adapter.traverse(where_criteria)
# only have the "joined left side" of what we
# return be subject to Query adaption. The right
# side of it is used for an exists() subquery and
# should not correlate or otherwise reach out
# to anything in the enclosing query.
- if criterion is not None:
- criterion = criterion._annotate(
+ if where_criteria is not None:
+ where_criteria = where_criteria._annotate(
{"no_replacement_traverse": True}
)
- crit = j & sql.True_._ifnone(criterion)
+ crit = j & sql.True_._ifnone(where_criteria)
if secondary is not None:
ex = (
)
return ex
- def any(self, criterion=None, **kwargs):
+ def any(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce an expression that tests a collection against
particular criterion, using EXISTS.
return self._criterion_exists(criterion, **kwargs)
- def has(self, criterion=None, **kwargs):
+ def has(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce an expression that tests a scalar reference against
particular criterion, using EXISTS.
)
return self._criterion_exists(criterion, **kwargs)
- def contains(self, other, **kwargs):
+ def contains(
+ self, other: _ColumnExpressionArgument[Any], **kwargs: Any
+ ) -> ColumnElement[bool]:
"""Return a simple expression that tests a collection for
containment of a particular item.
kwargs may be ignored by this operator but are required for API
conformance.
"""
- if not self.property.uselist:
+ if not self.prop.uselist:
raise sa_exc.InvalidRequestError(
"'contains' not implemented for scalar "
"attributes. Use =="
)
- clause = self.property._optimized_compare(
+
+ clause = self.prop._optimized_compare(
other, adapt_source=self.adapter
)
- if self.property.secondaryjoin is not None:
+ if self.prop.secondaryjoin is not None:
clause.negation_clause = self.__negated_contains_or_equals(
other
)
return clause
- def __negated_contains_or_equals(self, other):
- if self.property.direction == MANYTOONE:
+ def __negated_contains_or_equals(
+ self, other: Any
+ ) -> ColumnElement[bool]:
+ if self.prop.direction == MANYTOONE:
state = attributes.instance_state(other)
- def state_bindparam(local_col, state, remote_col):
+ def state_bindparam(
+ local_col: ColumnElement[Any],
+ state: InstanceState[Any],
+ remote_col: ColumnElement[Any],
+ ) -> BindParameter[Any]:
dict_ = state.dict
return sql.bindparam(
local_col.key,
type_=local_col.type,
unique=True,
- callable_=self.property._get_attr_w_warn_on_none(
- self.property.mapper, state, dict_, remote_col
+ callable_=self.prop._get_attr_w_warn_on_none(
+ self.prop.mapper, state, dict_, remote_col
),
)
- def adapt(col):
+ def adapt(col: _CE) -> _CE:
if self.adapter:
return self.adapter(col)
else:
return ~self._criterion_exists(criterion)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
"""Implement the ``!=`` operator.
In a many-to-one context, such as::
or many-to-many context produce an EXISTS clause.
"""
- if isinstance(other, (util.NoneType, expression.Null)):
+ if other is None or isinstance(other, expression.Null):
if self.property.direction == MANYTOONE:
return _orm_annotate(
~self.property._optimized_compare(
else:
return _orm_annotate(self.__negated_contains_or_equals(other))
- def _memoized_attr_property(self):
+ def _memoized_attr_property(self) -> Relationship[_PT]:
self.prop.parent._check_configure()
return self.prop
- comparator: Comparator[_T]
-
def _with_parent(
self,
instance: object,
from_entity: Optional[_EntityType[Any]] = None,
) -> ColumnElement[bool]:
assert instance is not None
- adapt_source = None
+ adapt_source: Optional[_CoreAdapterProto] = None
if from_entity is not None:
- insp = inspect(from_entity)
- if insp.is_aliased_class:
+ insp: Optional[_InternalEntityType[Any]] = inspect(from_entity)
+ assert insp is not None
+ if insp_is_aliased_class(insp):
adapt_source = insp._adapter.adapt_clause
return self._optimized_compare(
instance,
def _optimized_compare(
self,
- state,
- value_is_parent=False,
- adapt_source=None,
- alias_secondary=True,
- ):
+ state: Any,
+ value_is_parent: bool = False,
+ adapt_source: Optional[_CoreAdapterProto] = None,
+ alias_secondary: bool = True,
+ ) -> ColumnElement[bool]:
if state is not None:
try:
state = inspect(state)
dict_ = attributes.instance_dict(state.obj())
- def visit_bindparam(bindparam):
+ def visit_bindparam(bindparam: BindParameter[Any]) -> None:
if bindparam._identifying_key in bind_to_col:
bindparam.callable = self._get_attr_w_warn_on_none(
mapper,
criterion = adapt_source(criterion)
return criterion
- def _get_attr_w_warn_on_none(self, mapper, state, dict_, column):
+ def _get_attr_w_warn_on_none(
+ self,
+ mapper: Mapper[Any],
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ column: ColumnElement[Any],
+ ) -> Callable[[], Any]:
"""Create the callable that is used in a many-to-one expression.
E.g.::
# this feature was added explicitly for use in this method.
state._track_last_known_value(prop.key)
- def _go():
- last_known = to_return = state._last_known_values[prop.key]
- existing_is_available = last_known is not attributes.NO_VALUE
+ lkv_fixed = state._last_known_values
+
+ def _go() -> Any:
+ assert lkv_fixed is not None
+ last_known = to_return = lkv_fixed[prop.key]
+ existing_is_available = (
+ last_known is not LoaderCallableStatus.NO_VALUE
+ )
# we support that the value may have changed. so here we
# try to get the most recent value including re-fetching.
state,
dict_,
column,
- passive=attributes.PASSIVE_OFF
+ passive=PassiveFlag.PASSIVE_OFF
if state.persistent
- else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK,
+ else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK,
)
- if current_value is attributes.NEVER_SET:
+ if current_value is LoaderCallableStatus.NEVER_SET:
if not existing_is_available:
raise sa_exc.InvalidRequestError(
"Can't resolve value for column %s on object "
"%s; no value has been set for this column"
% (column, state_str(state))
)
- elif current_value is attributes.PASSIVE_NO_RESULT:
+ elif current_value is LoaderCallableStatus.PASSIVE_NO_RESULT:
if not existing_is_available:
raise sa_exc.InvalidRequestError(
"Can't resolve value for column %s on object "
return _go
- def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
+ def _lazy_none_clause(
+ self,
+ reverse_direction: bool = False,
+ adapt_source: Optional[_CoreAdapterProto] = None,
+ ) -> ColumnElement[bool]:
if not reverse_direction:
criterion, bind_to_col = (
self._lazy_strategy._lazywhere,
criterion = adapt_source(criterion)
return criterion
- def __str__(self):
+ def __str__(self) -> str:
return str(self.parent.class_.__name__) + "." + self.key
def merge(
self,
- session,
- source_state,
- source_dict,
- dest_state,
- dest_dict,
- load,
- _recursive,
- _resolve_conflict_map,
- ):
+ session: Session,
+ source_state: InstanceState[Any],
+ source_dict: _InstanceDict,
+ dest_state: InstanceState[Any],
+ dest_dict: _InstanceDict,
+ load: bool,
+ _recursive: Dict[Any, object],
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+ ) -> None:
if load:
for r in self._reverse_property:
if self.uselist:
impl = source_state.get_impl(self.key)
+
+ assert is_has_collection_adapter(impl)
instances_iterable = impl.get_collection(source_state, source_dict)
# if this is a CollectionAttributeImpl, then empty should
for c in dest_list:
coll.append_without_event(c)
else:
- dest_state.get_impl(self.key).set(
- dest_state, dest_dict, dest_list, _adapt=False
- )
+ dest_impl = dest_state.get_impl(self.key)
+ assert is_has_collection_adapter(dest_impl)
+ dest_impl.set(dest_state, dest_dict, dest_list, _adapt=False)
else:
current = source_dict[self.key]
if current is not None:
)
def _value_as_iterable(
- self, state, dict_, key, passive=attributes.PASSIVE_OFF
- ):
+ self,
+ state: InstanceState[_O],
+ dict_: _InstanceDict,
+ key: str,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> Sequence[Tuple[InstanceState[_O], _O]]:
"""Return a list of tuples (state, obj) for the given
key.
impl = state.manager[key].impl
x = impl.get(state, dict_, passive=passive)
- if x is attributes.PASSIVE_NO_RESULT or x is None:
+ if x is LoaderCallableStatus.PASSIVE_NO_RESULT or x is None:
return []
- elif hasattr(impl, "get_collection"):
+ elif is_has_collection_adapter(impl):
return [
(attributes.instance_state(o), o)
for o in impl.get_collection(state, dict_, x, passive=passive)
return [(attributes.instance_state(x), x)]
def cascade_iterator(
- self, type_, state, dict_, visited_states, halt_on=None
- ):
+ self,
+ type_: str,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ visited_states: Set[InstanceState[Any]],
+ halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None,
+ ) -> Iterator[Tuple[Any, Mapper[Any], InstanceState[Any], _InstanceDict]]:
# assert type_ in self._cascade
# only actively lazy load on the 'delete' cascade
if type_ != "delete" or self.passive_deletes:
- passive = attributes.PASSIVE_NO_INITIALIZE
+ passive = PassiveFlag.PASSIVE_NO_INITIALIZE
else:
- passive = attributes.PASSIVE_OFF
+ passive = PassiveFlag.PASSIVE_OFF
if type_ == "save-update":
tuples = state.manager[self.key].impl.get_all_pending(state, dict_)
-
else:
tuples = self._value_as_iterable(
state, dict_, self.key, passive=passive
# see [ticket:2229]
continue
+ assert instance_state is not None
instance_dict = attributes.instance_dict(c)
if halt_on and halt_on(instance_state):
yield c, instance_mapper, instance_state, instance_dict
@property
- def _effective_sync_backref(self):
+ def _effective_sync_backref(self) -> bool:
if self.viewonly:
return False
else:
return self.sync_backref is not False
@staticmethod
- def _check_sync_backref(rel_a, rel_b):
+ def _check_sync_backref(
+ rel_a: Relationship[Any], rel_b: Relationship[Any]
+ ) -> None:
if rel_a.viewonly and rel_b.sync_backref:
raise sa_exc.InvalidRequestError(
"Relationship %s cannot specify sync_backref=True since %s "
):
rel_b.sync_backref = False
- def _add_reverse_property(self, key):
+ def _add_reverse_property(self, key: str) -> None:
other = self.mapper.get_property(key, _configure_mappers=False)
if not isinstance(other, Relationship):
raise sa_exc.InvalidRequestError(
)
if (
- self.direction in (ONETOMANY, MANYTOONE)
+ other._configure_started
+ and self.direction in (ONETOMANY, MANYTOONE)
and self.direction == other.direction
):
raise sa_exc.ArgumentError(
)
@util.memoized_property
- def entity(self) -> Union["Mapper", "AliasedInsp"]:
+ def entity(self) -> _InternalEntityType[_T]:
"""Return the target mapped entity, which is an inspect() of the
class or aliased class that is referred towards.
"""
return self.entity.mapper
- def do_init(self):
+ def do_init(self) -> None:
self._check_conflicts()
self._process_dependent_arguments()
self._setup_entity()
self._generate_backref()
self._join_condition._warn_for_conflicting_sync_targets()
super(Relationship, self).do_init()
- self._lazy_strategy = self._get_strategy((("lazy", "select"),))
+ self._lazy_strategy = cast(
+ "LazyLoader", self._get_strategy((("lazy", "select"),))
+ )
- def _setup_registry_dependencies(self):
+ def _setup_registry_dependencies(self) -> None:
self.parent.mapper.registry._set_depends_on(
self.entity.mapper.registry
)
- def _process_dependent_arguments(self):
+ def _process_dependent_arguments(self) -> None:
"""Convert incoming configuration arguments to their
proper form.
# accept callables for other attributes which may require
# deferred initialization. This technique is used
# by declarative "string configs" and some recipes.
+ init_args = self._init_args
+
for attr in (
"order_by",
"primaryjoin",
"secondaryjoin",
"secondary",
- "_user_defined_foreign_keys",
+ "foreign_keys",
"remote_side",
):
- attr_value = getattr(self, attr)
-
- if isinstance(attr_value, str):
- setattr(
- self,
- attr,
- self._clsregistry_resolve_arg(
- attr_value, favor_tables=attr == "secondary"
- )(),
- )
- elif callable(attr_value) and not _is_mapped_class(attr_value):
- setattr(self, attr, attr_value())
+
+ rel_arg = getattr(init_args, attr)
+
+ rel_arg._resolve_against_registry(self._clsregistry_resolvers[1])
# remove "annotations" which are present if mapped class
# descriptors are used to create the join expression.
for attr in "primaryjoin", "secondaryjoin":
- val = getattr(self, attr)
+ rel_arg = getattr(init_args, attr)
+ val = rel_arg.resolved
if val is not None:
- setattr(
- self,
- attr,
- _orm_deannotate(
- coercions.expect(
- roles.ColumnArgumentRole, val, argname=attr
- )
- ),
+ rel_arg.resolved = _orm_deannotate(
+ coercions.expect(
+ roles.ColumnArgumentRole, val, argname=attr
+ )
)
- if self.secondary is not None and _is_mapped_class(self.secondary):
+ secondary = init_args.secondary.resolved
+ if secondary is not None and _is_mapped_class(secondary):
raise sa_exc.ArgumentError(
"secondary argument %s passed to to relationship() %s must "
"be a Table object or other FROM clause; can't send a mapped "
"class directly as rows in 'secondary' are persisted "
"independently of a class that is mapped "
- "to that same table." % (self.secondary, self)
+ "to that same table." % (secondary, self)
)
# ensure expressions in self.order_by, foreign_keys,
# remote_side are all columns, not strings.
- if self.order_by is not False and self.order_by is not None:
+ if (
+ init_args.order_by.resolved is not False
+ and init_args.order_by.resolved is not None
+ ):
self.order_by = tuple(
coercions.expect(
roles.ColumnArgumentRole, x, argname="order_by"
)
- for x in util.to_list(self.order_by)
+ for x in util.to_list(init_args.order_by.resolved)
)
+ else:
+ self.order_by = False
self._user_defined_foreign_keys = util.column_set(
coercions.expect(
roles.ColumnArgumentRole, x, argname="foreign_keys"
)
- for x in util.to_column_set(self._user_defined_foreign_keys)
+ for x in util.to_column_set(init_args.foreign_keys.resolved)
)
self.remote_side = util.column_set(
coercions.expect(
roles.ColumnArgumentRole, x, argname="remote_side"
)
- for x in util.to_column_set(self.remote_side)
+ for x in util.to_column_set(init_args.remote_side.resolved)
)
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
argument = _extract_mapped_subtype(
annotation,
cls,
if hasattr(argument, "__origin__"):
- collection_class = argument.__origin__
+ collection_class = argument.__origin__ # type: ignore
if issubclass(collection_class, abc.Collection):
if self.collection_class is None:
self.collection_class = collection_class
else:
self.uselist = False
- if argument.__args__:
- if issubclass(argument.__origin__, typing.Mapping):
- type_arg = argument.__args__[1]
+ if argument.__args__: # type: ignore
+ if issubclass(
+ argument.__origin__, typing.Mapping # type: ignore
+ ):
+ type_arg = argument.__args__[1] # type: ignore
else:
- type_arg = argument.__args__[0]
+ type_arg = argument.__args__[0] # type: ignore
if hasattr(type_arg, "__forward_arg__"):
str_argument = type_arg.__forward_arg__
argument = str_argument
f"Generic alias {argument} requires an argument"
)
elif hasattr(argument, "__forward_arg__"):
- argument = argument.__forward_arg__
+ argument = argument.__forward_arg__ # type: ignore
self.argument = argument
@util.preload_module("sqlalchemy.orm.mapper")
- def _setup_entity(self, __argument=None):
+ def _setup_entity(self, __argument: Any = None) -> None:
if "entity" in self.__dict__:
return
else:
argument = self.argument
+ resolved_argument: _ExternalEntityType[Any]
+
if isinstance(argument, str):
- argument = self._clsregistry_resolve_name(argument)()
+ # we might want to cleanup clsregistry API to make this
+ # more straightforward
+ resolved_argument = cast(
+ "_ExternalEntityType[Any]",
+ self._clsregistry_resolve_name(argument)(),
+ )
elif callable(argument) and not isinstance(
argument, (type, mapperlib.Mapper)
):
- argument = argument()
+ resolved_argument = argument()
else:
- argument = argument
+ resolved_argument = argument
- if isinstance(argument, type):
- entity = class_mapper(argument, configure=False)
+ entity: _InternalEntityType[Any]
+
+ if isinstance(resolved_argument, type):
+ entity = class_mapper(resolved_argument, configure=False)
else:
try:
- entity = inspect(argument)
+ entity = inspect(resolved_argument)
except sa_exc.NoInspectionAvailable:
- entity = None
+ entity = None # type: ignore
if not hasattr(entity, "mapper"):
raise sa_exc.ArgumentError(
"relationship '%s' expects "
"a class or a mapper argument (received: %s)"
- % (self.key, type(argument))
+ % (self.key, type(resolved_argument))
)
self.entity = entity # type: ignore
self.target = self.entity.persist_selectable
- def _setup_join_conditions(self):
+ def _setup_join_conditions(self) -> None:
self._join_condition = jc = JoinCondition(
parent_persist_selectable=self.parent.persist_selectable,
child_persist_selectable=self.entity.persist_selectable,
parent_local_selectable=self.parent.local_table,
child_local_selectable=self.entity.local_table,
- primaryjoin=self.primaryjoin,
- secondary=self.secondary,
- secondaryjoin=self.secondaryjoin,
+ primaryjoin=self._init_args.primaryjoin.resolved,
+ secondary=self._init_args.secondary.resolved,
+ secondaryjoin=self._init_args.secondaryjoin.resolved,
parent_equivalents=self.parent._equivalent_columns,
child_equivalents=self.mapper._equivalent_columns,
consider_as_foreign_keys=self._user_defined_foreign_keys,
)
self.primaryjoin = jc.primaryjoin
self.secondaryjoin = jc.secondaryjoin
+ self.secondary = jc.secondary
self.direction = jc.direction
self.local_remote_pairs = jc.local_remote_pairs
self.remote_side = jc.remote_columns
self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
@property
- def _clsregistry_resolve_arg(self):
+ def _clsregistry_resolve_arg(
+ self,
+ ) -> Callable[[str, bool], _class_resolver]:
return self._clsregistry_resolvers[1]
@property
- def _clsregistry_resolve_name(self):
+ def _clsregistry_resolve_name(
+ self,
+ ) -> Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]]:
return self._clsregistry_resolvers[0]
@util.memoized_property
@util.preload_module("sqlalchemy.orm.clsregistry")
- def _clsregistry_resolvers(self):
+ def _clsregistry_resolvers(
+ self,
+ ) -> Tuple[
+ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+ Callable[[str, bool], _class_resolver],
+ ]:
_resolver = util.preloaded.orm_clsregistry._resolver
return _resolver(self.parent.class_, self)
- def _check_conflicts(self):
+ def _check_conflicts(self) -> None:
"""Test that this relationship is legal, warn about
inheritance conflicts."""
if self.parent.non_primary and not class_mapper(
return self._cascade
@cascade.setter
- def cascade(self, cascade: Union[str, CascadeOptions]):
+ def cascade(self, cascade: Union[str, CascadeOptions]) -> None:
self._set_cascade(cascade)
- def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]):
+ def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]) -> None:
cascade = CascadeOptions(cascade_arg)
if self.viewonly:
if self._dependency_processor:
self._dependency_processor.cascade = cascade
- def _check_cascade_settings(self, cascade):
+ def _check_cascade_settings(self, cascade: CascadeOptions) -> None:
if (
cascade.delete_orphan
and not self.single_parent
(self.key, self.parent.class_)
)
- def _persists_for(self, mapper):
+ def _persists_for(self, mapper: Mapper[Any]) -> bool:
"""Return True if this property will persist values on behalf
of the given mapper.
and mapper.relationships[self.key] is self
)
- def _columns_are_mapped(self, *cols):
+ def _columns_are_mapped(self, *cols: ColumnElement[Any]) -> bool:
"""Return True if all columns in the given collection are
mapped by the tables referenced by this :class:`.Relationship`.
"""
+
+ secondary = self._init_args.secondary.resolved
for c in cols:
- if (
- self.secondary is not None
- and self.secondary.c.contains_column(c)
- ):
+ if secondary is not None and secondary.c.contains_column(c):
continue
if not self.parent.persist_selectable.c.contains_column(
c
return False
return True
- def _generate_backref(self):
+ def _generate_backref(self) -> None:
"""Interpret the 'backref' instruction to create a
:func:`_orm.relationship` complementary to this one."""
if self.parent.non_primary:
return
if self.backref is not None and not self.back_populates:
+ kwargs: Dict[str, Any]
if isinstance(self.backref, str):
backref_key, kwargs = self.backref, {}
else:
self._add_reverse_property(self.back_populates)
@util.preload_module("sqlalchemy.orm.dependency")
- def _post_init(self):
+ def _post_init(self) -> None:
dependency = util.preloaded.orm_dependency
if self.uselist is None:
)(self)
@util.memoized_property
- def _use_get(self):
+ def _use_get(self) -> bool:
"""memoize the 'use_get' attribute of this RelationshipLoader's
lazyloader."""
return strategy.use_get
@util.memoized_property
- def _is_self_referential(self):
+ def _is_self_referential(self) -> bool:
return self.mapper.common_parent(self.parent)
def _create_joins(
self,
- source_polymorphic=False,
- source_selectable=None,
- dest_selectable=None,
- of_type_entity=None,
- alias_secondary=False,
- extra_criteria=(),
- ):
+ source_polymorphic: bool = False,
+ source_selectable: Optional[FromClause] = None,
+ dest_selectable: Optional[FromClause] = None,
+ of_type_entity: Optional[_InternalEntityType[Any]] = None,
+ alias_secondary: bool = False,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
+ ) -> Tuple[
+ ColumnElement[bool],
+ Optional[ColumnElement[bool]],
+ FromClause,
+ FromClause,
+ Optional[FromClause],
+ Optional[ClauseAdapter],
+ ]:
aliased = False
)
-def _annotate_columns(element, annotations):
- def clone(elem):
+def _annotate_columns(element: _CE, annotations: _AnnotationDict) -> _CE:
+ def clone(elem: _CE) -> _CE:
if isinstance(elem, expression.ColumnClause):
- elem = elem._annotate(annotations.copy())
+ elem = elem._annotate(annotations.copy()) # type: ignore
elem._copy_internals(clone=clone)
return elem
if element is not None:
element = clone(element)
- clone = None # remove gc cycles
+ clone = None # type: ignore # remove gc cycles
return element
class JoinCondition:
+
+ primaryjoin_initial: Optional[ColumnElement[bool]]
+ primaryjoin: ColumnElement[bool]
+ secondaryjoin: Optional[ColumnElement[bool]]
+ secondary: Optional[FromClause]
+ prop: Relationship[Any]
+
+ synchronize_pairs: _ColumnPairs
+ secondary_synchronize_pairs: _ColumnPairs
+ direction: RelationshipDirection
+
+ parent_persist_selectable: FromClause
+ child_persist_selectable: FromClause
+ parent_local_selectable: FromClause
+ child_local_selectable: FromClause
+
+ _local_remote_pairs: Optional[_ColumnPairs]
+
def __init__(
self,
- parent_persist_selectable,
- child_persist_selectable,
- parent_local_selectable,
- child_local_selectable,
- primaryjoin=None,
- secondary=None,
- secondaryjoin=None,
- parent_equivalents=None,
- child_equivalents=None,
- consider_as_foreign_keys=None,
- local_remote_pairs=None,
- remote_side=None,
- self_referential=False,
- prop=None,
- support_sync=True,
- can_be_synced_fn=lambda *c: True,
+ parent_persist_selectable: FromClause,
+ child_persist_selectable: FromClause,
+ parent_local_selectable: FromClause,
+ child_local_selectable: FromClause,
+ primaryjoin: Optional[ColumnElement[bool]] = None,
+ secondary: Optional[FromClause] = None,
+ secondaryjoin: Optional[ColumnElement[bool]] = None,
+ parent_equivalents: Optional[_EquivalentColumnMap] = None,
+ child_equivalents: Optional[_EquivalentColumnMap] = None,
+ consider_as_foreign_keys: Any = None,
+ local_remote_pairs: Optional[_ColumnPairs] = None,
+ remote_side: Any = None,
+ self_referential: Any = False,
+ prop: Optional[Relationship[Any]] = None,
+ support_sync: bool = True,
+ can_be_synced_fn: Callable[..., bool] = lambda *c: True,
):
self.parent_persist_selectable = parent_persist_selectable
self.parent_local_selectable = parent_local_selectable
self.child_local_selectable = child_local_selectable
self.parent_equivalents = parent_equivalents
self.child_equivalents = child_equivalents
- self.primaryjoin = primaryjoin
+ self.primaryjoin_initial = primaryjoin
self.secondaryjoin = secondaryjoin
self.secondary = secondary
self.consider_as_foreign_keys = consider_as_foreign_keys
self.self_referential = self_referential
self.support_sync = support_sync
self.can_be_synced_fn = can_be_synced_fn
+
self._determine_joins()
+ assert self.primaryjoin is not None
+
self._sanitize_joins()
self._annotate_fks()
self._annotate_remote()
self._check_remote_side()
self._log_joins()
- def _log_joins(self):
+ def _log_joins(self) -> None:
if self.prop is None:
return
log = self.prop.logger
)
log.info("%s relationship direction %s", self.prop, self.direction)
- def _sanitize_joins(self):
+ def _sanitize_joins(self) -> None:
"""remove the parententity annotation from our join conditions which
can leak in here based on some declarative patterns and maybe others.
self.secondaryjoin, values=("parententity", "proxy_key")
)
- def _determine_joins(self):
+ def _determine_joins(self) -> None:
"""Determine the 'primaryjoin' and 'secondaryjoin' attributes,
if not passed to the constructor already.
a_subset=self.child_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
- if self.primaryjoin is None:
+ if self.primaryjoin_initial is None:
self.primaryjoin = join_condition(
self.parent_persist_selectable,
self.secondary,
a_subset=self.parent_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
+ else:
+ self.primaryjoin = self.primaryjoin_initial
else:
- if self.primaryjoin is None:
+ if self.primaryjoin_initial is None:
self.primaryjoin = join_condition(
self.parent_persist_selectable,
self.child_persist_selectable,
a_subset=self.parent_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
+ else:
+ self.primaryjoin = self.primaryjoin_initial
except sa_exc.NoForeignKeysError as nfe:
if self.secondary is not None:
raise sa_exc.NoForeignKeysError(
) from afe
@property
- def primaryjoin_minus_local(self):
+ def primaryjoin_minus_local(self) -> ColumnElement[bool]:
return _deep_deannotate(self.primaryjoin, values=("local", "remote"))
@property
- def secondaryjoin_minus_local(self):
+ def secondaryjoin_minus_local(self) -> ColumnElement[bool]:
+ assert self.secondaryjoin is not None
return _deep_deannotate(self.secondaryjoin, values=("local", "remote"))
@util.memoized_property
- def primaryjoin_reverse_remote(self):
+ def primaryjoin_reverse_remote(self) -> ColumnElement[bool]:
"""Return the primaryjoin condition suitable for the
"reverse" direction.
"""
if self._has_remote_annotations:
- def replace(element):
+ def replace(element: _CE, **kw: Any) -> Optional[_CE]:
if "remote" in element._annotations:
v = dict(element._annotations)
del v["remote"]
v["remote"] = True
return element._with_annotations(v)
+ return None
+
return visitors.replacement_traverse(self.primaryjoin, {}, replace)
else:
if self._has_foreign_annotations:
else:
return _deep_deannotate(self.primaryjoin)
- def _has_annotation(self, clause, annotation):
+ def _has_annotation(self, clause: ClauseElement, annotation: str) -> bool:
for col in visitors.iterate(clause, {}):
if annotation in col._annotations:
return True
return False
@util.memoized_property
- def _has_foreign_annotations(self):
+ def _has_foreign_annotations(self) -> bool:
return self._has_annotation(self.primaryjoin, "foreign")
@util.memoized_property
- def _has_remote_annotations(self):
+ def _has_remote_annotations(self) -> bool:
return self._has_annotation(self.primaryjoin, "remote")
- def _annotate_fks(self):
+ def _annotate_fks(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'foreign' annotations marking columns
considered as foreign.
else:
self._annotate_present_fks()
- def _annotate_from_fk_list(self):
- def check_fk(col):
- if col in self.consider_as_foreign_keys:
- return col._annotate({"foreign": True})
+ def _annotate_from_fk_list(self) -> None:
+ def check_fk(element: _CE, **kw: Any) -> Optional[_CE]:
+ if element in self.consider_as_foreign_keys:
+ return element._annotate({"foreign": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, check_fk
self.secondaryjoin, {}, check_fk
)
- def _annotate_present_fks(self):
+ def _annotate_present_fks(self) -> None:
if self.secondary is not None:
secondarycols = util.column_set(self.secondary.c)
else:
secondarycols = set()
- def is_foreign(a, b):
+ def is_foreign(
+ a: ColumnElement[Any], b: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
if isinstance(a, schema.Column) and isinstance(b, schema.Column):
if a.references(b):
return a
elif b in secondarycols and a not in secondarycols:
return b
- def visit_binary(binary):
+ return None
+
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
if not isinstance(
binary.left, sql.ColumnElement
) or not isinstance(binary.right, sql.ColumnElement):
self.secondaryjoin, {}, {"binary": visit_binary}
)
- def _refers_to_parent_table(self):
+ def _refers_to_parent_table(self) -> bool:
"""Return True if the join condition contains column
comparisons where both columns are in both tables.
"""
pt = self.parent_persist_selectable
mt = self.child_persist_selectable
- result = [False]
+ result = False
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
+ nonlocal result
c, f = binary.left, binary.right
if (
isinstance(c, expression.ColumnClause)
and mt.is_derived_from(c.table)
and mt.is_derived_from(f.table)
):
- result[0] = True
+ result = True
visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary})
- return result[0]
+ return result
- def _tables_overlap(self):
+ def _tables_overlap(self) -> bool:
"""Return True if parent/child tables have some overlap."""
return selectables_overlap(
self.parent_persist_selectable, self.child_persist_selectable
)
- def _annotate_remote(self):
+ def _annotate_remote(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'remote' annotations marking columns
considered as part of the 'remote' side.
else:
self._annotate_remote_distinct_selectables()
- def _annotate_remote_secondary(self):
+ def _annotate_remote_secondary(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when 'secondary' is present.
"""
- def repl(element):
- if self.secondary.c.contains_column(element):
+ assert self.secondary is not None
+ fixed_secondary = self.secondary
+
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
+ if fixed_secondary.c.contains_column(element):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
+
+ assert self.secondaryjoin is not None
self.secondaryjoin = visitors.replacement_traverse(
self.secondaryjoin, {}, repl
)
- def _annotate_selfref(self, fn, remote_side_given):
+ def _annotate_selfref(
+ self, fn: Callable[[ColumnElement[Any]], bool], remote_side_given: bool
+ ) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the relationship is detected as self-referential.
"""
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
equated = binary.left.compare(binary.right)
if isinstance(binary.left, expression.ColumnClause) and isinstance(
binary.right, expression.ColumnClause
self.primaryjoin, {}, {"binary": visit_binary}
)
- def _annotate_remote_from_args(self):
+ def _annotate_remote_from_args(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the 'remote_side' or '_local_remote_pairs'
arguments are used.
self._annotate_selfref(lambda col: col in remote_side, True)
else:
- def repl(element):
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
# use set() to avoid generating ``__eq__()`` expressions
# against each element
if element in set(remote_side):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
- def _annotate_remote_with_overlap(self):
+ def _annotate_remote_with_overlap(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the parent/child tables have some set of
tables in common, though is not a fully self-referential
"""
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
binary.left, binary.right = proc_left_right(
binary.left, binary.right
)
self.prop is not None and self.prop.mapper is not self.prop.parent
)
- def proc_left_right(left, right):
+ def proc_left_right(
+ left: ColumnElement[Any], right: ColumnElement[Any]
+ ) -> Tuple[ColumnElement[Any], ColumnElement[Any]]:
if isinstance(left, expression.ColumnClause) and isinstance(
right, expression.ColumnClause
):
self.primaryjoin, {}, {"binary": visit_binary}
)
- def _annotate_remote_distinct_selectables(self):
+ def _annotate_remote_distinct_selectables(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the parent/child tables are entirely
separate.
"""
- def repl(element):
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
if self.child_persist_selectable.c.contains_column(element) and (
not self.parent_local_selectable.c.contains_column(element)
or self.child_local_selectable.c.contains_column(element)
):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
- def _warn_non_column_elements(self):
+ def _warn_non_column_elements(self) -> None:
util.warn(
"Non-simple column elements in primary "
"join condition for property %s - consider using "
"remote() annotations to mark the remote side." % self.prop
)
- def _annotate_local(self):
+ def _annotate_local(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'local' annotations.
else:
local_side = util.column_set(self.parent_persist_selectable.c)
- def locals_(elem):
- if "remote" not in elem._annotations and elem in local_side:
- return elem._annotate({"local": True})
+ def locals_(element: _CE, **kw: Any) -> Optional[_CE]:
+ if "remote" not in element._annotations and element in local_side:
+ return element._annotate({"local": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, locals_
)
- def _annotate_parentmapper(self):
+ def _annotate_parentmapper(self) -> None:
if self.prop is None:
return
- def parentmappers_(elem):
- if "remote" in elem._annotations:
- return elem._annotate({"parentmapper": self.prop.mapper})
- elif "local" in elem._annotations:
- return elem._annotate({"parentmapper": self.prop.parent})
+ def parentmappers_(element: _CE, **kw: Any) -> Optional[_CE]:
+ if "remote" in element._annotations:
+ return element._annotate({"parentmapper": self.prop.mapper})
+ elif "local" in element._annotations:
+ return element._annotate({"parentmapper": self.prop.parent})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, parentmappers_
)
- def _check_remote_side(self):
+ def _check_remote_side(self) -> None:
if not self.local_remote_pairs:
raise sa_exc.ArgumentError(
"Relationship %s could "
"the relationship." % (self.prop,)
)
- def _check_foreign_cols(self, join_condition, primary):
+ def _check_foreign_cols(
+ self, join_condition: ColumnElement[bool], primary: bool
+ ) -> None:
"""Check the foreign key columns collected and emit error
messages."""
)
raise sa_exc.ArgumentError(err)
- def _determine_direction(self):
+ def _determine_direction(self) -> None:
"""Determine if this relationship is one to many, many to one,
many to many.
"nor the child's mapped tables" % self.prop
)
- def _deannotate_pairs(self, collection):
+ def _deannotate_pairs(
+ self, collection: _ColumnPairIterable
+ ) -> _MutableColumnPairs:
"""provide deannotation for the various lists of
pairs, so that using them in hashes doesn't incur
high-overhead __eq__() comparisons against
"""
return [(x._deannotate(), y._deannotate()) for x, y in collection]
- def _setup_pairs(self):
- sync_pairs = []
- lrp = util.OrderedSet([])
- secondary_sync_pairs = []
-
- def go(joincond, collection):
- def visit_binary(binary, left, right):
+ def _setup_pairs(self) -> None:
+ sync_pairs: _MutableColumnPairs = []
+ lrp: util.OrderedSet[
+ Tuple[ColumnElement[Any], ColumnElement[Any]]
+ ] = util.OrderedSet([])
+ secondary_sync_pairs: _MutableColumnPairs = []
+
+ def go(
+ joincond: ColumnElement[bool],
+ collection: _MutableColumnPairs,
+ ) -> None:
+ def visit_binary(
+ binary: BinaryExpression[Any],
+ left: ColumnElement[Any],
+ right: ColumnElement[Any],
+ ) -> None:
if (
"remote" in right._annotations
and "remote" not in left._annotations
secondary_sync_pairs
)
- _track_overlapping_sync_targets = weakref.WeakKeyDictionary()
+ _track_overlapping_sync_targets: weakref.WeakKeyDictionary[
+ ColumnElement[Any],
+ weakref.WeakKeyDictionary[Relationship[Any], ColumnElement[Any]],
+ ] = weakref.WeakKeyDictionary()
- def _warn_for_conflicting_sync_targets(self):
+ def _warn_for_conflicting_sync_targets(self) -> None:
if not self.support_sync:
return
self._track_overlapping_sync_targets[to_][self.prop] = from_
@util.memoized_property
- def remote_columns(self):
+ def remote_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("remote")
@util.memoized_property
- def local_columns(self):
+ def local_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("local")
@util.memoized_property
- def foreign_key_columns(self):
+ def foreign_key_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("foreign")
- def _gather_join_annotations(self, annotation):
+ def _gather_join_annotations(
+ self, annotation: str
+ ) -> Set[ColumnElement[Any]]:
s = set(
self._gather_columns_with_annotation(self.primaryjoin, annotation)
)
)
return {x._deannotate() for x in s}
- def _gather_columns_with_annotation(self, clause, *annotation):
- annotation = set(annotation)
+ def _gather_columns_with_annotation(
+ self, clause: ColumnElement[Any], *annotation: Iterable[str]
+ ) -> Set[ColumnElement[Any]]:
+ annotation_set = set(annotation)
return set(
[
- col
+ cast(ColumnElement[Any], col)
for col in visitors.iterate(clause, {})
- if annotation.issubset(col._annotations)
+ if annotation_set.issubset(col._annotations)
]
)
def join_targets(
self,
- source_selectable,
- dest_selectable,
- aliased,
- single_crit=None,
- extra_criteria=(),
- ):
+ source_selectable: Optional[FromClause],
+ dest_selectable: FromClause,
+ aliased: bool,
+ single_crit: Optional[ColumnElement[bool]] = None,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
+ ) -> Tuple[
+ ColumnElement[bool],
+ Optional[ColumnElement[bool]],
+ Optional[FromClause],
+ Optional[ClauseAdapter],
+ FromClause,
+ ]:
"""Given a source and destination selectable, create a
join between them.
dest_selectable,
)
- def create_lazy_clause(self, reverse_direction=False):
- binds = util.column_dict()
- equated_columns = util.column_dict()
+ def create_lazy_clause(
+ self, reverse_direction: bool = False
+ ) -> Tuple[
+ ColumnElement[bool],
+ Dict[str, ColumnElement[Any]],
+ Dict[ColumnElement[Any], ColumnElement[Any]],
+ ]:
+ binds: Dict[ColumnElement[Any], BindParameter[Any]] = {}
+ equated_columns: Dict[ColumnElement[Any], ColumnElement[Any]] = {}
has_secondary = self.secondaryjoin is not None
for l, r in self.local_remote_pairs:
equated_columns[l] = r
- def col_to_bind(col):
+ def col_to_bind(
+ element: ColumnElement[Any], **kw: Any
+ ) -> Optional[BindParameter[Any]]:
if (
- (not reverse_direction and "local" in col._annotations)
+ (not reverse_direction and "local" in element._annotations)
or reverse_direction
and (
- (has_secondary and col in lookup)
- or (not has_secondary and "remote" in col._annotations)
+ (has_secondary and element in lookup)
+ or (not has_secondary and "remote" in element._annotations)
)
):
- if col not in binds:
- binds[col] = sql.bindparam(
- None, None, type_=col.type, unique=True
+ if element not in binds:
+ binds[element] = sql.bindparam(
+ None, None, type_=element.type, unique=True
)
- return binds[col]
+ return binds[element]
return None
lazywhere = self.primaryjoin
__slots__ = ("name",)
- def __init__(self, name):
+ def __init__(self, name: str):
self.name = name
- def __call__(self, c):
+ def __call__(self, c: ClauseElement) -> bool:
return self.name in c._annotations
from ..sql._typing import _T7
from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import Executable
+ from ..sql.base import ExecutableOption
from ..sql.elements import ClauseElement
from ..sql.roles import TypedColumnsClauseRole
from ..sql.selectable import TypedReturnsRows
self.session.dispatch.after_transaction_create(self.session, self)
def _raise_for_prerequisite_state(
- self, operation_name: str, state: SessionTransactionState
+ self, operation_name: str, state: _StateChangeState
) -> NoReturn:
if state is SessionTransactionState.DEACTIVE:
if self._rollback_exception:
primary_key_identity: _PKIdentityArgument,
db_load_fn: Callable[..., _O],
*,
- options: Optional[Sequence[ORMOption]] = None,
+ options: Optional[Sequence[ExecutableOption]] = None,
populate_existing: bool = False,
with_for_update: Optional[ForUpdateArg] = None,
identity_token: Optional[Any] = None,
*,
options: Optional[Sequence[ORMOption]] = None,
load: bool,
- _recursive: Dict[InstanceState[Any], object],
+ _recursive: Dict[Any, object],
_resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
) -> _O:
mapper: Mapper[_O] = _state_mapper(state)
...
+class _InstallLoaderCallableProto(Protocol[_O]):
+ """used at result loading time to install a _LoaderCallable callable
+ upon a specific InstanceState, which will be used to populate an
+ attribute when that attribute is accessed.
+
+ Concrete examples are per-instance deferred column loaders and
+ relationship lazy loaders.
+
+ """
+
+ def __call__(
+ self, state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
+ ) -> None:
+ ...
+
+
@inspection._self_inspects
class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
"""tracks state information at the instance level.
@classmethod
def _instance_level_callable_processor(
cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any
- ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]:
+ ) -> _InstallLoaderCallableProto[_O]:
impl = manager[key].impl
if is_collection_impl(impl):
fixed_impl = impl
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
"""State tracking utilities used by :class:`_orm.Session`.
"""
from enum import Enum
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Iterator
+from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import TypeVar
_next_state: _StateChangeState = _StateChangeStates.ANY
_state: _StateChangeState = _StateChangeStates.NO_CHANGE
- _current_fn: Optional[Callable] = None
+ _current_fn: Optional[Callable[..., Any]] = None
- def _raise_for_prerequisite_state(self, operation_name, state):
+ def _raise_for_prerequisite_state(
+ self, operation_name: str, state: _StateChangeState
+ ) -> NoReturn:
raise sa_exc.IllegalStateChangeError(
f"Can't run operation '{operation_name}()' when Session "
f"is in state {state!r}"
prerequisite_states is not _StateChangeStates.ANY
)
+ prerequisite_state_collection = cast(
+ "Tuple[_StateChangeState, ...]", prerequisite_states
+ )
expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE
@util.decorator
- def _go(fn, self, *arg, **kw):
+ def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
current_state = self._state
if (
has_prerequisite_states
- and current_state not in prerequisite_states
+ and current_state not in prerequisite_state_collection
):
self._raise_for_prerequisite_state(fn.__name__, current_state)
return _go
@contextlib.contextmanager
- def _expect_state(self, expected: _StateChangeState):
+ def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]:
"""called within a method that changes states.
method must also use the ``@declare_states()`` decorator.
import collections
import itertools
+from typing import Any
+from typing import Dict
+from typing import Tuple
+from typing import TYPE_CHECKING
from . import attributes
from . import exc as orm_exc
from .base import _DEFER_FOR_STATE
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
+from .base import LoaderCallableStatus
from .base import PASSIVE_OFF
+from .base import PassiveFlag
from .context import _column_descriptions
from .context import ORMCompileState
from .context import ORMSelectCompileState
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import Select
+if TYPE_CHECKING:
+ from .relationships import Relationship
+ from ..sql.elements import ColumnElement
+
def _register_attribute(
prop,
def _load_for_state(self, state, passive):
if not state.key:
- return attributes.ATTR_EMPTY
+ return LoaderCallableStatus.ATTR_EMPTY
- if not passive & attributes.SQL_OK:
- return attributes.PASSIVE_NO_RESULT
+ if not passive & PassiveFlag.SQL_OK:
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
localparent = state.manager.mapper
state.mapper, state, set(group), PASSIVE_OFF
)
- return attributes.ATTR_WAS_SET
+ return LoaderCallableStatus.ATTR_WAS_SET
def _invoke_raise_load(self, state, passive, lazy):
raise sa_exc.InvalidRequestError(
@relationships.Relationship.strategy_for(lazy="raise")
@relationships.Relationship.strategy_for(lazy="raise_on_sql")
@relationships.Relationship.strategy_for(lazy="baked_select")
-class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
+class LazyLoader(
+ AbstractRelationshipLoader, util.MemoizedSlots, log.Identified
+):
"""Provide loading behavior for a :class:`.Relationship`
with "lazy=True", that is loads when first accessed.
"_raise_on_sql",
)
- def __init__(self, parent, strategy_key):
+ _lazywhere: ColumnElement[bool]
+ _bind_to_col: Dict[str, ColumnElement[Any]]
+ _rev_lazywhere: ColumnElement[bool]
+ _rev_bind_to_col: Dict[str, ColumnElement[Any]]
+
+ parent_property: Relationship[Any]
+
+ def __init__(
+ self, parent: Relationship[Any], strategy_key: Tuple[Any, ...]
+ ):
super(LazyLoader, self).__init__(parent, strategy_key)
self._raise_always = self.strategy_opts["lazy"] == "raise"
self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql"
o = state.obj() # strong ref
dict_ = attributes.instance_dict(o)
- if passive & attributes.INIT_OK:
- passive ^= attributes.INIT_OK
+ if passive & PassiveFlag.INIT_OK:
+ passive ^= PassiveFlag.INIT_OK
params = {}
for key, ident, value in param_keys:
if ident is not None:
- if passive and passive & attributes.LOAD_AGAINST_COMMITTED:
+ if passive and passive & PassiveFlag.LOAD_AGAINST_COMMITTED:
value = mapper._get_committed_state_attr_by_column(
state, dict_, ident, passive
)
)
or not state.session_id
):
- return attributes.ATTR_EMPTY
+ return LoaderCallableStatus.ATTR_EMPTY
pending = not state.key
primary_key_identity = None
use_get = self.use_get and (not loadopt or not loadopt._extra_criteria)
- if (not passive & attributes.SQL_OK and not use_get) or (
+ if (not passive & PassiveFlag.SQL_OK and not use_get) or (
not passive & attributes.NON_PERSISTENT_OK and pending
):
- return attributes.PASSIVE_NO_RESULT
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
if (
# we were given lazy="raise"
self._raise_always
# the no_raise history-related flag was not passed
- and not passive & attributes.NO_RAISE
+ and not passive & PassiveFlag.NO_RAISE
and (
# if we are use_get and related_object_ok is disabled,
# which means we are at most looking in the identity map
# PASSIVE_NO_RESULT, don't raise. This is also a
# history-related flag
not use_get
- or passive & attributes.RELATED_OBJECT_OK
+ or passive & PassiveFlag.RELATED_OBJECT_OK
)
):
session = _state_session(state)
if not session:
- if passive & attributes.NO_RAISE:
- return attributes.PASSIVE_NO_RESULT
+ if passive & PassiveFlag.NO_RAISE:
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
raise orm_exc.DetachedInstanceError(
"Parent instance %s is not bound to a Session; "
primary_key_identity = self._get_ident_for_use_get(
session, state, passive
)
- if attributes.PASSIVE_NO_RESULT in primary_key_identity:
- return attributes.PASSIVE_NO_RESULT
- elif attributes.NEVER_SET in primary_key_identity:
- return attributes.NEVER_SET
+ if LoaderCallableStatus.PASSIVE_NO_RESULT in primary_key_identity:
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
+ elif LoaderCallableStatus.NEVER_SET in primary_key_identity:
+ return LoaderCallableStatus.NEVER_SET
if _none_set.issuperset(primary_key_identity):
return None
if (
self.key in state.dict
- and not passive & attributes.DEFERRED_HISTORY_LOAD
+ and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD
):
- return attributes.ATTR_WAS_SET
+ return LoaderCallableStatus.ATTR_WAS_SET
# look for this identity in the identity map. Delegate to the
# Query class in use, as it may have special rules for how it
)
if instance is not None:
- if instance is attributes.PASSIVE_CLASS_MISMATCH:
+ if instance is LoaderCallableStatus.PASSIVE_CLASS_MISMATCH:
return None
else:
return instance
elif (
- not passive & attributes.SQL_OK
- or not passive & attributes.RELATED_OBJECT_OK
+ not passive & PassiveFlag.SQL_OK
+ or not passive & PassiveFlag.RELATED_OBJECT_OK
):
- return attributes.PASSIVE_NO_RESULT
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
return self._emit_lazyload(
session,
def _get_ident_for_use_get(self, session, state, passive):
instance_mapper = state.manager.mapper
- if passive & attributes.LOAD_AGAINST_COMMITTED:
+ if passive & PassiveFlag.LOAD_AGAINST_COMMITTED:
get_attr = instance_mapper._get_committed_state_attr_by_column
else:
get_attr = instance_mapper._get_state_attr_by_column
stmt._compile_options += {"_current_path": effective_path}
if use_get:
- if self._raise_on_sql and not passive & attributes.NO_RAISE:
+ if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE:
self._invoke_raise_load(state, passive, "raise_on_sql")
return loading.load_on_pk_identity(
if (
self.key in state.dict
- and not passive & attributes.DEFERRED_HISTORY_LOAD
+ and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD
):
- return attributes.ATTR_WAS_SET
+ return LoaderCallableStatus.ATTR_WAS_SET
if pending:
if util.has_intersection(orm_util._none_set, params.values()):
elif util.has_intersection(orm_util._never_set, params.values()):
return None
- if self._raise_on_sql and not passive & attributes.NO_RAISE:
+ if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE:
self._invoke_raise_load(state, passive, "raise_on_sql")
stmt._where_criteria = (lazy_clause,)
# "use get" load. the "_RELATED" part means it may return
# instance even if its expired, since this is a mutually-recursive
# load operation.
- flags = attributes.PASSIVE_NO_FETCH_RELATED | attributes.NO_RAISE
+ flags = attributes.PASSIVE_NO_FETCH_RELATED | PassiveFlag.NO_RAISE
else:
- flags = attributes.PASSIVE_OFF | attributes.NO_RAISE
+ flags = attributes.PASSIVE_OFF | PassiveFlag.NO_RAISE
populators["delayed"].append((self.key, load_immediate))
# if the loaded parent objects do not have the foreign key
# to the related item loaded, then degrade into the joined
# version of selectinload
- if attributes.PASSIVE_NO_RESULT in related_ident:
+ if LoaderCallableStatus.PASSIVE_NO_RESULT in related_ident:
query_info = self._fallback_query_info
break
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: allow-untyped-defs, allow-untyped-calls
"""
import typing
from typing import Any
+from typing import Callable
from typing import cast
-from typing import Mapping
-from typing import NoReturn
+from typing import Dict
+from typing import Iterable
from typing import Optional
+from typing import overload
+from typing import Sequence
from typing import Tuple
+from typing import Type
+from typing import TypeVar
from typing import Union
from . import util as orm_util
+from ._typing import insp_is_aliased_class
+from ._typing import insp_is_attribute
+from ._typing import insp_is_mapper
+from ._typing import insp_is_mapper_property
+from .attributes import QueryableAttribute
from .base import InspectionAttr
from .interfaces import LoaderOption
from .path_registry import _DEFAULT_TOKEN
from .path_registry import _WILDCARD_TOKEN
+from .path_registry import AbstractEntityRegistry
+from .path_registry import path_is_property
from .path_registry import PathRegistry
from .path_registry import TokenRegistry
from .util import _orm_full_deannotate
from ..sql import traversals
from ..sql import visitors
from ..sql.base import _generative
+from ..util.typing import Final
+from ..util.typing import Literal
-_RELATIONSHIP_TOKEN = "relationship"
-_COLUMN_TOKEN = "column"
+_RELATIONSHIP_TOKEN: Final[Literal["relationship"]] = "relationship"
+_COLUMN_TOKEN: Final[Literal["column"]] = "column"
+
+_FN = TypeVar("_FN", bound="Callable[..., Any]")
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
+ from ._typing import _InternalEntityType
+ from .context import _MapperEntity
+ from .context import ORMCompileState
+ from .context import QueryContext
+ from .interfaces import _StrategyKey
+ from .interfaces import MapperProperty
from .mapper import Mapper
+ from .path_registry import _PathRepresentation
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _FromClauseArgument
+ from ..sql.cache_key import _CacheKeyTraversalType
+ from ..sql.cache_key import CacheKey
+
+Self_AbstractLoad = TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
+
+_AttrType = Union[str, "QueryableAttribute[Any]"]
-Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
+_WildcardKeyType = Literal["relationship", "column"]
+_StrategySpec = Dict[str, Any]
+_OptsType = Dict[str, Any]
+_AttrGroupType = Tuple[_AttrType, ...]
class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
_is_strategy_option = True
propagate_to_loaders: bool
- def contains_eager(self, attr, alias=None, _is_chain=False):
+ def contains_eager(
+ self: Self_AbstractLoad,
+ attr: _AttrType,
+ alias: Optional[_FromClauseArgument] = None,
+ _is_chain: bool = False,
+ ) -> Self_AbstractLoad:
r"""Indicate that the given attribute should be eagerly loaded from
columns stated manually in the query.
"""
if alias is not None:
if not isinstance(alias, str):
- info = inspect(alias)
- alias = info.selectable
-
+ coerced_alias = coercions.expect(roles.FromClauseRole, alias)
else:
util.warn_deprecated(
"Passing a string name for the 'alias' argument to "
"sqlalchemy.orm.aliased() construct.",
version="1.4",
)
+ coerced_alias = alias
elif getattr(attr, "_of_type", None):
- ot = inspect(attr._of_type)
- alias = ot.selectable
+ assert isinstance(attr, QueryableAttribute)
+ ot: Optional[_InternalEntityType[Any]] = inspect(attr._of_type)
+ assert ot is not None
+ coerced_alias = ot.selectable
+ else:
+ coerced_alias = None
cloned = self._set_relationship_strategy(
attr,
{"lazy": "joined"},
propagate_to_loaders=False,
- opts={"eager_from_alias": alias},
+ opts={"eager_from_alias": coerced_alias},
_reconcile_to_other=True if _is_chain else None,
)
return cloned
- def load_only(self, *attrs):
+ def load_only(
+ self: Self_AbstractLoad, *attrs: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that for a particular entity, only the given list
of column-based attribute names should be loaded; all others will be
deferred.
{"deferred": False, "instrument": True},
)
cloned = cloned._set_column_strategy(
- "*", {"deferred": True, "instrument": True}, {"undefer_pks": True}
+ ("*",),
+ {"deferred": True, "instrument": True},
+ {"undefer_pks": True},
)
return cloned
- def joinedload(self, attr, innerjoin=None):
+ def joinedload(
+ self: Self_AbstractLoad,
+ attr: _AttrType,
+ innerjoin: Optional[bool] = None,
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using joined
eager loading.
)
return loader
- def subqueryload(self, attr):
+ def subqueryload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using
subquery eager loading.
"""
return self._set_relationship_strategy(attr, {"lazy": "subquery"})
- def selectinload(self, attr):
+ def selectinload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using
SELECT IN eager loading.
"""
return self._set_relationship_strategy(attr, {"lazy": "selectin"})
- def lazyload(self, attr):
+ def lazyload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using "lazy"
loading.
"""
return self._set_relationship_strategy(attr, {"lazy": "select"})
- def immediateload(self, attr):
+ def immediateload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using
an immediate load with a per-attribute SELECT statement.
loader = self._set_relationship_strategy(attr, {"lazy": "immediate"})
return loader
- def noload(self, attr):
+ def noload(self: Self_AbstractLoad, attr: _AttrType) -> Self_AbstractLoad:
"""Indicate that the given relationship attribute should remain
unloaded.
return self._set_relationship_strategy(attr, {"lazy": "noload"})
- def raiseload(self, attr, sql_only=False):
+ def raiseload(
+ self: Self_AbstractLoad, attr: _AttrType, sql_only: bool = False
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should raise an error if accessed.
A relationship attribute configured with :func:`_orm.raiseload` will
attr, {"lazy": "raise_on_sql" if sql_only else "raise"}
)
- def defaultload(self, attr):
+ def defaultload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate an attribute should load using its default loader style.
This method is used to link to other loader options further into
"""
return self._set_relationship_strategy(attr, None)
- def defer(self, key, raiseload=False):
+ def defer(
+ self: Self_AbstractLoad, key: _AttrType, raiseload: bool = False
+ ) -> Self_AbstractLoad:
r"""Indicate that the given column-oriented attribute should be
deferred, e.g. not loaded until accessed.
strategy["raiseload"] = True
return self._set_column_strategy((key,), strategy)
- def undefer(self, key):
+ def undefer(self: Self_AbstractLoad, key: _AttrType) -> Self_AbstractLoad:
r"""Indicate that the given column-oriented attribute should be
undeferred, e.g. specified within the SELECT statement of the entity
as a whole.
Examples::
# undefer two columns
- session.query(MyClass).options(undefer("col1"), undefer("col2"))
+ session.query(MyClass).options(
+ undefer(MyClass.col1), undefer(MyClass.col2)
+ )
# undefer all columns specific to a single class using Load + *
session.query(MyClass, MyOtherClass).options(
# undefer a column on a related object
session.query(MyClass).options(
- defaultload(MyClass.items).undefer('text'))
+ defaultload(MyClass.items).undefer(MyClass.text))
:param key: Attribute to be undeferred.
(key,), {"deferred": False, "instrument": True}
)
- def undefer_group(self, name):
+ def undefer_group(self: Self_AbstractLoad, name: str) -> Self_AbstractLoad:
"""Indicate that columns within the given deferred group name should be
undeferred.
"""
return self._set_column_strategy(
- _WILDCARD_TOKEN, None, {f"undefer_group_{name}": True}
+ (_WILDCARD_TOKEN,), None, {f"undefer_group_{name}": True}
)
- def with_expression(self, key, expression):
+ def with_expression(
+ self: Self_AbstractLoad,
+ key: _AttrType,
+ expression: _ColumnExpressionArgument[Any],
+ ) -> Self_AbstractLoad:
r"""Apply an ad-hoc SQL expression to a "deferred expression"
attribute.
"""
- expression = coercions.expect(
- roles.LabeledColumnExprRole, _orm_full_deannotate(expression)
+ expression = _orm_full_deannotate(
+ coercions.expect(roles.LabeledColumnExprRole, expression)
)
return self._set_column_strategy(
(key,), {"query_expression": True}, opts={"expression": expression}
)
- def selectin_polymorphic(self, classes):
+ def selectin_polymorphic(
+ self: Self_AbstractLoad, classes: Iterable[Type[Any]]
+ ) -> Self_AbstractLoad:
"""Indicate an eager load should take place for all attributes
specific to a subclass.
)
return self
- def _coerce_strat(self, strategy):
+ @overload
+ def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey:
+ ...
+
+ @overload
+ def _coerce_strat(self, strategy: Literal[None]) -> None:
+ ...
+
+ def _coerce_strat(
+ self, strategy: Optional[_StrategySpec]
+ ) -> Optional[_StrategyKey]:
if strategy is not None:
- strategy = tuple(sorted(strategy.items()))
- return strategy
+ strategy_key = tuple(sorted(strategy.items()))
+ else:
+ strategy_key = None
+ return strategy_key
@_generative
def _set_relationship_strategy(
self: Self_AbstractLoad,
- attr,
- strategy,
- propagate_to_loaders=True,
- opts=None,
- _reconcile_to_other=None,
+ attr: _AttrType,
+ strategy: Optional[_StrategySpec],
+ propagate_to_loaders: bool = True,
+ opts: Optional[_OptsType] = None,
+ _reconcile_to_other: Optional[bool] = None,
) -> Self_AbstractLoad:
- strategy = self._coerce_strat(strategy)
+ strategy_key = self._coerce_strat(strategy)
self._clone_for_bind_strategy(
(attr,),
- strategy,
+ strategy_key,
_RELATIONSHIP_TOKEN,
opts=opts,
propagate_to_loaders=propagate_to_loaders,
@_generative
def _set_column_strategy(
- self: Self_AbstractLoad, attrs, strategy, opts=None
+ self: Self_AbstractLoad,
+ attrs: Tuple[_AttrType, ...],
+ strategy: Optional[_StrategySpec],
+ opts: Optional[_OptsType] = None,
) -> Self_AbstractLoad:
- strategy = self._coerce_strat(strategy)
+ strategy_key = self._coerce_strat(strategy)
self._clone_for_bind_strategy(
attrs,
- strategy,
+ strategy_key,
_COLUMN_TOKEN,
opts=opts,
attr_group=attrs,
@_generative
def _set_generic_strategy(
- self: Self_AbstractLoad, attrs, strategy, _reconcile_to_other=None
+ self: Self_AbstractLoad,
+ attrs: Tuple[_AttrType, ...],
+ strategy: _StrategySpec,
+ _reconcile_to_other: Optional[bool] = None,
) -> Self_AbstractLoad:
- strategy = self._coerce_strat(strategy)
+ strategy_key = self._coerce_strat(strategy)
self._clone_for_bind_strategy(
attrs,
- strategy,
+ strategy_key,
None,
propagate_to_loaders=True,
reconcile_to_other=_reconcile_to_other,
@_generative
def _set_class_strategy(
- self: Self_AbstractLoad, strategy, opts
+ self: Self_AbstractLoad, strategy: _StrategySpec, opts: _OptsType
) -> Self_AbstractLoad:
- strategy = self._coerce_strat(strategy)
+ strategy_key = self._coerce_strat(strategy)
- self._clone_for_bind_strategy(None, strategy, None, opts=opts)
+ self._clone_for_bind_strategy(None, strategy_key, None, opts=opts)
return self
- def _apply_to_parent(self, parent):
+ def _apply_to_parent(self, parent: Load) -> None:
"""apply this :class:`_orm._AbstractLoad` object as a sub-option o
a :class:`_orm.Load` object.
"""
raise NotImplementedError()
- def options(self: Self_AbstractLoad, *opts) -> NoReturn:
+ def options(
+ self: Self_AbstractLoad, *opts: _AbstractLoad
+ ) -> Self_AbstractLoad:
r"""Apply a series of options as sub-options to this
:class:`_orm._AbstractLoad` object.
raise NotImplementedError()
def _clone_for_bind_strategy(
- self,
- attrs,
- strategy,
- wildcard_key,
- opts=None,
- attr_group=None,
- propagate_to_loaders=True,
- reconcile_to_other=None,
- ):
+ self: Self_AbstractLoad,
+ attrs: Optional[Tuple[_AttrType, ...]],
+ strategy: Optional[_StrategyKey],
+ wildcard_key: Optional[_WildcardKeyType],
+ opts: Optional[_OptsType] = None,
+ attr_group: Optional[_AttrGroupType] = None,
+ propagate_to_loaders: bool = True,
+ reconcile_to_other: Optional[bool] = None,
+ ) -> Self_AbstractLoad:
raise NotImplementedError()
def process_compile_state_replaced_entities(
- self, compile_state, mapper_entities
- ):
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ ) -> None:
if not compile_state.compile_options._enable_eagerloads:
return
not bool(compile_state.current_path),
)
- def process_compile_state(self, compile_state):
+ def process_compile_state(self, compile_state: ORMCompileState) -> None:
if not compile_state.compile_options._enable_eagerloads:
return
and not compile_state.compile_options._for_refresh_state,
)
- def _process(self, compile_state, mapper_entities, raiseerr):
+ def _process(
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ raiseerr: bool,
+ ) -> None:
"""implemented by subclasses"""
raise NotImplementedError()
@classmethod
- def _chop_path(cls, to_chop, path, debug=False):
+ def _chop_path(
+ cls,
+ to_chop: _PathRepresentation,
+ path: PathRegistry,
+ debug: bool = False,
+ ) -> Optional[_PathRepresentation]:
i = -1
for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)):
return to_chop
elif (
c_token != f"{_RELATIONSHIP_TOKEN}:{_WILDCARD_TOKEN}"
- and c_token != p_token.key
+ and c_token != p_token.key # type: ignore
):
return None
continue
elif (
isinstance(c_token, InspectionAttr)
- and c_token.is_mapper
+ and insp_is_mapper(c_token)
and (
- (p_token.is_mapper and c_token.isa(p_token))
+ (insp_is_mapper(p_token) and c_token.isa(p_token))
or (
# a too-liberal check here to allow a path like
# A->A.bs->B->B.cs->C->C.ds, natural path, to chop
# test_of_type.py->test_all_subq_query
#
i >= 2
- and p_token.is_aliased_class
+ and insp_is_aliased_class(p_token)
and p_token._is_with_polymorphic
and c_token in p_token.with_polymorphic_mappers
- # and (breakpoint() or True)
)
)
):
return to_chop[i + 1 :]
-SelfLoad = typing.TypeVar("SelfLoad", bound="Load")
+SelfLoad = TypeVar("SelfLoad", bound="Load")
class Load(_AbstractLoad):
_cache_key_traversal = None
path: PathRegistry
- context: Tuple["_LoadElement", ...]
+ context: Tuple[_LoadElement, ...]
- def __init__(self, entity):
- insp = cast(Union["Mapper", AliasedInsp], inspect(entity))
+ def __init__(self, entity: _EntityType[Any]):
+ insp = cast("Union[Mapper[Any], AliasedInsp[Any]]", inspect(entity))
insp._post_inspect
self.path = insp._path_registry
self.context = ()
self.propagate_to_loaders = False
- def __str__(self):
+ def __str__(self) -> str:
return f"Load({self.path[0]})"
@classmethod
- def _construct_for_existing_path(cls, path):
+ def _construct_for_existing_path(cls, path: PathRegistry) -> Load:
load = cls.__new__(cls)
load.path = path
load.context = ()
load.propagate_to_loaders = False
return load
- def _adjust_for_extra_criteria(self, context):
+ def _adjust_for_extra_criteria(self, context: QueryContext) -> Load:
"""Apply the current bound parameters in a QueryContext to all
occurrences "extra_criteria" stored within this ``Load`` object,
returning a new instance of this ``Load`` object.
"""
orig_query = context.compile_state.select_statement
- orig_cache_key = None
- replacement_cache_key = None
+ orig_cache_key: Optional[CacheKey] = None
+ replacement_cache_key: Optional[CacheKey] = None
- def process(opt):
+ def process(opt: _LoadElement) -> _LoadElement:
if not opt._extra_criteria:
return opt
orig_cache_key = orig_query._generate_cache_key()
replacement_cache_key = context.query._generate_cache_key()
+ assert orig_cache_key is not None
+ assert replacement_cache_key is not None
+
opt._extra_criteria = tuple(
replacement_cache_key._apply_params_to_element(
orig_cache_key, crit
ezero = None
for ent in mapper_entities:
ezero = ent.entity_zero
- if ezero and orm_util._entity_corresponds_to(ezero, path[0]):
+ if ezero and orm_util._entity_corresponds_to(
+ # technically this can be a token also, but this is
+ # safe to pass to _entity_corresponds_to()
+ ezero,
+ cast("_InternalEntityType[Any]", path[0]),
+ ):
return ezero
return None
- def _process(self, compile_state, mapper_entities, raiseerr):
+ def _process(
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ raiseerr: bool,
+ ) -> None:
reconciled_lead_entity = self._reconcile_query_entities_with_us(
mapper_entities, raiseerr
raiseerr,
)
- def _apply_to_parent(self, parent):
+ def _apply_to_parent(self, parent: Load) -> None:
"""apply this :class:`_orm.Load` object as a sub-option of another
:class:`_orm.Load` object.
assert cloned.propagate_to_loaders == self.propagate_to_loaders
if not orm_util._entity_corresponds_to_use_path_impl(
- parent.path[-1], cloned.path[0]
+ cast("_InternalEntityType[Any]", parent.path[-1]),
+ cast("_InternalEntityType[Any]", cloned.path[0]),
):
raise sa_exc.ArgumentError(
f'Attribute "{cloned.path[1]}" does not link '
parent.context += cloned.context
@_generative
- def options(self: SelfLoad, *opts) -> SelfLoad:
+ def options(self: SelfLoad, *opts: _AbstractLoad) -> SelfLoad:
r"""Apply a series of options as sub-options to this
:class:`_orm.Load`
object.
return self
def _clone_for_bind_strategy(
- self,
- attrs,
- strategy,
- wildcard_key,
- opts=None,
- attr_group=None,
- propagate_to_loaders=True,
- reconcile_to_other=None,
- ) -> None:
+ self: SelfLoad,
+ attrs: Optional[Tuple[_AttrType, ...]],
+ strategy: Optional[_StrategyKey],
+ wildcard_key: Optional[_WildcardKeyType],
+ opts: Optional[_OptsType] = None,
+ attr_group: Optional[_AttrGroupType] = None,
+ propagate_to_loaders: bool = True,
+ reconcile_to_other: Optional[bool] = None,
+ ) -> SelfLoad:
# for individual strategy that needs to propagate, set the whole
# Load container to also propagate, so that it shows up in
# InstanceState.load_options
if propagate_to_loaders:
self.propagate_to_loaders = True
- if not self.path.has_entity:
- if self.path.is_token:
+ if self.path.is_token:
+ raise sa_exc.ArgumentError(
+ "Wildcard token cannot be followed by another entity"
+ )
+
+ elif path_is_property(self.path):
+ # re-use the lookup which will raise a nicely formatted
+ # LoaderStrategyException
+ if strategy:
+ self.path.prop._strategy_lookup(self.path.prop, strategy[0])
+ else:
raise sa_exc.ArgumentError(
- "Wildcard token cannot be followed by another entity"
+ f"Mapped attribute '{self.path.prop}' does not "
+ "refer to a mapped entity"
)
- else:
- # re-use the lookup which will raise a nicely formatted
- # LoaderStrategyException
- if strategy:
- self.path.prop._strategy_lookup(
- self.path.prop, strategy[0]
- )
- else:
- raise sa_exc.ArgumentError(
- f"Mapped attribute '{self.path.prop}' does not "
- "refer to a mapped entity"
- )
if attrs is None:
load_element = _ClassStrategyLoad.create(
if wildcard_key is _RELATIONSHIP_TOKEN:
self.path = load_element.path
self.context += (load_element,)
+ return self
def __getstate__(self):
d = self._shallow_to_dict()
self._shallow_from_dict(state)
-SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
+SelfWildcardLoad = TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
class _WildcardLoad(_AbstractLoad):
visitors.ExtendedInternalTraversal.dp_string_multi_dict,
),
]
- cache_key_traversal = None
+ cache_key_traversal: _CacheKeyTraversalType = None
strategy: Optional[Tuple[Any, ...]]
- local_opts: Mapping[str, Any]
+ local_opts: _OptsType
path: Tuple[str, ...]
propagate_to_loaders = False
- def __init__(self):
+ def __init__(self) -> None:
self.path = ()
self.strategy = None
self.local_opts = util.EMPTY_DICT
propagate_to_loaders=True,
reconcile_to_other=None,
):
+ assert attrs is not None
attr = attrs[0]
assert (
wildcard_key
if opts:
self.local_opts = util.immutabledict(opts)
- def options(self: SelfWildcardLoad, *opts) -> SelfWildcardLoad:
+ def options(
+ self: SelfWildcardLoad, *opts: _AbstractLoad
+ ) -> SelfWildcardLoad:
raise NotImplementedError("Star option does not support sub-options")
- def _apply_to_parent(self, parent):
+ def _apply_to_parent(self, parent: Load) -> None:
"""apply this :class:`_orm._WildcardLoad` object as a sub-option of
a :class:`_orm.Load` object.
it may be used as the sub-option of a :class:`_orm.Load` object.
"""
-
attr = self.path[0]
if attr.endswith(_DEFAULT_TOKEN):
attr = f"{attr.split(':')[0]}:{_WILDCARD_TOKEN}"
- effective_path = parent.path.token(attr)
+ effective_path = cast(AbstractEntityRegistry, parent.path).token(attr)
assert effective_path.is_token
entities = [ent.entity_zero for ent in mapper_entities]
current_path = compile_state.current_path
- start_path = self.path
+ start_path: _PathRepresentation = self.path
# TODO: chop_path already occurs in loader.process_compile_state()
# so we will seek to simplify this
if current_path:
- start_path = self._chop_path(start_path, current_path)
- if not start_path:
+ new_path = self._chop_path(start_path, current_path)
+ if not new_path:
return
+ start_path = new_path
# start_path is a single-token tuple
assert start_path and len(start_path) == 1
token = start_path[0]
-
+ assert isinstance(token, str)
entity = self._find_entity_basestring(entities, token, raiseerr)
if not entity:
# we just located, then go through the rest of our path
# tokens and populate into the Load().
+ assert isinstance(token, str)
loader = _TokenStrategyLoad.create(
path_element._path_registry,
token,
return loader
- def _find_entity_basestring(self, entities, token, raiseerr):
+ def _find_entity_basestring(
+ self,
+ entities: Iterable[_InternalEntityType[Any]],
+ token: str,
+ raiseerr: bool,
+ ) -> Optional[_InternalEntityType[Any]]:
if token.endswith(f":{_WILDCARD_TOKEN}"):
if len(list(entities)) != 1:
if raiseerr:
else:
return None
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
d = self._shallow_to_dict()
return d
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
self._shallow_from_dict(state)
_extra_criteria: Tuple[Any, ...]
_reconcile_to_other: Optional[bool]
- strategy: Tuple[Any, ...]
+ strategy: Optional[_StrategyKey]
path: PathRegistry
propagate_to_loaders: bool
- local_opts: Mapping[str, Any]
+ local_opts: util.immutabledict[str, Any]
is_token_strategy: bool
is_class_strategy: bool
- def __hash__(self):
+ def __hash__(self) -> int:
return id(self)
def __eq__(self, other):
return traversals.compare(self, other)
@property
- def is_opts_only(self):
+ def is_opts_only(self) -> bool:
return bool(self.local_opts and self.strategy is None)
- def _clone(self):
+ def _clone(self, **kw: Any) -> _LoadElement:
cls = self.__class__
s = cls.__new__(cls)
self._shallow_copy_to(s)
return s
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
d = self._shallow_to_dict()
d["path"] = self.path.serialize()
return d
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
state["path"] = PathRegistry.deserialize(state["path"])
self._shallow_from_dict(state)
)
def _adjust_effective_path_for_current_path(
- self, effective_path, current_path
- ):
+ self, effective_path: PathRegistry, current_path: PathRegistry
+ ) -> Optional[PathRegistry]:
"""receives the 'current_path' entry from an :class:`.ORMCompileState`
instance, which is set during lazy loads and secondary loader strategy
loads, and adjusts the given path to be relative to the
"""
- chopped_start_path = Load._chop_path(effective_path, current_path)
+ chopped_start_path = Load._chop_path(effective_path.path, current_path)
if not chopped_start_path:
return None
@classmethod
def create(
cls,
- path,
- attr,
- strategy,
- wildcard_key,
- local_opts,
- propagate_to_loaders,
- raiseerr=True,
- attr_group=None,
- reconcile_to_other=None,
- ):
+ path: PathRegistry,
+ attr: Optional[_AttrType],
+ strategy: Optional[_StrategyKey],
+ wildcard_key: Optional[_WildcardKeyType],
+ local_opts: Optional[_OptsType],
+ propagate_to_loaders: bool,
+ raiseerr: bool = True,
+ attr_group: Optional[_AttrGroupType] = None,
+ reconcile_to_other: Optional[bool] = None,
+ ) -> _LoadElement:
"""Create a new :class:`._LoadElement` object."""
opt = cls.__new__(cls)
path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr)
if not path:
- return None
+ return None # type: ignore
assert opt.is_token_strategy == path.is_token
opt.path = path
return opt
- def __init__(self, path, strategy, local_opts, propagate_to_loaders):
+ def __init__(self) -> None:
raise NotImplementedError()
def _prepend_path_from(self, parent):
assert cloned.is_class_strategy == self.is_class_strategy
if not orm_util._entity_corresponds_to_use_path_impl(
- parent.path[-1], cloned.path[0]
+ cast("_InternalEntityType[Any]", parent.path[-1]),
+ cast("_InternalEntityType[Any]", cloned.path[0]),
):
raise sa_exc.ArgumentError(
f'Attribute "{cloned.path[1]}" does not link '
return cloned
@staticmethod
- def _reconcile(replacement, existing):
+ def _reconcile(
+ replacement: _LoadElement, existing: _LoadElement
+ ) -> _LoadElement:
"""define behavior for when two Load objects are to be put into
the context.attributes under the same key.
),
]
- _of_type: Union["Mapper", AliasedInsp, None]
+ _of_type: Union["Mapper[Any]", "AliasedInsp[Any]", None]
_path_with_polymorphic_path: Optional[PathRegistry]
is_class_strategy = False
pwpi = inspect(
orm_util.AliasedInsp._with_polymorphic_factory(
pwpi.mapper.base_mapper,
- pwpi.mapper,
+ (pwpi.mapper,),
aliased=True,
_use_mapper_path=True,
)
start_path = self._path_with_polymorphic_path
if current_path:
- start_path = self._adjust_effective_path_for_current_path(
+ new_path = self._adjust_effective_path_for_current_path(
start_path, current_path
)
- if start_path is None:
+ if new_path is None:
return
+ start_path = new_path
key = ("path_with_polymorphic", start_path.natural_path)
if key in context:
effective_path = self.path
if current_path:
+ assert effective_path is not None
effective_path = self._adjust_effective_path_for_current_path(
effective_path, current_path
)
)
if current_path:
- effective_path = self._adjust_effective_path_for_current_path(
+ new_effective_path = self._adjust_effective_path_for_current_path(
effective_path, current_path
)
- if effective_path is None:
+ if new_effective_path is None:
return []
+ effective_path = new_effective_path
# for a wildcard token, expand out the path we set
# to encompass everything from the query entity on
effective_path = self.path
if current_path:
- effective_path = self._adjust_effective_path_for_current_path(
+ new_effective_path = self._adjust_effective_path_for_current_path(
effective_path, current_path
)
- if effective_path is None:
+ if new_effective_path is None:
return []
+ effective_path = new_effective_path
- return [("loader", cast(PathRegistry, effective_path).natural_path)]
+ return [("loader", effective_path.natural_path)]
-def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad:
-
- lead_element = None
+def _generate_from_keys(
+ meth: Callable[..., _AbstractLoad],
+ keys: Tuple[_AttrType, ...],
+ chained: bool,
+ kw: Any,
+) -> _AbstractLoad:
+ lead_element: Optional[_AbstractLoad] = None
+ attr: Any
for is_default, _keys in (True, keys[0:-1]), (False, keys[-1:]):
for attr in _keys:
if isinstance(attr, str):
return lead_element
-def _parse_attr_argument(attr):
+def _parse_attr_argument(
+ attr: _AttrType,
+) -> Tuple[InspectionAttr, _InternalEntityType[Any], MapperProperty[Any]]:
"""parse an attribute or wildcard argument to produce an
:class:`._AbstractLoad` instance.
"""
try:
- insp = inspect(attr)
+ # TODO: need to figure out this None thing being returned by
+ # inspect(), it should not have None as an option in most cases
+ # if at all
+ insp: InspectionAttr = inspect(attr) # type: ignore
except sa_exc.NoInspectionAvailable as err:
raise sa_exc.ArgumentError(
"expected ORM mapped attribute for loader strategy argument"
) from err
- if insp.is_property:
+ lead_entity: _InternalEntityType[Any]
+
+ if insp_is_mapper_property(insp):
lead_entity = insp.parent
prop = insp
- elif insp.is_attribute:
+ elif insp_is_attribute(insp):
lead_entity = insp.parent
prop = insp.prop
else:
return insp, lead_entity, prop
-def loader_unbound_fn(fn):
+def loader_unbound_fn(fn: _FN) -> _FN:
"""decorator that applies docstrings between standalone loader functions
and the loader methods on :class:`._AbstractLoad`.
@loader_unbound_fn
-def contains_eager(*keys, **kw) -> _AbstractLoad:
+def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
return _generate_from_keys(Load.contains_eager, keys, True, kw)
@loader_unbound_fn
-def load_only(*attrs) -> _AbstractLoad:
+def load_only(*attrs: _AttrType) -> _AbstractLoad:
# TODO: attrs against different classes. we likely have to
# add some extra state to Load of some kind
_, lead_element, _ = _parse_attr_argument(attrs[0])
@loader_unbound_fn
-def joinedload(*keys, **kw) -> _AbstractLoad:
+def joinedload(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
return _generate_from_keys(Load.joinedload, keys, False, kw)
@loader_unbound_fn
-def subqueryload(*keys) -> _AbstractLoad:
+def subqueryload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.subqueryload, keys, False, {})
@loader_unbound_fn
-def selectinload(*keys) -> _AbstractLoad:
+def selectinload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.selectinload, keys, False, {})
@loader_unbound_fn
-def lazyload(*keys) -> _AbstractLoad:
+def lazyload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.lazyload, keys, False, {})
@loader_unbound_fn
-def immediateload(*keys) -> _AbstractLoad:
+def immediateload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.immediateload, keys, False, {})
@loader_unbound_fn
-def noload(*keys) -> _AbstractLoad:
+def noload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.noload, keys, False, {})
@loader_unbound_fn
-def raiseload(*keys, **kw) -> _AbstractLoad:
+def raiseload(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
return _generate_from_keys(Load.raiseload, keys, False, kw)
@loader_unbound_fn
-def defaultload(*keys) -> _AbstractLoad:
+def defaultload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.defaultload, keys, False, {})
@loader_unbound_fn
-def defer(key, *addl_attrs, **kw) -> _AbstractLoad:
+def defer(key: _AttrType, *addl_attrs: _AttrType, **kw: Any) -> _AbstractLoad:
if addl_attrs:
util.warn_deprecated(
"The *addl_attrs on orm.defer is deprecated. Please use "
@loader_unbound_fn
-def undefer(key, *addl_attrs) -> _AbstractLoad:
+def undefer(key: _AttrType, *addl_attrs: _AttrType) -> _AbstractLoad:
if addl_attrs:
util.warn_deprecated(
"The *addl_attrs on orm.undefer is deprecated. Please use "
@loader_unbound_fn
-def undefer_group(name) -> _AbstractLoad:
+def undefer_group(name: str) -> _AbstractLoad:
element = _WildcardLoad()
return element.undefer_group(name)
@loader_unbound_fn
-def with_expression(key, expression) -> _AbstractLoad:
+def with_expression(
+ key: _AttrType, expression: _ColumnExpressionArgument[Any]
+) -> _AbstractLoad:
return _generate_from_keys(
Load.with_expression, (key,), False, {"expression": expression}
)
@loader_unbound_fn
-def selectin_polymorphic(base_cls, classes) -> _AbstractLoad:
+def selectin_polymorphic(
+ base_cls: _EntityType[Any], classes: Iterable[Type[Any]]
+) -> _AbstractLoad:
ul = Load(base_cls)
return ul.selectin_polymorphic(classes)
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
"""private module containing functions used for copying data
from __future__ import annotations
-from . import attributes
from . import exc
from . import util as orm_util
+from .base import PassiveFlag
def populate(
# inline of source_mapper._get_state_attr_by_column
prop = source_mapper._columntoproperty[l]
value = source.manager[prop.key].impl.get(
- source, source_dict, attributes.PASSIVE_OFF
+ source, source_dict, PassiveFlag.PASSIVE_OFF
)
except exc.UnmappedColumnError as err:
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err)
try:
prop = source_mapper._columntoproperty[r]
source_dict[prop.key] = value
- except exc.UnmappedColumnError:
- _raise_col_to_prop(True, source_mapper, l, source_mapper, r)
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(True, source_mapper, l, source_mapper, r, err)
def clear(dest, dest_mapper, synchronize_pairs):
source.obj(), l
)
value = source_mapper._get_state_attr_by_column(
- source, source.dict, l, passive=attributes.PASSIVE_OFF
+ source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF
)
except exc.UnmappedColumnError as err:
_raise_col_to_prop(False, source_mapper, l, None, r, err)
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(
- source, source.dict, l, passive=attributes.PASSIVE_OFF
+ source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF
)
except exc.UnmappedColumnError as err:
_raise_col_to_prop(False, source_mapper, l, None, r, err)
except exc.UnmappedColumnError as err:
_raise_col_to_prop(False, source_mapper, l, None, r, err)
history = uowcommit.get_attribute_history(
- source, prop.key, attributes.PASSIVE_NO_INITIALIZE
+ source, prop.key, PassiveFlag.PASSIVE_NO_INITIALIZE
)
if bool(history.deleted):
return True
import types
import typing
from typing import Any
+from typing import Callable
from typing import cast
from typing import Dict
from typing import FrozenSet
from ._typing import _EntityType
from ._typing import _IdentityKeyType
from ._typing import _InternalEntityType
- from ._typing import _ORMColumnExprArgument
+ from ._typing import _ORMCOLEXPR
from .context import _MapperEntity
from .context import ORMCompileState
from .mapper import Mapper
+ from .query import Query
from .relationships import Relationship
from ..engine import Row
from ..engine import RowMapping
+ from ..sql._typing import _CE
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _EquivalentColumnMap
from ..sql._typing import _FromClauseArgument
from ..sql._typing import _OnClauseArgument
from ..sql._typing import _PropagateAttrsType
+ from ..sql.annotation import _SA
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import BindParameter
from ..sql.selectable import _ColumnsClauseElement
from ..sql.selectable import Alias
+ from ..sql.selectable import Select
from ..sql.selectable import Subquery
from ..sql.visitors import anon_map
+ from ..util.typing import _AnnotationScanType
_T = TypeVar("_T", bound=Any)
expunge: bool
delete_orphan: bool
- def __new__(cls, value_list):
+ def __new__(
+ cls, value_list: Optional[Union[Iterable[str], str]]
+ ) -> CascadeOptions:
if isinstance(value_list, str) or value_list is None:
- return cls.from_string(value_list)
+ return cls.from_string(value_list) # type: ignore
values = set(value_list)
if values.difference(cls._allowed_cascades):
raise sa_exc.ArgumentError(
def _with_polymorphic_factory(
cls,
base: Union[_O, Mapper[_O]],
- classes: Iterable[Type[Any]],
+ classes: Iterable[_EntityType[Any]],
selectable: Union[Literal[False, None], FromClause] = False,
flat: bool = False,
polymorphic_on: Optional[ColumnElement[Any]] = None,
)._aliased_insp
def _adapt_element(
- self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None
- ) -> _ORMColumnExprArgument[_T]:
- assert isinstance(elem, ColumnElement)
+ self, expr: _ORMCOLEXPR, key: Optional[str] = None
+ ) -> _ORMCOLEXPR:
+ assert isinstance(expr, ColumnElement)
d: Dict[str, Any] = {
"parententity": self,
"parentmapper": self.mapper,
}
if key:
d["proxy_key"] = key
+
+ # IMO mypy should see this one also as returning the same type
+ # we put into it, but it's not
return (
- self._adapter.traverse(elem)
+ self._adapter.traverse(expr) # type: ignore
._annotate(d)
._set_propagate_attrs(
{"compile_state_plugin": "orm", "plugin_subject": self}
)
)
+ if TYPE_CHECKING:
+ # establish compatibility with the _ORMAdapterProto protocol,
+ # which in turn is compatible with _CoreAdapterProto.
+
+ def _orm_adapt_element(
+ self,
+ obj: _CE,
+ key: Optional[str] = None,
+ ) -> _CE:
+ ...
+
+ else:
+ _orm_adapt_element = _adapt_element
+
def _entity_for_mapper(self, mapper):
self_poly = self.with_polymorphic_mappers
if mapper in self_poly:
cloned.name = name
return cloned
- def create_row_processor(self, query, procs, labels):
+ def create_row_processor(
+ self,
+ query: Select[Any],
+ procs: Sequence[Callable[[Row[Any]], Any]],
+ labels: Sequence[str],
+ ) -> Callable[[Row[Any]], Any]:
"""Produce the "row processing" function for this :class:`.Bundle`.
May be overridden by subclasses.
"""
keyed_tuple = result_tuple(labels, [() for l in labels])
- def proc(row):
+ def proc(row: Row[Any]) -> Any:
return keyed_tuple([proc(row) for proc in procs])
return proc
-def _orm_annotate(element, exclude=None):
+def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA:
"""Deep copy the given ClauseElement, annotating each element with the
"_orm_adapt" flag.
return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
-def _orm_deannotate(element):
+def _orm_deannotate(element: _SA) -> _SA:
"""Remove annotations that link a column to a particular mapping.
Note this doesn't affect "remote" and "foreign" annotations
)
-def _orm_full_deannotate(element):
+def _orm_full_deannotate(element: _SA) -> _SA:
return sql_util._deep_deannotate(element)
on_selectable = prop.parent.selectable
else:
prop = None
+ on_selectable = None
if prop:
left_selectable = left_info.selectable
-
+ adapt_from: Optional[FromClause]
if sql_util.clause_is_present(on_selectable, left_selectable):
adapt_from = on_selectable
else:
+ assert isinstance(left_selectable, FromClause)
adapt_from = left_selectable
(
return given.isa(mapper)
-def _getitem(iterable_query, item):
+def _getitem(iterable_query: Query[Any], item: Any) -> Any:
"""calculate __getitem__ in terms of an iterable query object
that also has a slice() method.
isinstance(stop, int) and stop < 0
):
_no_negative_indexes()
- return list(iterable_query)[item]
res = iterable_query.slice(start, stop)
if step is not None:
- return list(res)[None : None : item.step]
+ return list(res)[None : None : item.step] # type: ignore
else:
- return list(res)
+ return list(res) # type: ignore
else:
if item == -1:
_no_negative_indexes()
- return list(iterable_query)[-1]
else:
return list(iterable_query[item : item + 1])[0]
def _extract_mapped_subtype(
- raw_annotation: Union[type, str],
+ raw_annotation: Optional[_AnnotationScanType],
cls: type,
key: str,
attr_cls: Type[Any],
_T = TypeVar("_T", bound=Any)
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
+
+
class _HasClauseElement(Protocol):
"""indicates a class that has a __clause_element__() method"""
...
+class _CoreAdapterProto(Protocol):
+ """protocol for the ClauseAdapter/ColumnAdapter.traverse() method."""
+
+ def __call__(self, obj: _CE) -> _CE:
+ ...
+
+
# match column types that are not ORM entities
_NOT_ENTITY = TypeVar(
"_NOT_ENTITY",
return element
+@overload
+def _deep_deannotate(
+ element: Literal[None], values: Optional[Sequence[str]] = None
+) -> Literal[None]:
+ ...
+
+
+@overload
def _deep_deannotate(
element: _SA, values: Optional[Sequence[str]] = None
) -> _SA:
+ ...
+
+
+def _deep_deannotate(
+ element: Optional[_SA], values: Optional[Sequence[str]] = None
+) -> Optional[_SA]:
"""Deep copy the given element, removing annotations."""
cloned: Dict[Any, SupportsAnnotations] = {}
return element
-def _shallow_annotate(
- element: SupportsAnnotations, annotations: _AnnotationDict
-) -> SupportsAnnotations:
+def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA:
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
o1.__dict__.update(other)
return o1
+ if TYPE_CHECKING:
+
+ def __getattr__(self, key: str) -> Any:
+ ...
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ ...
+
+ def __delattr__(self, key: str) -> None:
+ ...
+
class Options(metaclass=_MetaOptions):
"""A cacheable option dictionary with defaults."""
else:
return existing_options, exec_options
+ if TYPE_CHECKING:
+
+ def __getattr__(self, key: str) -> Any:
+ ...
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ ...
+
+ def __delattr__(self, key: str) -> None:
+ ...
+
class CacheableOptions(Options, HasCacheKey):
__slots__ = ()
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import DQLDMLClauseElement
+ from .elements import NamedColumn
from .elements import SQLCoreOperations
from .schema import Column
from .selectable import _ColumnsClauseElement
...
+@overload
+def expect(
+ role: Type[roles.LabeledColumnExprRole[Any]],
+ element: _ColumnExpressionArgument[_T],
+ **kw: Any,
+) -> NamedColumn[_T]:
+ ...
+
+
@overload
def expect(
role: Union[
Type[roles.LimitOffsetRole],
Type[roles.WhereHavingRole],
Type[roles.OnClauseRole],
+ Type[roles.ColumnArgumentRole],
],
element: Any,
**kw: Any,
def params(
self: SelfClauseElement,
- __optionaldict: Optional[Dict[str, Any]] = None,
+ __optionaldict: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> SelfClauseElement:
"""Return a copy with :func:`_expression.bindparam` elements
def _replace_params(
self: SelfClauseElement,
unique: bool,
- optionaldict: Optional[Dict[str, Any]],
+ optionaldict: Optional[Mapping[str, Any]],
kwargs: Dict[str, Any],
) -> SelfClauseElement:
{"bindparam": visit_bindparam},
)
- def compare(self, other, **kw):
+ def compare(self, other: ClauseElement, **kw: Any) -> bool:
r"""Compare this :class:`_expression.ClauseElement` to
the given :class:`_expression.ClauseElement`.
return False_._singleton
@classmethod
- def _ifnone(cls, other):
+ def _ifnone(
+ cls, other: Optional[ColumnElement[Any]]
+ ) -> ColumnElement[Any]:
if other is None:
return cls._instance()
else:
) -> Optional[str]:
return name
- def _bind_param(self, operator, obj, type_=None, expanding=False):
+ def _bind_param(
+ self,
+ operator: OperatorType,
+ obj: Any,
+ type_: Optional[TypeEngine[_T]] = None,
+ expanding: bool = False,
+ ) -> BindParameter[_T]:
return BindParameter(
self.key,
obj,
from .base import _expand_cloned
from .base import _from_objects
from .base import _generative
+from .base import _NoArg
from .base import _select_iterables
from .base import CacheableOptions
from .base import ColumnCollection
from .dml import Insert
from .dml import Update
from .elements import KeyedColumnElement
+ from .elements import Label
from .elements import NamedColumn
from .elements import TextClause
from .functions import Function
"""
raise NotImplementedError()
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
"""Return ``True`` if this :class:`.ReturnsRows` is
'derived' from the given :class:`.FromClause`.
"""
return TableSample._construct(self, sampling, name, seed)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
"""Return ``True`` if this :class:`_expression.FromClause` is
'derived' from the given ``FromClause``.
"""
+ LABEL_STYLE_LEGACY_ORM = 3
+
(
LABEL_STYLE_NONE,
LABEL_STYLE_TABLENAME_PLUS_COL,
LABEL_STYLE_DISAMBIGUATE_ONLY,
+ _,
) = list(SelectLabelStyle)
LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY
id(self.right),
)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
return (
# use hash() to ensure direct comparison to annotated works
# as well
"""Legacy for dialects that are referring to Alias.original."""
return self.element
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
if fromclause in self._cloned_set:
return True
return self.element.is_derived_from(fromclause)
def foreign_keys(self):
return self.element.foreign_keys
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
return self.element.is_derived_from(fromclause)
def alias(
def __init__(
self,
- nowait=False,
- read=False,
- of=None,
- skip_locked=False,
- key_share=False,
+ *,
+ nowait: bool = False,
+ read: bool = False,
+ of: Optional[
+ Union[
+ _ColumnExpressionArgument[Any],
+ Sequence[_ColumnExpressionArgument[Any]],
+ ]
+ ] = None,
+ skip_locked: bool = False,
+ key_share: bool = False,
):
"""Represents arguments specified to
:meth:`_expression.Select.for_update`.
return ScalarSelect(self)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[Any]:
"""Return a 'scalar' representation of this selectable, embedded as a
subquery with a label.
@_generative
def with_for_update(
self: SelfGenerativeSelect,
+ *,
nowait: bool = False,
read: bool = False,
of: Optional[
@_generative
def order_by(
- self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfGenerativeSelect,
+ __first: Union[
+ Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of ORDER BY
criteria applied.
"""
- if len(clauses) == 1 and clauses[0] is None:
+ if not clauses and __first is None:
self._order_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
self._order_by_clauses += tuple(
coercions.expect(roles.OrderByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
return self
@_generative
def group_by(
- self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfGenerativeSelect,
+ __first: Union[
+ Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of GROUP BY
criterion applied.
"""
- if len(clauses) == 1 and clauses[0] is None:
+ if not clauses and __first is None:
self._group_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
self._group_by_clauses += tuple(
coercions.expect(roles.GroupByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
return self
) -> GroupedElement:
return SelectStatementGrouping(self)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
for s in self.selects:
if s.is_derived_from(fromclause):
return True
_raw_columns: List[_ColumnsClauseElement]
- _distinct = False
+ _distinct: bool = False
_distinct_on: Tuple[ColumnElement[Any], ...] = ()
_correlate: Tuple[FromClause, ...] = ()
_correlate_except: Optional[Tuple[FromClause, ...]] = None
return iter(self._all_selected_columns)
- def is_derived_from(self, fromclause: FromClause) -> bool:
- if self in fromclause._cloned_set:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
+ if fromclause is not None and self in fromclause._cloned_set:
return True
for f in self._iterate_from_elements():
from typing import Deque
from typing import Dict
from typing import Iterable
+from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
COMPARE_SUCCEEDED = True
-def compare(obj1, obj2, **kw):
+def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
strategy: TraversalComparatorStrategy
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
return strategy.compare(obj1, obj2, **kw)
-def _preconfigure_traversals(target_hierarchy):
+def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
for cls in util.walk_subclasses(target_hierarchy):
if hasattr(cls, "_generate_cache_attrs") and hasattr(
cls, "_traverse_internals"
def __init__(self):
self.stack: Deque[
- Tuple[ExternallyTraversible, ExternallyTraversible]
+ Tuple[
+ Optional[ExternallyTraversible],
+ Optional[ExternallyTraversible],
+ ]
] = deque()
self.cache = set()
def _memoized_attr_anon_map(self):
return (anon_map(), anon_map())
- def compare(self, obj1, obj2, **kw):
+ def compare(
+ self,
+ obj1: ExternallyTraversible,
+ obj2: ExternallyTraversible,
+ **kw: Any,
+ ) -> bool:
stack = self.stack
cache = self.cache
elif left_attrname in attributes_compared:
continue
+ assert left_visit_sym is not None
+ assert left_attrname is not None
+ assert right_attrname is not None
+
dispatch = self.dispatch(left_visit_sym)
assert dispatch, (
f"{self.__class__} has no dispatch for "
self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in zip_longest(left, right, fillvalue=None):
+ if l is None:
+ if r is not None:
+ return COMPARE_FAILED
+ else:
+ continue
+ elif r is None:
+ return COMPARE_FAILED
+
if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
self.anon_map[1], []
):
self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in zip_longest(left, right, fillvalue=None):
+ if l is None:
+ if r is not None:
+ return COMPARE_FAILED
+ else:
+ continue
+ elif r is None:
+ return COMPARE_FAILED
+
if (
l._gen_cache_key(self.anon_map[0], [])
if l._is_has_cache_key
from ._typing import _ColumnExpressionArgument
from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
+ from .elements import BinaryExpression
from .elements import TextClause
from .selectable import _JoinTargetElement
from .selectable import _SelectIterable
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.row import Row
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
-def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
+
+def join_condition(
+ a: FromClause,
+ b: FromClause,
+ a_subset: Optional[FromClause] = None,
+ consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
+) -> ColumnElement[bool]:
"""Create a join condition between two tables or selectables.
e.g.::
)
-def find_join_source(clauses, join_to):
+def find_join_source(
+ clauses: List[FromClause], join_to: FromClause
+) -> List[int]:
"""Given a list of FROM clauses and a selectable,
return the first index and element from the list of
clauses which can be joined against the selectable. returns
return idx
-def find_left_clause_that_matches_given(clauses, join_from):
+def find_left_clause_that_matches_given(
+ clauses: Sequence[FromClause], join_from: FromClause
+) -> List[int]:
"""Given a list of FROM clauses and a selectable,
return the indexes from the list of
clauses which is derived from the selectable.
return idx
-def visit_binary_product(fn, expr):
+def visit_binary_product(
+ fn: Callable[
+ [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
+ ],
+ expr: ColumnElement[Any],
+) -> None:
"""Produce a traversal of the given expression, delivering
column comparisons to the given function.
a binary comparison is passed as pairs.
"""
- stack: List[ClauseElement] = []
+ stack: List[BinaryExpression[Any]] = []
- def visit(element):
+ def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
if isinstance(element, ScalarSelect):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
elif element.__visit_name__ == "binary" and operators.is_comparison(
- element.operator
+ element.operator # type: ignore
):
- stack.insert(0, element)
- for l in visit(element.left):
- for r in visit(element.right):
+ stack.insert(0, element) # type: ignore
+ for l in visit(element.left): # type: ignore
+ for r in visit(element.right): # type: ignore
fn(stack[0], l, r)
stack.pop(0)
for elem in element.get_children():
return None
-def selectables_overlap(left, right):
+def selectables_overlap(left: FromClause, right: FromClause) -> bool:
"""Return True if left/right have some overlapping selectable"""
return bool(
return "[%s]" % (", ".join(trunc(value) for value in params))
-def adapt_criterion_to_null(crit, nulls):
+def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
"""given criterion containing bind params, convert selected elements
to IS NULL.
return pairs
-_CE = TypeVar("_CE", bound="ClauseElement")
-
-
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
from typing import Callable
from typing import cast
from typing import ClassVar
-from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import overload
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from ..util.typing import Protocol
from ..util.typing import Self
+if TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .elements import ColumnElement
+
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import prefix_anon_map as prefix_anon_map
from ._py_util import cache_anon_map as anon_map
_generate_traversal_dispatch()
+SelfExternallyTraversible = TypeVar(
+ "SelfExternallyTraversible", bound="ExternallyTraversible"
+)
+
+
class ExternallyTraversible(HasTraverseInternals, Visitable):
__slots__ = ()
- _annotations: Collection[Any] = ()
+ _annotations: Mapping[Any, Any] = util.EMPTY_DICT
if typing.TYPE_CHECKING:
+ def _annotate(
+ self: SelfExternallyTraversible, values: _AnnotationDict
+ ) -> SelfExternallyTraversible:
+ ...
+
def get_children(
self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[ExternallyTraversible]:
_ET = TypeVar("_ET", bound=ExternallyTraversible)
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
_TraverseCallableType = Callable[[_ET], None]
...
-class _TraverseTransformCallableType(Protocol):
- def __call__(
- self, element: ExternallyTraversible, **kw: Any
- ) -> Optional[ExternallyTraversible]:
+class _TraverseTransformCallableType(Protocol[_ET]):
+ def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]:
...
def replacement_traverse(
obj: Literal[None],
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> None:
...
+@overload
+def replacement_traverse(
+ obj: _CE,
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType[Any],
+) -> _CE:
+ ...
+
+
@overload
def replacement_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> ExternallyTraversible:
...
def replacement_traverse(
obj: Optional[ExternallyTraversible],
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
newelem = replace(elem)
if newelem is not None:
stop_on.add(id(newelem))
- return newelem
+ return newelem # type: ignore
else:
# base "already seen" on id(), not hash, so that we don't
# replace an Annotated element with its non-annotated one, and
newelem = kw["replace"](elem)
if newelem is not None:
cloned[id_elem] = newelem
- return newelem
+ return newelem # type: ignore
cloned[id_elem] = newelem = elem._clone(**kw)
newelem._copy_internals(clone=clone, **kw)
- return cloned[id_elem]
+ return cloned[id_elem] # type: ignore
if obj is not None:
obj = clone(
EMPTY_SET: FrozenSet[Any] = frozenset()
-def merge_lists_w_ordering(a, b):
+def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
"""merge two lists, maintaining ordering as much as possible.
this is to reconcile vars(cls) with cls.__annotations__.
return x
-def to_column_set(x):
+def to_column_set(x: Any) -> Set[Any]:
if x is None:
return column_set()
if not isinstance(x, column_set):
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
py311 = sys.version_info >= (3, 11)
return result
-def dataclass_fields(cls):
+def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated
with a class."""
return []
-def local_dataclass_fields(cls):
+def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated with
a class, excluding those that originate from a superclass."""
if dataclasses.is_dataclass(cls):
- super_fields = set()
+ super_fields: Set[dataclasses.Field[Any]] = set()
for sup in cls.__bases__:
super_fields.update(dataclass_fields(sup))
return [f for f in dataclasses.fields(cls) if f not in super_fields]
metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
metadata.update(format_argspec_plus(spec, grouped=False))
metadata["name"] = fn.__name__
- code = (
- """\
+
+ # look for __ positional arguments. This is a convention in
+ # SQLAlchemy that arguments should be passed positionally
+ # rather than as keyword
+ # arguments. note that apply_pos doesn't currently work in all cases
+ # such as when a kw-only indicator "*" is present, which is why
+ # we limit the use of this to just that case we can detect. As we add
+ # more kinds of methods that use @decorator, things may have to
+ # be further improved in this area
+ if "__" in repr(spec[0]):
+ code = (
+ """\
+def %(name)s%(grouped_args)s:
+ return %(target)s(%(fn)s, %(apply_pos)s)
+"""
+ % metadata
+ )
+ else:
+ code = (
+ """\
def %(name)s%(grouped_args)s:
return %(target)s(%(fn)s, %(apply_kw)s)
"""
- % metadata
- )
+ % metadata
+ )
env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
decorated = cast(
return result
@classmethod
- def memoized_instancemethod(cls, fn: Any) -> Any:
+ def memoized_instancemethod(cls, fn: _F) -> _F:
"""Decorate a method memoize its return value."""
- def oneshot(self, *args, **kw):
+ def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
result = fn(self, *args, **kw)
def memo(*a, **kw):
self._memoized_keys |= {fn.__name__}
return result
- return update_wrapper(oneshot, fn)
+ return update_wrapper(oneshot, fn) # type: ignore
if TYPE_CHECKING:
if TYPE_CHECKING:
from sqlalchemy.engine import default as engine_default # noqa
+ from sqlalchemy.orm import clsregistry as orm_clsregistry # noqa
+ from sqlalchemy.orm import decl_api as orm_decl_api # noqa
+ from sqlalchemy.orm import properties as orm_properties # noqa
from sqlalchemy.orm import relationships as orm_relationships # noqa
from sqlalchemy.orm import session as orm_session # noqa
+ from sqlalchemy.orm import state as orm_state # noqa
from sqlalchemy.orm import util as orm_util # noqa
from sqlalchemy.sql import dml as sql_dml # noqa
from sqlalchemy.sql import functions as sql_functions # noqa
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
"""Topological sorting algorithms."""
from __future__ import annotations
+from typing import Any
+from typing import DefaultDict
+from typing import Iterable
+from typing import Iterator
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+
from .. import util
from ..exc import CircularDependencyError
+_T = TypeVar("_T", bound=Any)
+
__all__ = ["sort", "sort_as_subsets", "find_cycles"]
-def sort_as_subsets(tuples, allitems):
+def sort_as_subsets(
+ tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
+) -> Iterator[Sequence[_T]]:
- edges = util.defaultdict(set)
+ edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[child].add(parent)
yield output
-def sort(tuples, allitems, deterministic_order=True):
+def sort(
+ tuples: Iterable[Tuple[_T, _T]],
+ allitems: Iterable[_T],
+ deterministic_order: bool = True,
+) -> Iterator[_T]:
"""sort the given list of items by dependency.
'tuples' is a list of tuples representing a partial ordering.
yield s
-def find_cycles(tuples, allitems):
+def find_cycles(
+ tuples: Iterable[Tuple[_T, _T]],
+ allitems: Iterable[_T],
+) -> Set[_T]:
# adapted from:
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
- edges = util.defaultdict(set)
+ edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[parent].add(child)
nodes_to_test = set(edges)
return output
-def _gen_edges(edges):
- return set([(right, left) for left in edges for right in edges[left]])
+def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]:
+ return {(right, left) for left in edges for right in edges[left]}
from typing import ForwardRef
from typing import Generic
from typing import Iterable
+from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Tuple
from typing import Type
from typing import TypeVar
if compat.py310:
# why they took until py310 to put this in stdlib is beyond me,
# I've been wanting it since py27
- from types import NoneType
+ from types import NoneType as NoneType
else:
NoneType = type(None) # type: ignore
# copied from TypeShed, required in order to implement
# MutableMapping.update()
+_AnnotationScanType = Union[Type[Any], str]
+
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
def keys(self) -> Iterable[_KT]:
def de_stringify_annotation(
cls: Type[Any],
- annotation: Union[str, Type[Any]],
+ annotation: _AnnotationScanType,
str_cleanup_fn: Optional[Callable[[str], str]] = None,
-) -> Union[str, Type[Any]]:
+) -> Type[Any]:
"""Resolve annotations that may be string based into real objects.
This is particularly important if a module defines "from __future__ import
annotation = eval(annotation, base_globals, None)
except NameError:
pass
- return annotation
+ return annotation # type: ignore
-def is_fwd_ref(type_):
+def is_fwd_ref(type_: _AnnotationScanType) -> bool:
return isinstance(type_, ForwardRef)
-def de_optionalize_union_types(type_):
+@overload
+def de_optionalize_union_types(type_: str) -> str:
+ ...
+
+
+@overload
+def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]:
+ ...
+
+
+def de_optionalize_union_types(
+ type_: _AnnotationScanType,
+) -> _AnnotationScanType:
"""Given a type, filter out ``Union`` types that include ``NoneType``
to not include the ``NoneType``.
"""
if is_optional(type_):
- typ = set(type_.__args__)
+ typ = set(type_.__args__) # type: ignore
typ.discard(NoneType)
return type_
-def make_union_type(*types):
+def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
"""Make a Union type.
This is needed by :func:`.de_optionalize_union_types` which removes
``NoneType`` from a ``Union``.
"""
- return cast(Any, Union).__getitem__(types)
+ return cast(Any, Union).__getitem__(types) # type: ignore
def expand_unions(
...
+_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
+
+
+class RODescriptorReference(Generic[_DESC_co]):
+ """a descriptor that refers to a descriptor.
+
+ same as :class:`.DescriptorReference` but is read-only, so that subclasses
+ can define a subtype as the generically contained element
+
+ """
+
+ def __get__(self, instance: object, owner: Any) -> _DESC_co:
+ ...
+
+ def __set__(self, instance: Any, value: Any) -> NoReturn:
+ ...
+
+ def __delete__(self, instance: Any) -> NoReturn:
+ ...
+
+
+_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
+
+
+class CallableReference(Generic[_FN]):
+ """a descriptor that refers to a callable.
+
+ works around mypy's limitation of not allowing callables assigned
+ as instance variables
+
+
+ """
+
+ def __get__(self, instance: object, owner: Any) -> _FN:
+ ...
+
+ def __set__(self, instance: Any, value: _FN) -> None:
+ ...
+
+ def __delete__(self, instance: Any) -> None:
+ ...
+
+
# $def ro_descriptor_reference(fn: Callable[])
# pass
module = [
- # TODO for ORM, non-strict
- "sqlalchemy.orm.base",
- "sqlalchemy.orm.decl_base",
- "sqlalchemy.orm.descriptor_props",
- "sqlalchemy.orm.identity",
- "sqlalchemy.orm.mapped_collection",
- "sqlalchemy.orm.properties",
- "sqlalchemy.orm.relationships",
- "sqlalchemy.orm.strategy_options",
- "sqlalchemy.orm.state_changes",
-
- # would ideally be strict
- "sqlalchemy.orm.decl_api",
- "sqlalchemy.orm.events",
- "sqlalchemy.orm.query",
"sqlalchemy.engine.reflection",
]
--- /dev/null
+from typing import Any
+from typing import Tuple
+
+from sqlalchemy import select
+from sqlalchemy.orm import composite
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class Point:
+ def __init__(self, x: int, y: int):
+ self.x = x
+ self.y = y
+
+ def __composite_values__(self) -> Tuple[int, int]:
+ return self.x, self.y
+
+ def __repr__(self) -> str:
+ return "Point(x=%r, y=%r)" % (self.x, self.y)
+
+ def __eq__(self, other: Any) -> bool:
+ return (
+ isinstance(other, Point)
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+
+class Vertex(Base):
+ __tablename__ = "vertices"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ x1: Mapped[int]
+ y1: Mapped[int]
+ x2: Mapped[int]
+ y2: Mapped[int]
+
+ # inferred from right hand side
+ start = composite(Point, "x1", "y1")
+
+ # taken from left hand side
+ end: Mapped[Point] = composite(Point, "x2", "y2")
+
+
+v1 = Vertex(start=Point(3, 4), end=Point(5, 6))
+
+stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)]))
+
+# EXPECTED_TYPE: Select[Tuple[Vertex]]
+reveal_type(stmt)
+
+# EXPECTED_TYPE: composite.Point
+reveal_type(v1.start)
+
+# EXPECTED_TYPE: composite.Point
+reveal_type(v1.end)
+
+# EXPECTED_TYPE: int
+reveal_type(v1.end.y)
)
expected_msg = re.sub(
- r"(int|str|float|bool)",
+ r"\b(int|str|float|bool)\b",
lambda m: rf"builtins.{m.group(0)}\*?",
expected_msg,
)
reg = registry(metadata=metadata)
- reg.map_declaratively(User)
+ mp = reg.map_declaratively(User)
+ assert mp is inspect(User)
+ assert mp is User.__mapper__
def test_undefer_column_name(self):
# TODO: not sure if there was an explicit
class_mapper(User).get_property("props").secondary is user_to_prop
)
+ def test_string_dependency_resolution_schemas_no_base(self):
+ """
+
+ found_during_type_annotation
+
+ """
+
+ reg = registry()
+
+ @reg.mapped
+ class User:
+
+ __tablename__ = "users"
+ __table_args__ = {"schema": "fooschema"}
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ props = relationship(
+ "Prop",
+ secondary="fooschema.user_to_prop",
+ primaryjoin="User.id==fooschema.user_to_prop.c.user_id",
+ secondaryjoin="fooschema.user_to_prop.c.prop_id==Prop.id",
+ backref="users",
+ )
+
+ @reg.mapped
+ class Prop:
+
+ __tablename__ = "props"
+ __table_args__ = {"schema": "fooschema"}
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+
+ user_to_prop = Table(
+ "user_to_prop",
+ reg.metadata,
+ Column("user_id", Integer, ForeignKey("fooschema.users.id")),
+ Column("prop_id", Integer, ForeignKey("fooschema.props.id")),
+ schema="fooschema",
+ )
+ configure_mappers()
+
+ assert (
+ class_mapper(User).get_property("props").secondary is user_to_prop
+ )
+
def test_string_dependency_resolution_annotations(self):
Base = declarative_base()
reg = registry(metadata=metadata)
reg.mapped(User)
reg.mapped(Address)
+
+ reg.metadata.create_all(testing.db)
+ u1 = User(
+ name="u1", addresses=[Address(email="one"), Address(email="two")]
+ )
+ with Session(testing.db) as sess:
+ sess.add(u1)
+ sess.commit()
+ with Session(testing.db) as sess:
+ eq_(
+ sess.query(User).all(),
+ [
+ User(
+ name="u1",
+ addresses=[Address(email="one"), Address(email="two")],
+ )
+ ],
+ )
+
+ def test_map_declaratively(self, metadata):
+ class User(fixtures.ComparableEntity):
+
+ __tablename__ = "users"
+ id = Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ )
+ name = Column("name", String(50))
+ addresses = relationship("Address", backref="user")
+
+ class Address(fixtures.ComparableEntity):
+
+ __tablename__ = "addresses"
+ id = Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ )
+ email = Column("email", String(50))
+ user_id = Column("user_id", Integer, ForeignKey("users.id"))
+
+ reg = registry(metadata=metadata)
+ um = reg.map_declaratively(User)
+ am = reg.map_declaratively(Address)
+
+ is_(User.__mapper__, um)
+ is_(Address.__mapper__, am)
+
reg.metadata.create_all(testing.db)
u1 = User(
name="u1", addresses=[Address(email="one"), Address(email="two")]
import sqlalchemy as sa
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
+from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.orm import class_mapper
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
Engineer(name="vlad", primary_language="cobol"),
)
+ def test_single_cols_on_sub_base_of_subquery(self):
+ """
+ found_during_type_annotation
+
+ """
+ t = Table("t", Base.metadata, Column("id", Integer, primary_key=True))
+
+ class Person(Base):
+ __table__ = select(t).subquery()
+
+ with expect_raises_message(
+ sa.exc.ArgumentError,
+ r"Can't declare columns on single-table-inherited subclass "
+ r".*Contractor.*; superclass .*Person.* is not mapped to a Table",
+ ):
+
+ class Contractor(Person):
+ contractor_field = Column(String)
+
def test_single_cols_on_sub_base_of_joined(self):
"""test [ticket:3895]"""
'"user".state, "user".zip FROM "user"',
)
+ def test_name_cols_by_str(self, decl_base):
+ @dataclasses.dataclass
+ class Address:
+ street: str
+ state: str
+ zip_: str
+
+ class User(decl_base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str]
+ street: Mapped[str]
+ state: Mapped[str]
+
+ # TODO: this needs to be improved, we should be able to say:
+ # zip_: Mapped[str] = mapped_column("zip")
+ # and it should assign to "zip_" for the attribute. not working
+
+ zip_: Mapped[str] = mapped_column(name="zip", key="zip_")
+
+ address: Mapped["Address"] = composite(
+ Address, "street", "state", "zip_"
+ )
+
+ eq_(
+ User.__mapper__.attrs["address"].props,
+ [
+ User.__mapper__.attrs["street"],
+ User.__mapper__.attrs["state"],
+ User.__mapper__.attrs["zip_"],
+ ],
+ )
+ self.assert_compile(
+ select(User),
+ 'SELECT "user".id, "user".name, "user".street, '
+ '"user".state, "user".zip FROM "user"',
+ )
+
def test_cls_annotated_setup(self, decl_base):
@dataclasses.dataclass
class Address:
from sqlalchemy.orm import with_polymorphic
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
from test.orm.test_events import _RemoveListeners
-class ConcreteTest(fixtures.MappedTest):
+class ConcreteTest(AssertsCompiledSQL, fixtures.MappedTest):
+ __dialect__ = "default"
+
@classmethod
def define_tables(cls, metadata):
Table(
"sometype",
)
+ # found_during_type_annotation
+ # test the comparator returned by ConcreteInheritedProperty
+ self.assert_compile(Manager.type == "x", "pjoin.type = :type_1")
+
jenn = Engineer("Jenn", "knows how to program")
hacker = Hacker("Karina", "Badass", "knows how to hack")
)
assert a1.c.users_id is not None
+ def test_no_subquery_for_from_statement(self):
+ """
+ found_during_typing
+
+ """
+ User = self.classes.User
+
+ session = fixture_session()
+ q = session.query(User.id).from_statement(text("select * from user"))
+
+ with expect_raises_message(
+ sa.exc.InvalidRequestError,
+ r"Can't call this method on a Query that uses from_statement\(\)",
+ ):
+ q.subquery()
+
def test_reduced_subquery(self):
User = self.classes.User
ua = aliased(User)
"FROM users GROUP BY name",
)
- def test_orm_columns_accepts_text(self):
- from sqlalchemy.orm.base import _orm_columns
-
- t = text("x")
- eq_(_orm_columns(t), [t])
-
def test_order_by_w_eager_one(self):
User = self.classes.User
s = fixture_session()
assert_raises_message(
sa.exc.ArgumentError,
"T1.t1s and back-reference T1.parent are "
- r"both of the same direction symbol\('ONETOMANY'\). Did you "
+ r"both of the same "
+ r"direction .*RelationshipDirection.ONETOMANY.*. Did you "
"mean to set remote_side on the many-to-one side ?",
configure_mappers,
)
assert_raises_message(
sa.exc.ArgumentError,
"T1.t1s and back-reference T1.parent are "
- r"both of the same direction symbol\('MANYTOONE'\). Did you "
+ r"both of the same direction .*RelationshipDirection.MANYTOONE.*."
+ "Did you "
"mean to set remote_side on the many-to-one side ?",
configure_mappers,
)
# can't be sure of ordering here
assert_raises_message(
sa.exc.ArgumentError,
- r"both of the same direction symbol\('ONETOMANY'\). Did you "
+ r"both of the same direction "
+ r".*RelationshipDirection.ONETOMANY.*. Did you "
"mean to set remote_side on the many-to-one side ?",
configure_mappers,
)
# can't be sure of ordering here
assert_raises_message(
sa.exc.ArgumentError,
- r"both of the same direction symbol\('MANYTOONE'\). Did you "
+ r"both of the same direction "
+ r".*RelationshipDirection.MANYTOONE.*. Did you "
"mean to set remote_side on the many-to-one side ?",
configure_mappers,
)