from ... import types as sqltypes
from ... import util
from ...engine import cursor as _cursor
+from ...util import FastIntFlag
+from ...util import parse_user_argument_for_enum
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
pass
-EXECUTEMANY_PLAIN = util.symbol("executemany_plain", canonical=0)
-EXECUTEMANY_BATCH = util.symbol("executemany_batch", canonical=1)
-EXECUTEMANY_VALUES = util.symbol("executemany_values", canonical=2)
-EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol(
- "executemany_values_plus_batch",
- canonical=EXECUTEMANY_BATCH | EXECUTEMANY_VALUES,
-)
+class ExecutemanyMode(FastIntFlag):
+ EXECUTEMANY_PLAIN = 0
+ EXECUTEMANY_BATCH = 1
+ EXECUTEMANY_VALUES = 2
+ EXECUTEMANY_VALUES_PLUS_BATCH = EXECUTEMANY_BATCH | EXECUTEMANY_VALUES
+
+
+(
+ EXECUTEMANY_PLAIN,
+ EXECUTEMANY_BATCH,
+ EXECUTEMANY_VALUES,
+ EXECUTEMANY_VALUES_PLUS_BATCH,
+) = tuple(ExecutemanyMode)
class PGDialect_psycopg2(_PGDialect_common_psycopg):
# Parse executemany_mode argument, allowing it to be only one of the
# symbol names
- self.executemany_mode = util.symbol.parse_user_argument(
+ self.executemany_mode = parse_user_argument_for_enum(
executemany_mode,
{
EXECUTEMANY_PLAIN: [None],
"""
-_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.immutabledict()
-NO_OPTIONS: Mapping[str, Any] = util.immutabledict()
+_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT
+NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT
class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
)
)
+ def _get_required_transaction(self) -> RootTransaction:
+ trans = self._transaction
+ if trans is None:
+ raise exc.InvalidRequestError("connection is not in a transaction")
+ return trans
+
+ def _get_required_nested_transaction(self) -> NestedTransaction:
+ trans = self._nested_transaction
+ if trans is None:
+ raise exc.InvalidRequestError(
+ "connection is not in a nested transaction"
+ )
+ return trans
+
def get_transaction(self) -> Optional[RootTransaction]:
"""Return the current root transaction in progress, if any.
self,
func: FunctionElement[Any],
distilled_parameters: _CoreMultiExecuteParams,
- execution_options: _ExecuteOptions,
+ execution_options: _ExecuteOptionsParameter,
) -> Result:
"""Execute a sql.FunctionElement object."""
self,
default: ColumnDefault,
distilled_parameters: _CoreMultiExecuteParams,
- execution_options: _ExecuteOptions,
+ execution_options: _ExecuteOptionsParameter,
) -> Any:
"""Execute a schema.ColumnDefault object."""
self,
ddl: DDLElement,
distilled_parameters: _CoreMultiExecuteParams,
- execution_options: _ExecuteOptions,
+ execution_options: _ExecuteOptionsParameter,
) -> Result:
"""Execute a schema.DDL object."""
self,
elem: Executable,
distilled_parameters: _CoreMultiExecuteParams,
- execution_options: _ExecuteOptions,
+ execution_options: _ExecuteOptionsParameter,
) -> Result:
"""Execute a sql.ClauseElement object."""
self,
statement: str,
parameters: Optional[_DBAPIAnyExecuteParams] = None,
- execution_options: Optional[_ExecuteOptions] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> Result:
r"""Executes a SQL statement construct and returns a
:class:`_engine.CursorResult`.
import typing
from typing import Any
from typing import Callable
+from typing import Optional
from typing import TypeVar
from .. import exc
from .. import util
from ..util._has_cy import HAS_CYEXTENSION
+from ..util.typing import Protocol
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import _distill_params_20 as _distill_params_20
return decorated # type: ignore[return-value]
+class _TConsSubject(Protocol):
+ _trans_context_manager: Optional[TransactionalContext]
+
+
class TransactionalContext:
"""Apply Python context manager behavior to transaction objects.
__slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__")
+ _trans_subject: Optional[_TConsSubject]
+
def _transaction_is_active(self) -> bool:
raise NotImplementedError()
"""
raise NotImplementedError()
- def _get_subject(self) -> Any:
+ def _get_subject(self) -> _TConsSubject:
raise NotImplementedError()
def commit(self) -> None:
raise NotImplementedError()
@classmethod
- def _trans_ctx_check(cls, subject: Any) -> None:
+ def _trans_ctx_check(cls, subject: _TConsSubject) -> None:
trans_context = subject._trans_context_manager
if trans_context:
if not trans_context._transaction_is_active():
from .api import listens_for as listens_for
from .api import NO_RETVAL as NO_RETVAL
from .api import remove as remove
+from .attr import _InstanceLevelDispatch as _InstanceLevelDispatch
from .attr import RefCollection as RefCollection
from .base import _Dispatch as _Dispatch
from .base import _DispatchCommon as _DispatchCommon
# code within this block is **programmatically,
# statically generated** by tools/generate_proxy_methods.py
- def __contains__(self, instance):
+ def __contains__(self, instance: object) -> bool:
r"""Return True if the instance is associated with this session.
.. container:: class_bases
return self._proxied.__contains__(instance)
- def __iter__(self):
+ def __iter__(self) -> Iterator[object]:
r"""Iterate over all pending or persistent instances within this
Session.
return self._proxied.__iter__()
- def add(self, instance: Any, _warn: bool = True) -> None:
+ def add(self, instance: object, _warn: bool = True) -> None:
r"""Place an object in the ``Session``.
.. container:: class_bases
return self._proxied.add(instance, _warn=_warn)
- def add_all(self, instances):
+ def add_all(self, instances: Iterable[object]) -> None:
r"""Add the given collection of instances to this ``Session``.
.. container:: class_bases
**kw,
)
- def expire(self, instance, attribute_names=None):
+ def expire(
+ self, instance: object, attribute_names: Optional[Iterable[str]] = None
+ ) -> None:
r"""Expire the attributes on an instance.
.. container:: class_bases
return self._proxied.expire(instance, attribute_names=attribute_names)
- def expire_all(self):
+ def expire_all(self) -> None:
r"""Expires all persistent instances within this Session.
.. container:: class_bases
return self._proxied.expire_all()
- def expunge(self, instance):
+ def expunge(self, instance: object) -> None:
r"""Remove the `instance` from this ``Session``.
.. container:: class_bases
return self._proxied.expunge(instance)
- def expunge_all(self):
+ def expunge_all(self) -> None:
r"""Remove all object instances from this ``Session``.
.. container:: class_bases
mapper=mapper, clause=clause, bind=bind, **kw
)
- def is_modified(self, instance, include_collections=True):
+ def is_modified(
+ self, instance: object, include_collections: bool = True
+ ) -> bool:
r"""Return ``True`` if the given instance has locally
modified attributes.
return await AsyncSession.close_all()
@classmethod
- def object_session(cls, instance: Any) -> "Session":
+ def object_session(cls, instance: object) -> Optional[Session]:
r"""Return the :class:`.Session` to which an object belongs.
.. container:: class_bases
@classmethod
def identity_key(
cls,
- class_=None,
- ident=None,
+ class_: Optional[Type[Any]] = None,
+ ident: Union[Any, Tuple[Any, ...]] = None,
*,
- instance=None,
- row=None,
- identity_token=None,
- ) -> _IdentityKeyType:
+ instance: Optional[Any] = None,
+ row: Optional[Row] = None,
+ identity_token: Optional[Any] = None,
+ ) -> _IdentityKeyType[Any]:
r"""Return an identity key.
.. container:: class_bases
# code within this block is **programmatically,
# statically generated** by tools/generate_proxy_methods.py
- def __contains__(self, instance):
+ def __contains__(self, instance: object) -> bool:
r"""Return True if the instance is associated with this session.
.. container:: class_bases
return self._proxied.__contains__(instance)
- def __iter__(self):
+ def __iter__(self) -> Iterator[object]:
r"""Iterate over all pending or persistent instances within this
Session.
return self._proxied.__iter__()
- def add(self, instance: Any, _warn: bool = True) -> None:
+ def add(self, instance: object, _warn: bool = True) -> None:
r"""Place an object in the ``Session``.
.. container:: class_bases
return self._proxied.add(instance, _warn=_warn)
- def add_all(self, instances):
+ def add_all(self, instances: Iterable[object]) -> None:
r"""Add the given collection of instances to this ``Session``.
.. container:: class_bases
return self._proxied.add_all(instances)
- def expire(self, instance, attribute_names=None):
+ def expire(
+ self, instance: object, attribute_names: Optional[Iterable[str]] = None
+ ) -> None:
r"""Expire the attributes on an instance.
.. container:: class_bases
return self._proxied.expire(instance, attribute_names=attribute_names)
- def expire_all(self):
+ def expire_all(self) -> None:
r"""Expires all persistent instances within this Session.
.. container:: class_bases
return self._proxied.expire_all()
- def expunge(self, instance):
+ def expunge(self, instance: object) -> None:
r"""Remove the `instance` from this ``Session``.
.. container:: class_bases
return self._proxied.expunge(instance)
- def expunge_all(self):
+ def expunge_all(self) -> None:
r"""Remove all object instances from this ``Session``.
.. container:: class_bases
return self._proxied.expunge_all()
- def is_modified(self, instance, include_collections=True):
+ def is_modified(
+ self, instance: object, include_collections: bool = True
+ ) -> bool:
r"""Return ``True`` if the given instance has locally
modified attributes.
instance, include_collections=include_collections
)
- def in_transaction(self):
+ def in_transaction(self) -> bool:
r"""Return True if this :class:`_orm.Session` has begun a transaction.
.. container:: class_bases
return self._proxied.in_transaction()
- def in_nested_transaction(self):
+ def in_nested_transaction(self) -> bool:
r"""Return True if this :class:`_orm.Session` has begun a nested
transaction, e.g. SAVEPOINT.
return self._proxied.new
@property
- def identity_map(self) -> identity.IdentityMap:
+ def identity_map(self) -> IdentityMap:
r"""Proxy for the :attr:`_orm.Session.identity_map` attribute
on behalf of the :class:`_asyncio.AsyncSession` class.
return self._proxied.identity_map
@identity_map.setter
- def identity_map(self, attr: identity.IdentityMap) -> None:
+ def identity_map(self, attr: IdentityMap) -> None:
self._proxied.identity_map = attr
@property
return self._proxied.info
@classmethod
- def object_session(cls, instance: Any) -> "Session":
+ def object_session(cls, instance: object) -> Optional[Session]:
r"""Return the :class:`.Session` to which an object belongs.
.. container:: class_bases
@classmethod
def identity_key(
cls,
- class_=None,
- ident=None,
+ class_: Optional[Type[Any]] = None,
+ ident: Union[Any, Tuple[Any, ...]] = None,
*,
- instance=None,
- row=None,
- identity_token=None,
- ) -> _IdentityKeyType:
+ instance: Optional[Any] = None,
+ row: Optional[Row] = None,
+ identity_token: Optional[Any] = None,
+ ) -> _IdentityKeyType[Any]:
r"""Return an identity key.
.. container:: class_bases
return None
-def async_session(session):
+def async_session(session: Session) -> AsyncSession:
"""Return the :class:`_asyncio.AsyncSession` which is proxying the given
:class:`_orm.Session` object, if any.
def property(self) -> Any:
return None
- def adapt_to_entity(self, adapt_to_entity: AliasedInsp) -> Comparator[_T]:
+ def adapt_to_entity(
+ self, adapt_to_entity: AliasedInsp[Any]
+ ) -> Comparator[_T]:
# interesting....
return self
class_: Type[_T],
*attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
**kwargs: Any,
-) -> "Composite[_T]":
+) -> Composite[_T]:
...
def composite(
*attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
**kwargs: Any,
-) -> "Composite[Any]":
+) -> Composite[Any]:
...
class_: Any = None,
*attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
**kwargs: Any,
-) -> "Composite[Any]":
+) -> Composite[Any]:
r"""Return a composite column-based property for use with a Mapper.
See the mapping documentation section :ref:`mapper_composite` for a
from __future__ import annotations
+import operator
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
+from sqlalchemy.orm.interfaces import UserDefinedOption
+from ..util.typing import Protocol
+from ..util.typing import TypeGuard
if TYPE_CHECKING:
+ from .attributes import AttributeImpl
+ from .attributes import CollectionAttributeImpl
+ from .base import PassiveFlag
+ from .descriptor_props import _CompositeClassProto
from .mapper import Mapper
+ from .state import InstanceState
+ from .util import AliasedClass
from .util import AliasedInsp
+ from ..sql.base import ExecutableOption
-_EntityType = Union[Mapper, AliasedInsp]
+_T = TypeVar("_T", bound=Any)
+
+_O = TypeVar("_O", bound=Any)
+"""The 'ORM mapped object' type.
+I would have preferred this were bound=object however it seems
+to not travel in all situations when defined in that way.
+"""
+
+_InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"]
+
+_EntityType = Union[_T, "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"]
+
+
+_InstanceDict = Dict[str, Any]
+
+_IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]]
+
+
+class _LoaderCallable(Protocol):
+ def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any:
+ ...
+
+
+def is_user_defined_option(
+ opt: ExecutableOption,
+) -> TypeGuard[UserDefinedOption]:
+ return not opt._is_core and opt._is_user_defined # type: ignore
+
+
+def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]:
+ return hasattr(obj, "__composite_values__")
+
+
+if TYPE_CHECKING:
+
+ def is_collection_impl(
+ impl: AttributeImpl,
+ ) -> TypeGuard[CollectionAttributeImpl]:
+ ...
+
+else:
+ is_collection_impl = operator.attrgetter("collection")
from collections import namedtuple
import operator
-import typing
from typing import Any
from typing import Callable
+from typing import Collection
+from typing import Dict
from typing import List
from typing import NamedTuple
+from typing import Optional
+from typing import overload
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .base import CALLABLES_OK
from .base import DEFERRED_HISTORY_LOAD
from .base import INIT_OK
-from .base import instance_dict
-from .base import instance_state
+from .base import instance_dict as instance_dict
+from .base import instance_state as instance_state
from .base import instance_str
from .base import LOAD_AGAINST_COMMITTED
from .base import manager_of_class
from .base import PASSIVE_OFF
from .base import PASSIVE_ONLY_PERSISTENT
from .base import PASSIVE_RETURN_NO_VALUE
+from .base import PassiveFlag
from .base import RELATED_OBJECT_OK # noqa
from .base import SQL_OK # noqa
from .base import state_str
from ..sql import traversals
from ..sql import visitors
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
+ from .state import InstanceState
from ..sql.dml import _DMLColumnElement
from ..sql.elements import ColumnElement
from ..sql.elements import SQLCoreOperations
is_attribute = True
+ impl: AttributeImpl
+
# PropComparator has a __visit_name__ to participate within
# traversals. Disambiguate the attribute vs. a comparator.
__visit_name__ = "orm_instrumented_attribute"
def __delete__(self, instance):
self.impl.delete(instance_state(instance), instance_dict(instance))
- def __get__(self, instance, owner):
+ @overload
+ def __get__(
+ self, instance: None, owner: Type[Any]
+ ) -> InstrumentedAttribute:
+ ...
+
+ @overload
+ def __get__(self, instance: object, owner: Type[Any]) -> Optional[_T]:
+ ...
+
+ def __get__(
+ self, instance: Optional[object], owner: Type[Any]
+ ) -> Union[InstrumentedAttribute, Optional[_T]]:
if instance is None:
return self
class AttributeImpl:
"""internal implementation for instrumented attributes."""
+ collection: bool
+
def __init__(
self,
class_,
state.parents[id_] = False
- def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive=PASSIVE_OFF,
+ ) -> History:
raise NotImplementedError()
def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
):
raise AttributeError("%s object does not have a value" % self)
- def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: Dict[str, Any],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> History:
if self.key in dict_:
return History.from_scalar_attribute(self, state, dict_[self.key])
elif self.key in state.committed_state:
def set(
self,
- state,
- dict_,
- value,
- initiator,
- passive=PASSIVE_OFF,
- check_old=None,
- pop=False,
+ state: InstanceState[Any],
+ dict_: Dict[str, Any],
+ value: Any,
+ initiator: Optional[Event],
+ passive: PassiveFlag = PASSIVE_OFF,
+ check_old: Optional[object] = None,
+ pop: bool = False,
):
if self.dispatch._active_history:
old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
if fire_event:
self.dispatch.dispose_collection(state, collection, adapter)
- def _invalidate_collection(self, collection):
+ def _invalidate_collection(self, collection: Collection) -> None:
adapter = getattr(collection, "_sa_adapter")
adapter.invalidated = True
from typing import Generic
from typing import Optional
from typing import overload
-from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .. import inspection
from .. import util
from ..sql.elements import SQLCoreOperations
+from ..util import FastIntFlag
from ..util.langhelpers import TypingOnly
from ..util.typing import Concatenate
from ..util.typing import Literal
from ..util.typing import Self
if typing.TYPE_CHECKING:
+ from ._typing import _InternalEntityType
from .attributes import InstrumentedAttribute
from .mapper import Mapper
+ from .state import InstanceState
_T = TypeVar("_T", bound=Any)
+_O = TypeVar("_O", bound=object)
-_IdentityKeyType = Tuple[type, Tuple[Any, ...], Optional[str]]
-
-PASSIVE_NO_RESULT = util.symbol(
- "PASSIVE_NO_RESULT",
+class LoaderCallableStatus(Enum):
+ PASSIVE_NO_RESULT = 0
"""Symbol returned by a loader callable or other attribute/history
retrieval operation when a value could not be determined, based
on loader callable flags.
- """,
-)
+ """
-PASSIVE_CLASS_MISMATCH = util.symbol(
- "PASSIVE_CLASS_MISMATCH",
+ PASSIVE_CLASS_MISMATCH = 1
"""Symbol indicating that an object is locally present for a given
primary key identity but it is not of the requested class. The
- return value is therefore None and no SQL should be emitted.""",
-)
+ return value is therefore None and no SQL should be emitted."""
-ATTR_WAS_SET = util.symbol(
- "ATTR_WAS_SET",
+ ATTR_WAS_SET = 2
"""Symbol returned by a loader callable to indicate the
retrieved value, or values, were assigned to their attributes
on the target object.
- """,
-)
+ """
-ATTR_EMPTY = util.symbol(
- "ATTR_EMPTY",
+ ATTR_EMPTY = 3
"""Symbol used internally to indicate an attribute had no callable.""",
-)
-NO_VALUE = util.symbol(
- "NO_VALUE",
+ NO_VALUE = 4
"""Symbol which may be placed as the 'previous' value of an attribute,
indicating no value was loaded for an attribute when it was modified,
and flags indicated we were not to load it.
- """,
-)
+ """
+
+ NEVER_SET = NO_VALUE
+ """
+ Synonymous with NO_VALUE
+
+ .. versionchanged:: 1.4 NEVER_SET was merged with NO_VALUE
+
+ """
+
+
+(
+ PASSIVE_NO_RESULT,
+ PASSIVE_CLASS_MISMATCH,
+ ATTR_WAS_SET,
+ ATTR_EMPTY,
+ NO_VALUE,
+) = tuple(LoaderCallableStatus)
+
NEVER_SET = NO_VALUE
-"""
-Synonymous with NO_VALUE
-.. versionchanged:: 1.4 NEVER_SET was merged with NO_VALUE
-"""
-NO_CHANGE = util.symbol(
- "NO_CHANGE",
+class PassiveFlag(FastIntFlag):
+ """Bitflag interface that passes options onto loader callables"""
+
+ NO_CHANGE = 0
"""No callables or SQL should be emitted on attribute access
and no state should change
- """,
- canonical=0,
-)
+ """
-CALLABLES_OK = util.symbol(
- "CALLABLES_OK",
+ CALLABLES_OK = 1
"""Loader callables can be fired off if a value
is not present.
- """,
- canonical=1,
-)
+ """
-SQL_OK = util.symbol(
- "SQL_OK",
- """Loader callables can emit SQL at least on scalar value attributes.""",
- canonical=2,
-)
+ SQL_OK = 2
+ """Loader callables can emit SQL at least on scalar value attributes."""
-RELATED_OBJECT_OK = util.symbol(
- "RELATED_OBJECT_OK",
+ RELATED_OBJECT_OK = 4
"""Callables can use SQL to load related objects as well
as scalar value attributes.
- """,
- canonical=4,
-)
+ """
-INIT_OK = util.symbol(
- "INIT_OK",
+ INIT_OK = 8
"""Attributes should be initialized with a blank
value (None or an empty collection) upon get, if no other
value can be obtained.
- """,
- canonical=8,
-)
+ """
-NON_PERSISTENT_OK = util.symbol(
- "NON_PERSISTENT_OK",
- """Callables can be emitted if the parent is not persistent.""",
- canonical=16,
-)
+ NON_PERSISTENT_OK = 16
+ """Callables can be emitted if the parent is not persistent."""
-LOAD_AGAINST_COMMITTED = util.symbol(
- "LOAD_AGAINST_COMMITTED",
+ LOAD_AGAINST_COMMITTED = 32
"""Callables should use committed values as primary/foreign keys during a
load.
- """,
- canonical=32,
-)
+ """
-NO_AUTOFLUSH = util.symbol(
- "NO_AUTOFLUSH",
+ NO_AUTOFLUSH = 64
"""Loader callables should disable autoflush.""",
- canonical=64,
-)
-NO_RAISE = util.symbol(
- "NO_RAISE",
- """Loader callables should not raise any assertions""",
- canonical=128,
-)
+ NO_RAISE = 128
+ """Loader callables should not raise any assertions"""
-DEFERRED_HISTORY_LOAD = util.symbol(
- "DEFERRED_HISTORY_LOAD",
- """indicates special load of the previous value of an attribute""",
- canonical=256,
-)
+ DEFERRED_HISTORY_LOAD = 256
+ """indicates special load of the previous value of an attribute"""
-# pre-packaged sets of flags used as inputs
-PASSIVE_OFF = util.symbol(
- "PASSIVE_OFF",
- "Callables can be emitted in all cases.",
- canonical=(
+ # pre-packaged sets of flags used as inputs
+ PASSIVE_OFF = (
RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK
- ),
-)
-PASSIVE_RETURN_NO_VALUE = util.symbol(
- "PASSIVE_RETURN_NO_VALUE",
- """PASSIVE_OFF ^ INIT_OK""",
- canonical=PASSIVE_OFF ^ INIT_OK,
-)
-PASSIVE_NO_INITIALIZE = util.symbol(
- "PASSIVE_NO_INITIALIZE",
- "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK",
- canonical=PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK,
-)
-PASSIVE_NO_FETCH = util.symbol(
- "PASSIVE_NO_FETCH", "PASSIVE_OFF ^ SQL_OK", canonical=PASSIVE_OFF ^ SQL_OK
-)
-PASSIVE_NO_FETCH_RELATED = util.symbol(
- "PASSIVE_NO_FETCH_RELATED",
- "PASSIVE_OFF ^ RELATED_OBJECT_OK",
- canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK,
-)
-PASSIVE_ONLY_PERSISTENT = util.symbol(
- "PASSIVE_ONLY_PERSISTENT",
- "PASSIVE_OFF ^ NON_PERSISTENT_OK",
- canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK,
-)
+ )
+ "Callables can be emitted in all cases."
+
+ PASSIVE_RETURN_NO_VALUE = PASSIVE_OFF ^ INIT_OK
+ """PASSIVE_OFF ^ INIT_OK"""
+
+ PASSIVE_NO_INITIALIZE = PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK
+ "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK"
+
+ PASSIVE_NO_FETCH = PASSIVE_OFF ^ SQL_OK
+ "PASSIVE_OFF ^ SQL_OK"
+
+ PASSIVE_NO_FETCH_RELATED = PASSIVE_OFF ^ RELATED_OBJECT_OK
+ "PASSIVE_OFF ^ RELATED_OBJECT_OK"
+
+ PASSIVE_ONLY_PERSISTENT = PASSIVE_OFF ^ NON_PERSISTENT_OK
+ "PASSIVE_OFF ^ NON_PERSISTENT_OK"
+
+
+(
+ NO_CHANGE,
+ CALLABLES_OK,
+ SQL_OK,
+ RELATED_OBJECT_OK,
+ INIT_OK,
+ NON_PERSISTENT_OK,
+ LOAD_AGAINST_COMMITTED,
+ NO_AUTOFLUSH,
+ NO_RAISE,
+ DEFERRED_HISTORY_LOAD,
+ PASSIVE_OFF,
+ PASSIVE_RETURN_NO_VALUE,
+ PASSIVE_NO_INITIALIZE,
+ PASSIVE_NO_FETCH,
+ PASSIVE_NO_FETCH_RELATED,
+ PASSIVE_ONLY_PERSISTENT,
+) = tuple(PassiveFlag)
DEFAULT_MANAGER_ATTR = "_sa_class_manager"
DEFAULT_STATE_ATTR = "_sa_instance_state"
return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
-instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
+if TYPE_CHECKING:
+
+ def instance_state(instance: _O) -> InstanceState[_O]:
+ ...
+
+ def instance_dict(instance: object) -> Dict[str, Any]:
+ ...
-instance_dict = operator.attrgetter("__dict__")
+else:
+ instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
+ instance_dict = operator.attrgetter("__dict__")
-def instance_str(instance):
+
+def instance_str(instance: object) -> str:
"""Return a string describing an instance."""
return state_str(instance_state(instance))
-def state_str(state):
+def state_str(state: InstanceState[Any]) -> str:
"""Return a string describing an instance via its InstanceState."""
if state is None:
return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj()))
-def state_class_str(state):
+def state_class_str(state: InstanceState[Any]) -> str:
"""Return a string describing an instance's class via its
InstanceState.
"""
return "<%s>" % (state.class_.__name__,)
-def attribute_str(instance, attribute):
+def attribute_str(instance: object, attribute: str) -> str:
return instance_str(instance) + "." + attribute
-def state_attribute_str(state, attribute):
+def state_attribute_str(state: InstanceState[Any], attribute: str) -> str:
return state_str(state) + "." + attribute
-def object_mapper(instance):
+def object_mapper(instance: _T) -> Mapper[_T]:
"""Given an object, return the primary Mapper associated with the object
instance.
return object_state(instance).mapper
-def object_state(instance):
+def object_state(instance: _T) -> InstanceState[_T]:
"""Given an object, return the :class:`.InstanceState`
associated with the object.
@inspection._inspects(object)
-def _inspect_mapped_object(instance):
+def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]:
try:
return instance_state(instance)
except (exc.UnmappedClassError,) + exc.NO_STATE:
return None
-def _class_to_mapper(class_or_mapper):
+def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]:
insp = inspection.inspect(class_or_mapper, False)
if insp is not None:
return insp.mapper
raise exc.UnmappedClassError(class_or_mapper)
-def _mapper_or_none(entity):
+def _mapper_or_none(
+ entity: Union[_T, _InternalEntityType[_T]]
+) -> Optional[Mapper[_T]]:
"""Return the :class:`_orm.Mapper` for the given class or None if the
class is not mapped.
"""
"""
+ __slots__ = ()
+
@util.memoized_property
def info(self) -> Dict[Any, Any]:
"""Info dictionary associated with the object, allowing user-defined
import itertools
from typing import Any
+from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from ..sql.visitors import InternalTraversal
if TYPE_CHECKING:
- from ._typing import _EntityType
+ from ._typing import _InternalEntityType
from ..sql.compiler import _CompilerStackEntry
from ..sql.dml import _DMLTableElement
from ..sql.elements import ColumnElement
statement: Union[Select, FromStatement]
select_statement: Union[Select, FromStatement]
_entities: List[_QueryEntity]
- _polymorphic_adapters: Dict[_EntityType, ORMAdapter]
+ _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter]
compile_options: Union[
Type[default_compile_options], default_compile_options
]
return compiler.process(compile_state.statement, **kw)
+ @property
+ def column_descriptions(self):
+ """Return a :term:`plugin-enabled` 'column descriptions' structure
+ referring to the columns which are SELECTed by this statement.
+
+ See the section :ref:`queryguide_inspection` for an overview
+ of this feature.
+
+ .. seealso::
+
+ :ref:`queryguide_inspection` - ORM background
+
+ """
+ meth = cast(
+ ORMSelectCompileState, SelectState.get_plugin_class(self)
+ ).get_column_descriptions
+ return meth(self)
+
def _ensure_disambiguated_names(self):
return self
import typing
from typing import Any
from typing import Callable
+from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
+from typing import Type
from typing import TypeVar
from typing import Union
from .. import util
from ..sql import expression
from ..sql import operators
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from .attributes import InstrumentedAttribute
from .properties import MappedColumn
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql.schema import Column
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
+class _CompositeClassProto(Protocol):
+ def __composite_values__(self) -> Tuple[Any, ...]:
+ ...
+
+
class DescriptorProperty(MapperProperty[_T]):
""":class:`.MapperProperty` which proxies access to a
user-defined descriptor."""
mapper.class_manager.instrument_attribute(self.key, proxy_attr)
+_CompositeAttrType = Union[
+ str, "Column[Any]", "MappedColumn[Any]", "InstrumentedAttribute[Any]"
+]
+
+
class Composite(
_MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T]
):
"""
- composite_class: Union[type, Callable[..., type]]
- attrs: Tuple[
- Union[sql.ColumnElement[Any], "MappedColumn", str, Mapped[Any]], ...
+ composite_class: Union[
+ Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]]
]
+ attrs: Tuple[_CompositeAttrType, ...]
- def __init__(self, class_=None, *attrs, **kwargs):
+ def __init__(
+ self,
+ class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None,
+ *attrs: _CompositeAttrType,
+ active_history: bool = False,
+ deferred: bool = False,
+ group: Optional[str] = None,
+ comparator_factory: Optional[Type[Comparator]] = None,
+ info: Optional[Dict[Any, Any]] = None,
+ ):
super().__init__()
if isinstance(class_, (Mapped, str, sql.ColumnElement)):
self.composite_class = class_
self.attrs = attrs
- self.active_history = kwargs.get("active_history", False)
- self.deferred = kwargs.get("deferred", False)
- self.group = kwargs.get("group", None)
- self.comparator_factory = kwargs.pop(
- "comparator_factory", self.__class__.Comparator
+ self.active_history = active_history
+ self.deferred = deferred
+ self.group = group
+ self.comparator_factory = (
+ comparator_factory
+ if comparator_factory is not None
+ else self.__class__.Comparator
)
self._generated_composite_accessor = None
- if "info" in kwargs:
- self.info = kwargs.pop("info")
+ if info is not None:
+ self.info = info
util.set_creation_order(self)
self._create_descriptor()
super().instrument_class(mapper)
self._setup_event_handlers()
- def _composite_values_from_instance(self, value):
+ def _composite_values_from_instance(
+ self, value: _CompositeClassProto
+ ) -> Tuple[Any, ...]:
if self._generated_composite_accessor:
return self._generated_composite_accessor(value)
else:
from __future__ import annotations
+from typing import Any
+from typing import Optional
+from typing import Type
+
from .. import exc as sa_exc
from .. import util
from ..exc import MultipleResultsFound # noqa
"""An mapping operation was requested for an unknown instance."""
@util.preload_module("sqlalchemy.orm.base")
- def __init__(self, obj, msg=None):
+ def __init__(self, obj: object, msg: Optional[str] = None):
base = util.preloaded.orm_base
if not msg:
"was called." % (name, name)
)
except UnmappedClassError:
- msg = _default_unmapped(type(obj))
+ msg = f"Class '{_safe_cls_name(type(obj))}' is not mapped"
if isinstance(obj, type):
msg += (
"; was a class (%s) supplied where an instance was "
class UnmappedClassError(UnmappedError):
"""An mapping operation was requested for an unknown class."""
- def __init__(self, cls, msg=None):
+ def __init__(self, cls: Type[object], msg: Optional[str] = None):
if not msg:
msg = _default_unmapped(cls)
UnmappedError.__init__(self, msg)
- def __reduce__(self):
+ def __reduce__(self) -> Any:
return self.__class__, (None, self.args[0])
@util.preload_module("sqlalchemy.orm.base")
-def _default_unmapped(cls):
+def _default_unmapped(cls) -> Optional[str]:
base = util.preloaded.orm_base
try:
name = _safe_cls_name(cls)
if not mappers:
- return "Class '%s' is not mapped" % name
+ return f"Class '{name}' is not mapped"
+ else:
+ return None
from __future__ import annotations
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import TYPE_CHECKING
+from typing import TypeVar
import weakref
from . import util as orm_util
from .. import exc as sa_exc
+if TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from .state import InstanceState
+
+
+_T = TypeVar("_T", bound=Any)
+
+_O = TypeVar("_O", bound=object)
+
class IdentityMap:
- def __init__(self):
+ _wr: weakref.ref[IdentityMap]
+
+ _dict: Dict[_IdentityKeyType[Any], Any]
+ _modified: Set[InstanceState[Any]]
+
+ def __init__(self) -> None:
self._dict = {}
self._modified = set()
self._wr = weakref.ref(self)
- def _kill(self):
- self._add_unpresent = _killed
+ def _kill(self) -> None:
+ self._add_unpresent = _killed # type: ignore
+
+ def all_states(self) -> List[InstanceState[Any]]:
+ raise NotImplementedError()
+
+ def contains_state(self, state: InstanceState[Any]) -> bool:
+ raise NotImplementedError()
+
+ def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
+ raise NotImplementedError()
+
+ def safe_discard(self, state: InstanceState[Any]) -> None:
+ raise NotImplementedError()
+
+ def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
+ raise NotImplementedError()
+
+ def get(
+ self, key: _IdentityKeyType[_O], default: Optional[_O] = None
+ ) -> Optional[_O]:
+ raise NotImplementedError()
def keys(self):
return self._dict.keys()
- def replace(self, state):
+ def values(self) -> Iterable[object]:
+ raise NotImplementedError()
+
+ def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]:
+ raise NotImplementedError()
+
+ def add(self, state: InstanceState[Any]) -> bool:
raise NotImplementedError()
- def add(self, state):
+ def _fast_discard(self, state: InstanceState[Any]) -> None:
raise NotImplementedError()
- def _add_unpresent(self, state, key):
+ def _add_unpresent(
+ self, state: InstanceState[Any], key: _IdentityKeyType[Any]
+ ) -> None:
"""optional inlined form of add() which can assume item isn't present
in the map"""
self.add(state)
- def update(self, dict_):
- raise NotImplementedError("IdentityMap uses add() to insert data")
-
- def clear(self):
- raise NotImplementedError("IdentityMap uses remove() to remove data")
-
- def _manage_incoming_state(self, state):
+ def _manage_incoming_state(self, state: InstanceState[Any]) -> None:
state._instance_dict = self._wr
if state.modified:
self._modified.add(state)
- def _manage_removed_state(self, state):
+ def _manage_removed_state(self, state: InstanceState[Any]) -> None:
del state._instance_dict
if state.modified:
self._modified.discard(state)
- def _dirty_states(self):
+ def _dirty_states(self) -> Set[InstanceState[Any]]:
return self._modified
- def check_modified(self):
+ def check_modified(self) -> bool:
"""return True if any InstanceStates present have been marked
as 'modified'.
"""
return bool(self._modified)
- def has_key(self, key):
+ def has_key(self, key: _IdentityKeyType[Any]) -> bool:
return key in self
- def popitem(self):
- raise NotImplementedError("IdentityMap uses remove() to remove data")
-
- def pop(self, key, *args):
- raise NotImplementedError("IdentityMap uses remove() to remove data")
-
- def setdefault(self, key, default=None):
- raise NotImplementedError("IdentityMap uses add() to insert data")
-
- def __len__(self):
+ def __len__(self) -> int:
return len(self._dict)
- def copy(self):
- raise NotImplementedError()
-
- def __setitem__(self, key, value):
- raise NotImplementedError("IdentityMap uses add() to insert data")
-
- def __delitem__(self, key):
- raise NotImplementedError("IdentityMap uses remove() to remove data")
-
class WeakInstanceDict(IdentityMap):
- def __getitem__(self, key):
+ _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]]
+
+ def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
state = self._dict[key]
o = state.obj()
if o is None:
raise KeyError(key)
return o
- def __contains__(self, key):
+ def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
try:
if key in self._dict:
state = self._dict[key]
else:
return o is not None
- def contains_state(self, state):
+ def contains_state(self, state: InstanceState[Any]) -> bool:
if state.key in self._dict:
try:
return self._dict[state.key] is state
else:
return False
- def replace(self, state):
+ def replace(
+ self, state: InstanceState[Any]
+ ) -> Optional[InstanceState[Any]]:
if state.key in self._dict:
try:
existing = self._dict[state.key]
except KeyError:
# catch gc removed the key after we just checked for it
- pass
+ existing = None
else:
if existing is not state:
self._manage_removed_state(existing)
self._manage_incoming_state(state)
return existing
- def add(self, state):
+ def add(self, state: InstanceState[Any]) -> bool:
key = state.key
# inline of self.__contains__
if key in self._dict:
self._manage_incoming_state(state)
return True
- def _add_unpresent(self, state, key):
+ def _add_unpresent(
+ self, state: InstanceState[Any], key: _IdentityKeyType[Any]
+ ) -> None:
# inlined form of add() called by loading.py
self._dict[key] = state
state._instance_dict = self._wr
- def get(self, key, default=None):
+ def get(
+ self, key: _IdentityKeyType[_O], default: Optional[_O] = None
+ ) -> Optional[_O]:
if key not in self._dict:
return default
try:
return default
return o
- def items(self):
+ def items(self) -> List[InstanceState[Any]]:
values = self.all_states()
result = []
for state in values:
result.append((state.key, value))
return result
- def values(self):
+ def values(self) -> List[object]:
values = self.all_states()
result = []
for state in values:
return result
- def __iter__(self):
+ def __iter__(self) -> Iterator[_IdentityKeyType[Any]]:
return iter(self.keys())
- def all_states(self):
+ def all_states(self) -> List[InstanceState[Any]]:
return list(self._dict.values())
- def _fast_discard(self, state):
+ def _fast_discard(self, state: InstanceState[Any]) -> None:
# used by InstanceState for state being
# GC'ed, inlines _managed_removed_state
try:
if st is state:
self._dict.pop(state.key, None)
- def discard(self, state):
+ def discard(self, state: InstanceState[Any]) -> None:
self.safe_discard(state)
- def safe_discard(self, state):
+ def safe_discard(self, state: InstanceState[Any]) -> None:
if state.key in self._dict:
try:
st = self._dict[state.key]
self._manage_removed_state(state)
-def _killed(state, key):
+def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn:
# external function to avoid creating cycles when assigned to
# the IdentityMap
raise sa_exc.InvalidRequestError(
from __future__ import annotations
+from typing import Any
+from typing import Dict
+from typing import Generic
+from typing import Set
+from typing import TYPE_CHECKING
+from typing import TypeVar
+
from . import base
from . import collections
from . import exc
from . import interfaces
from . import state
from .. import util
+from ..event import EventTarget
from ..util import HasMemoized
+from ..util.typing import Protocol
+if TYPE_CHECKING:
+ from .attributes import InstrumentedAttribute
+ from .mapper import Mapper
+ from ..event import dispatcher
+_T = TypeVar("_T", bound=Any)
DEL_ATTR = util.symbol("DEL_ATTR")
-class ClassManager(HasMemoized, dict):
+class _ExpiredAttributeLoaderProto(Protocol):
+ def __call__(
+ self,
+ state: state.InstanceState[Any],
+ toload: Set[str],
+ passive: base.PassiveFlag,
+ ):
+ ...
+
+
+class ClassManager(
+ HasMemoized,
+ Dict[str, "InstrumentedAttribute[Any]"],
+ Generic[_T],
+ EventTarget,
+):
"""Tracks state information at the class level."""
+ dispatch: dispatcher[ClassManager]
+
MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
STATE_ATTR = base.DEFAULT_STATE_ATTR
_state_setter = staticmethod(util.attrsetter(STATE_ATTR))
- expired_attribute_loader = None
+ expired_attribute_loader: _ExpiredAttributeLoaderProto
"previously known as deferred_scalar_loader"
init_method = None
factory = None
- mapper = None
+
declarative_scan = None
registry = None
return frozenset([attr.impl for attr in self.values()])
@util.memoized_property
- def mapper(self):
+ def mapper(self) -> Mapper[_T]:
# raises unless self.mapper has been assigned
raise exc.UnmappedClassError(self.class_)
def teardown_instance(self, instance):
delattr(instance, self.STATE_ATTR)
- def _serialize(self, state, state_dict):
+ def _serialize(
+ self, state: state.InstanceState, state_dict: Dict[str, Any]
+ ) -> _SerializeManager:
return _SerializeManager(state, state_dict)
def _new_state_if_none(self, instance):
"""
- def __init__(self, state, d):
+ def __init__(self, state: state.InstanceState[Any], d: Dict[str, Any]):
self.class_ = state.class_
manager = state.manager
manager.dispatch.pickle(state, d)
"""
+ _is_core = False
+
+ _is_user_defined = False
+
_is_compile_state = False
_is_criteria_option = False
_is_legacy_option = False
+ _is_user_defined = True
+
propagate_to_loaders = False
"""if True, indicate this option should be carried along
to "secondary" Query objects produced during lazy loads
from __future__ import annotations
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
+from sqlalchemy.orm.context import FromStatement
from . import attributes
from . import exc as orm_exc
from . import path_registry
from .base import _DEFER_FOR_STATE
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
+from .base import PassiveFlag
from .util import _none_set
from .util import state_str
from .. import exc as sa_exc
from ..engine.result import FrozenResult
from ..engine.result import SimpleResultMetaData
from ..sql import util as sql_util
+from ..sql.selectable import ForUpdateArg
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import SelectState
+if TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from .base import LoaderCallableStatus
+ from .context import FromStatement
+ from .interfaces import ORMOption
+ from .mapper import Mapper
+ from .session import Session
+ from .state import InstanceState
+ from ..engine.interfaces import _ExecuteOptions
+ from ..sql import Select
+ from ..sql.base import Executable
+ from ..sql.selectable import ForUpdateArg
+
+_T = TypeVar("_T", bound=Any)
+_O = TypeVar("_O", bound=object)
_new_runid = util.counter()
session.autoflush = autoflush
-def get_from_identity(session, mapper, key, passive):
+def get_from_identity(
+ session: Session,
+ mapper: Mapper[_O],
+ key: _IdentityKeyType[_O],
+ passive: PassiveFlag,
+) -> Union[Optional[_O], LoaderCallableStatus]:
"""Look up the given key in the given session's identity map,
check the object for expired state if found.
def load_on_ident(
- session,
- statement,
- key,
- load_options=None,
- refresh_state=None,
- with_for_update=None,
- only_load_props=None,
- no_autoflush=False,
- bind_arguments=util.EMPTY_DICT,
- execution_options=util.EMPTY_DICT,
+ session: Session,
+ statement: Union[Select, FromStatement],
+ key: Optional[_IdentityKeyType],
+ *,
+ load_options: Optional[Sequence[ORMOption]] = None,
+ refresh_state: Optional[InstanceState[Any]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ only_load_props: Optional[Iterable[str]] = None,
+ no_autoflush: bool = False,
+ bind_arguments: Mapping[str, Any] = util.EMPTY_DICT,
+ execution_options: _ExecuteOptions = util.EMPTY_DICT,
):
"""Load the given identity key from the database."""
if key is not None:
def load_on_pk_identity(
- session,
- statement,
- primary_key_identity,
- load_options=None,
- refresh_state=None,
- with_for_update=None,
- only_load_props=None,
- identity_token=None,
- no_autoflush=False,
- bind_arguments=util.EMPTY_DICT,
- execution_options=util.EMPTY_DICT,
+ session: Session,
+ statement: Union[Select, FromStatement],
+ primary_key_identity: Optional[Tuple[Any, ...]],
+ *,
+ load_options: Optional[Sequence[ORMOption]] = None,
+ refresh_state: Optional[InstanceState[Any]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ only_load_props: Optional[Iterable[str]] = None,
+ identity_token: Optional[Any] = None,
+ no_autoflush: bool = False,
+ bind_arguments: Mapping[str, Any] = util.EMPTY_DICT,
+ execution_options: _ExecuteOptions = util.EMPTY_DICT,
):
"""Load the given primary key identity from the database."""
import sys
import threading
from typing import Any
+from typing import Callable
from typing import Generic
+from typing import Iterator
+from typing import Optional
+from typing import Tuple
from typing import Type
-from typing import TypeVar
+from typing import TYPE_CHECKING
import weakref
from . import attributes
from . import loading
from . import properties
from . import util as orm_util
+from ._typing import _O
from .base import _class_to_mapper
from .base import _state_mapper
from .base import class_mapper
+from .base import PassiveFlag
from .base import state_str
from .interfaces import _MappedAttribute
from .interfaces import EXT_SKIP
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util import HasMemoized
-_mapper_registries = weakref.WeakKeyDictionary()
-
+if TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
+ from .instrumentation import ClassManager
+ from .state import InstanceState
+ from ..sql.elements import ColumnElement
+ from ..sql.schema import Column
-_MC = TypeVar("_MC")
+_mapper_registries = weakref.WeakKeyDictionary()
def _all_registries():
sql_base.MemoizedHasCacheKey,
InspectionAttr,
log.Identified,
- Generic[_MC],
+ Generic[_O],
):
"""Defines an association between a Python class and a database table or
other relational structure, so that ORM operations against the class may
_dispose_called = False
_ready_for_configure = False
- class_: Type[_MC]
+ class_: Type[_O]
"""The class to which this :class:`_orm.Mapper` is mapped."""
+ _identity_class: Type[_O]
+
+ always_refresh: bool
+
@util.deprecated_params(
non_primary=(
"1.3",
)
def __init__(
self,
- class_: Type[_MC],
+ class_: Type[_O],
local_table=None,
properties=None,
primary_key=None,
"""
- primary_key = None
+ primary_key: Tuple[Column[Any], ...]
"""An iterable containing the collection of :class:`_schema.Column`
objects
which comprise the 'primary key' of the mapped table, from the
"""
- class_ = None
+ class_: Type[_O]
"""The Python class which this :class:`_orm.Mapper` maps.
This is a *read only* attribute determined during mapper construction.
"""
- class_manager = None
+ class_manager: ClassManager[_O]
"""The :class:`.ClassManager` which maintains event listeners
and class-bound descriptors for this :class:`_orm.Mapper`.
else self.persist_selectable.description,
)
- def _is_orphan(self, state):
+ def _is_orphan(self, state: InstanceState[_O]) -> bool:
orphan_possible = False
for mapper in self.iterate_to_root():
for (key, cls) in mapper._delete_orphans:
identity_token,
)
- def identity_key_from_primary_key(self, primary_key, identity_token=None):
+ def identity_key_from_primary_key(
+ self,
+ primary_key: Tuple[Any, ...],
+ identity_token: Optional[Any] = None,
+ ) -> _IdentityKeyType[_O]:
"""Return an identity-map key for use in storing/retrieving an
item from an identity map.
:param primary_key: A list of values indicating the identifier.
"""
- return self._identity_class, tuple(primary_key), identity_token
+ return (
+ self._identity_class,
+ tuple(primary_key),
+ identity_token,
+ )
- def identity_key_from_instance(self, instance):
+ def identity_key_from_instance(self, instance: _O) -> _IdentityKeyType[_O]:
"""Return the identity key for the given instance, based on
its primary key attributes.
return self._identity_key_from_state(state, attributes.PASSIVE_OFF)
def _identity_key_from_state(
- self, state, passive=attributes.PASSIVE_RETURN_NO_VALUE
- ):
+ self,
+ state: InstanceState[_O],
+ passive: PassiveFlag = attributes.PASSIVE_RETURN_NO_VALUE,
+ ) -> _IdentityKeyType[_O]:
dict_ = state.dict
manager = state.manager
return (
state.identity_token,
)
- def primary_key_from_instance(self, instance):
+ def primary_key_from_instance(self, instance: _O) -> Tuple[Any, ...]:
"""Return the list of primary key values for the given
instance.
return {self._columntoproperty[col].key for col in self._all_pk_cols}
def _get_state_attr_by_column(
- self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE
- ):
+ self,
+ state: InstanceState[_O],
+ dict_: _InstanceDict,
+ column: Column[Any],
+ passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE,
+ ) -> Any:
prop = self._columntoproperty[column]
return state.manager[prop.key].impl.get(state, dict_, passive=passive)
def _subclass_load_via_in_mapper(self):
return self._subclass_load_via_in(self)
- def cascade_iterator(self, type_, state, halt_on=None):
+ def cascade_iterator(
+ self,
+ type_: str,
+ state: InstanceState[_O],
+ halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None,
+ ) -> Iterator[
+ Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict]
+ ]:
r"""Iterate each element and its mapper in an object graph,
for all relationships that meet the given cascade rule.
from itertools import chain
import logging
from typing import Any
+from typing import Sequence
from typing import Tuple
from typing import Union
p = p[0:-1]
return p
- def serialize(self):
+ def serialize(self) -> Sequence[Any]:
path = self.path
return self._serialize_path(path)
@classmethod
- def deserialize(cls, path: Tuple) -> "PathRegistry":
+ def deserialize(cls, path: Sequence[Any]) -> PathRegistry:
assert path is not None
p = cls._deserialize_path(path)
return cls.coerce(p)
from itertools import groupby
from itertools import zip_longest
import operator
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from . import attributes
from . import evaluator
from ..sql.elements import BooleanClauseList
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+if TYPE_CHECKING:
+ from .mapper import Mapper
+ from .session import SessionTransaction
+ from .state import InstanceState
+
+_O = TypeVar("_O", bound=object)
+
def _bulk_insert(
- mapper,
- mappings,
- session_transaction,
- isstates,
- return_defaults,
- render_nulls,
-):
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+) -> None:
base_mapper = mapper.base_mapper
if session_transaction.session.connection_callable:
def _bulk_update(
- mapper, mappings, session_transaction, isstates, update_changed_only
-):
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+) -> None:
base_mapper = mapper.base_mapper
search_keys = mapper._primary_key_propkeys
from __future__ import annotations
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
from . import exc as orm_exc
from .base import class_mapper
from .session import Session
from ..util import ThreadLocalRegistry
from ..util import warn
from ..util import warn_deprecated
+from ..util.typing import Protocol
+
+if TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from .identity import IdentityMap
+ from .interfaces import ORMOption
+ from .mapper import Mapper
+ from .query import Query
+ from .session import _EntityBindKey
+ from .session import _PKIdentityArgument
+ from .session import _SessionBind
+ from .session import sessionmaker
+ from .session import SessionTransaction
+ from ..engine import Connection
+ from ..engine import Engine
+ from ..engine import Result
+ from ..engine import Row
+ from ..engine.interfaces import _CoreAnyExecuteParams
+ from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _ExecuteOptionsParameter
+ from ..engine.result import ScalarResult
+ from ..sql._typing import _ColumnsClauseArgument
+ from ..sql.base import Executable
+ from ..sql.elements import ClauseElement
+ from ..sql.selectable import ForUpdateArg
+
+
+class _QueryDescriptorType(Protocol):
+ def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]:
+ ...
+
+
+_O = TypeVar("_O", bound=object)
__all__ = ["scoped_session", "ScopedSessionMixin"]
class ScopedSessionMixin:
+ session_factory: sessionmaker
+ _support_async: bool
+ registry: ScopedRegistry[Session]
+
@property
- def _proxied(self):
- return self.registry()
+ def _proxied(self) -> Session:
+ return self.registry() # type: ignore
- def __call__(self, **kw):
+ def __call__(self, **kw: Any) -> Session:
r"""Return the current :class:`.Session`, creating it
using the :attr:`.scoped_session.session_factory` if not present.
)
return sess
- def configure(self, **kwargs):
+ def configure(self, **kwargs: Any) -> None:
"""reconfigure the :class:`.sessionmaker` used by this
:class:`.scoped_session`.
"autoflush",
"no_autoflush",
"info",
- "autocommit",
],
)
class scoped_session(ScopedSessionMixin):
"""
- _support_async = False
+ _support_async: bool = False
- session_factory = None
+ session_factory: sessionmaker
"""The `session_factory` provided to `__init__` is stored in this
attribute and may be accessed at a later time. This can be useful when
a new non-scoped :class:`.Session` or :class:`_engine.Connection` to the
database is needed."""
- def __init__(self, session_factory, scopefunc=None):
+ def __init__(
+ self,
+ session_factory: sessionmaker,
+ scopefunc: Optional[Callable[[], Any]] = None,
+ ):
+
"""Construct a new :class:`.scoped_session`.
:param session_factory: a factory to create new :class:`.Session`
else:
self.registry = ThreadLocalRegistry(session_factory)
- def remove(self):
+ def remove(self) -> None:
"""Dispose of the current :class:`.Session`, if present.
This will first call :meth:`.Session.close` method
self.registry().close()
self.registry.clear()
- def query_property(self, query_cls=None):
+ def query_property(
+ self, query_cls: Optional[Type[Query[Any]]] = None
+ ) -> _QueryDescriptorType:
"""return a class property which produces a :class:`_query.Query`
object
against the class and the current :class:`.Session` when called.
"""
class query:
- def __get__(s, instance, owner):
+ def __get__(
+ s, instance: Any, owner: Type[Any]
+ ) -> Optional[Query[Any]]:
try:
mapper = class_mapper(owner)
- if mapper:
- if query_cls:
- # custom query class
- return query_cls(mapper, session=self.registry())
- else:
- # session's configured query class
- return self.registry().query(mapper)
+ assert mapper is not None
+ if query_cls:
+ # custom query class
+ return query_cls(mapper, session=self.registry())
+ else:
+ # session's configured query class
+ return self.registry().query(mapper)
except orm_exc.UnmappedClassError:
return None
# code within this block is **programmatically,
# statically generated** by tools/generate_proxy_methods.py
- def __contains__(self, instance):
+ def __contains__(self, instance: object) -> bool:
r"""Return True if the instance is associated with this session.
.. container:: class_bases
return self._proxied.__contains__(instance)
- def __iter__(self):
+ def __iter__(self) -> Iterator[object]:
r"""Iterate over all pending or persistent instances within this
Session.
return self._proxied.__iter__()
- def add(self, instance: Any, _warn: bool = True) -> None:
+ def add(self, instance: object, _warn: bool = True) -> None:
r"""Place an object in the ``Session``.
.. container:: class_bases
return self._proxied.add(instance, _warn=_warn)
- def add_all(self, instances):
+ def add_all(self, instances: Iterable[object]) -> None:
r"""Add the given collection of instances to this ``Session``.
.. container:: class_bases
return self._proxied.add_all(instances)
- def begin(self, nested=False, _subtrans=False):
+ def begin(
+ self, nested: bool = False, _subtrans: bool = False
+ ) -> SessionTransaction:
r"""Begin a transaction, or nested transaction,
on this :class:`.Session`, if one is not already begun.
return self._proxied.begin(nested=nested, _subtrans=_subtrans)
- def begin_nested(self):
+ def begin_nested(self) -> SessionTransaction:
r"""Begin a "nested" transaction on this Session, e.g. SAVEPOINT.
.. container:: class_bases
return self._proxied.begin_nested()
- def close(self):
+ def close(self) -> None:
r"""Close out the transactional resources and ORM objects used by this
:class:`_orm.Session`.
def connection(
self,
bind_arguments: Optional[Dict[str, Any]] = None,
- execution_options: Optional["_ExecuteOptions"] = None,
+ execution_options: Optional[_ExecuteOptions] = None,
) -> "Connection":
r"""Return a :class:`_engine.Connection` object corresponding to this
:class:`.Session` object's transactional state.
bind_arguments=bind_arguments, execution_options=execution_options
)
- def delete(self, instance):
+ def delete(self, instance: object) -> None:
r"""Mark an instance as deleted.
.. container:: class_bases
def execute(
self,
- statement: "Executable",
- params: Optional["_ExecuteParams"] = None,
- execution_options: "_ExecuteOptions" = util.EMPTY_DICT,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ):
+ ) -> Result:
r"""Execute a SQL expression construct.
.. container:: class_bases
_add_event=_add_event,
)
- def expire(self, instance, attribute_names=None):
+ def expire(
+ self, instance: object, attribute_names: Optional[Iterable[str]] = None
+ ) -> None:
r"""Expire the attributes on an instance.
.. container:: class_bases
return self._proxied.expire(instance, attribute_names=attribute_names)
- def expire_all(self):
+ def expire_all(self) -> None:
r"""Expires all persistent instances within this Session.
.. container:: class_bases
return self._proxied.expire_all()
- def expunge(self, instance):
+ def expunge(self, instance: object) -> None:
r"""Remove the `instance` from this ``Session``.
.. container:: class_bases
return self._proxied.expunge(instance)
- def expunge_all(self):
+ def expunge_all(self) -> None:
r"""Remove all object instances from this ``Session``.
.. container:: class_bases
return self._proxied.expunge_all()
- def flush(self, objects=None):
+ def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
r"""Flush all the object changes to the database.
.. container:: class_bases
def get(
self,
- entity,
- ident,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- execution_options=None,
- ):
+ entity: _EntityBindKey[_O],
+ ident: _PKIdentityArgument,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ ) -> Optional[_O]:
r"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
def get_bind(
self,
- mapper=None,
- clause=None,
- bind=None,
- _sa_skip_events=None,
- _sa_skip_for_implicit_returning=False,
- ):
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ clause: Optional[ClauseElement] = None,
+ bind: Optional[_SessionBind] = None,
+ _sa_skip_events: Optional[bool] = None,
+ _sa_skip_for_implicit_returning: bool = False,
+ ) -> Union[Engine, Connection]:
r"""Return a "bind" to which this :class:`.Session` is bound.
.. container:: class_bases
_sa_skip_for_implicit_returning=_sa_skip_for_implicit_returning,
)
- def is_modified(self, instance, include_collections=True):
+ def is_modified(
+ self, instance: object, include_collections: bool = True
+ ) -> bool:
r"""Return ``True`` if the given instance has locally
modified attributes.
def bulk_save_objects(
self,
- objects,
- return_defaults=False,
- update_changed_only=True,
- preserve_order=True,
- ):
+ objects: Iterable[object],
+ return_defaults: bool = False,
+ update_changed_only: bool = True,
+ preserve_order: bool = True,
+ ) -> None:
r"""Perform a bulk save of the given list of objects.
.. container:: class_bases
)
def bulk_insert_mappings(
- self, mapper, mappings, return_defaults=False, render_nulls=False
- ):
+ self,
+ mapper: Mapper[Any],
+ mappings: Iterable[Dict[str, Any]],
+ return_defaults: bool = False,
+ render_nulls: bool = False,
+ ) -> None:
r"""Perform a bulk insert of the given list of mapping dictionaries.
.. container:: class_bases
render_nulls=render_nulls,
)
- def bulk_update_mappings(self, mapper, mappings):
+ def bulk_update_mappings(
+ self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]]
+ ) -> None:
r"""Perform a bulk update of the given list of mapping dictionaries.
.. container:: class_bases
return self._proxied.bulk_update_mappings(mapper, mappings)
- def merge(self, instance, load=True, options=None):
+ def merge(
+ self,
+ instance: _O,
+ *,
+ load: bool = True,
+ options: Optional[Sequence[ORMOption]] = None,
+ ) -> _O:
r"""Copy the state of a given instance into a corresponding instance
within this :class:`.Session`.
return self._proxied.merge(instance, load=load, options=options)
- def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query:
+ def query(
+ self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ ) -> Query[Any]:
r"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
return self._proxied.query(*entities, **kwargs)
- def refresh(self, instance, attribute_names=None, with_for_update=None):
+ def refresh(
+ self,
+ instance: object,
+ attribute_names: Optional[Iterable[str]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ ) -> None:
r"""Expire and refresh attributes on the given instance.
.. container:: class_bases
with_for_update=with_for_update,
)
- def rollback(self):
+ def rollback(self) -> None:
r"""Rollback the current transaction in progress.
.. container:: class_bases
def scalar(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[Dict[str, Any]] = None,
+ **kw: Any,
+ ) -> Any:
r"""Execute a statement and return a scalar result.
.. container:: class_bases
def scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[Dict[str, Any]] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
r"""Execute a statement and return the results as scalars.
.. container:: class_bases
return self._proxied.new
@property
- def identity_map(self) -> identity.IdentityMap:
+ def identity_map(self) -> IdentityMap:
r"""Proxy for the :attr:`_orm.Session.identity_map` attribute
on behalf of the :class:`_orm.scoping.scoped_session` class.
return self._proxied.identity_map
@identity_map.setter
- def identity_map(self, attr: identity.IdentityMap) -> None:
+ def identity_map(self, attr: IdentityMap) -> None:
self._proxied.identity_map = attr
@property
return self._proxied.info
- @property
- def autocommit(self) -> Any:
- r"""Proxy for the :attr:`_orm.Session.autocommit` attribute
- on behalf of the :class:`_orm.scoping.scoped_session` class.
-
- """ # noqa: E501
-
- return self._proxied.autocommit
-
- @autocommit.setter
- def autocommit(self, attr: Any) -> None:
- self._proxied.autocommit = attr
-
@classmethod
def close_all(cls) -> None:
r"""Close *all* sessions in memory.
return Session.close_all()
@classmethod
- def object_session(cls, instance: Any) -> "Session":
+ def object_session(cls, instance: object) -> Optional[Session]:
r"""Return the :class:`.Session` to which an object belongs.
.. container:: class_bases
@classmethod
def identity_key(
cls,
- class_=None,
- ident=None,
+ class_: Optional[Type[Any]] = None,
+ ident: Union[Any, Tuple[Any, ...]] = None,
*,
- instance=None,
- row=None,
- identity_token=None,
- ) -> _IdentityKeyType:
+ instance: Optional[Any] = None,
+ row: Optional[Row] = None,
+ identity_token: Optional[Any] = None,
+ ) -> _IdentityKeyType[Any]:
r"""Return an identity key.
.. container:: class_bases
import sys
import typing
from typing import Any
+from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import NoReturn
from typing import Optional
-from typing import overload
+from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
from . import persistence
from . import query
from . import state as statelib
+from ._typing import is_composite_class
+from ._typing import is_user_defined_option
from .base import _class_to_mapper
-from .base import _IdentityKeyType
from .base import _none_set
from .base import _state_mapper
from .base import instance_str
+from .base import LoaderCallableStatus
from .base import object_mapper
from .base import object_state
+from .base import PassiveFlag
from .base import state_str
+from .context import FromStatement
+from .context import ORMCompileState
+from .identity import IdentityMap
from .query import Query
from .state import InstanceState
from .state_changes import _StateChange
from ..engine import Connection
from ..engine import Engine
from ..engine.util import TransactionalContext
+from ..event import dispatcher
+from ..event import EventTarget
from ..inspection import inspect
from ..sql import coercions
from ..sql import dml
from ..sql import roles
+from ..sql import Select
from ..sql import visitors
from ..sql.base import CompileState
+from ..sql.selectable import ForUpdateArg
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import IdentitySet
from ..util.typing import Literal
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
+ from .interfaces import ORMOption
+ from .interfaces import UserDefinedOption
from .mapper import Mapper
+ from .path_registry import PathRegistry
+ from ..engine import Result
from ..engine import Row
+ from ..engine.base import Transaction
+ from ..engine.base import TwoPhaseTransaction
+ from ..engine.interfaces import _CoreAnyExecuteParams
+ from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _ExecuteOptionsParameter
+ from ..engine.result import ScalarResult
+ from ..event import _InstanceLevelDispatch
from ..sql._typing import _ColumnsClauseArgument
- from ..sql._typing import _ExecuteOptions
- from ..sql._typing import _ExecuteParams
from ..sql.base import Executable
+ from ..sql.elements import ClauseElement
from ..sql.schema import Table
__all__ = [
"object_session",
]
-_sessions = weakref.WeakValueDictionary()
+_sessions: weakref.WeakValueDictionary[
+ int, Session
+] = weakref.WeakValueDictionary()
"""Weak-referencing dictionary of :class:`.Session` objects.
"""
+_O = TypeVar("_O", bound=object)
statelib._sessions = _sessions
+_PKIdentityArgument = Union[Any, Tuple[Any, ...]]
-def _state_session(state):
+_EntityBindKey = Union[Type[_O], "Mapper[_O]"]
+_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"]
+_SessionBind = Union["Engine", "Connection"]
+
+
+class _ConnectionCallableProto(Protocol):
+ """a callable that returns a :class:`.Connection` given an instance.
+
+ This callable, when present on a :class:`.Session`, is called only from the
+ ORM's persistence mechanism (i.e. the unit of work flush process) to allow
+ for connection-per-instance schemes (i.e. horizontal sharding) to be used
+ as persistence time.
+
+ This callable is not present on a plain :class:`.Session`, however
+ is established when using the horizontal sharding extension.
+
+ """
+
+ def __call__(
+ self,
+ mapper: Optional[Mapper[Any]] = None,
+ instance: Optional[object] = None,
+ **kw: Any,
+ ) -> Connection:
+ ...
+
+
+def _state_session(state: InstanceState[Any]) -> Optional[Session]:
"""Given an :class:`.InstanceState`, return the :class:`.Session`
associated, if any.
"""
close_all_sessions()
- @classmethod
- @overload
- def identity_key(
- cls,
- class_: type,
- ident: Tuple[Any, ...],
- *,
- identity_token: Optional[str],
- ) -> _IdentityKeyType:
- ...
-
- @classmethod
- @overload
- def identity_key(cls, *, instance: Any) -> _IdentityKeyType:
- ...
-
- @classmethod
- @overload
- def identity_key(
- cls, class_: type, *, row: "Row", identity_token: Optional[str]
- ) -> _IdentityKeyType:
- ...
-
@classmethod
@util.preload_module("sqlalchemy.orm.util")
def identity_key(
cls,
- class_=None,
- ident=None,
+ class_: Optional[Type[Any]] = None,
+ ident: Union[Any, Tuple[Any, ...]] = None,
*,
- instance=None,
- row=None,
- identity_token=None,
- ) -> _IdentityKeyType:
+ instance: Optional[Any] = None,
+ row: Optional[Row] = None,
+ identity_token: Optional[Any] = None,
+ ) -> _IdentityKeyType[Any]:
"""Return an identity key.
This is an alias of :func:`.util.identity_key`.
)
@classmethod
- def object_session(cls, instance: Any) -> "Session":
+ def object_session(cls, instance: object) -> Optional[Session]:
"""Return the :class:`.Session` to which an object belongs.
This is an alias of :func:`.object_session`.
"_update_execution_options",
)
- session: "Session"
- statement: "Executable"
- parameters: "_ExecuteParams"
- execution_options: "_ExecuteOptions"
- local_execution_options: "_ExecuteOptions"
+ session: Session
+ statement: Executable
+ parameters: Optional[_CoreAnyExecuteParams]
+ execution_options: _ExecuteOptions
+ local_execution_options: _ExecuteOptions
bind_arguments: Dict[str, Any]
- _compile_state_cls: Type[context.ORMCompileState]
- _starting_event_idx: Optional[int]
+ _compile_state_cls: Optional[Type[ORMCompileState]]
+ _starting_event_idx: int
_events_todo: List[Any]
- _update_execution_options: Optional["_ExecuteOptions"]
+ _update_execution_options: Optional[_ExecuteOptions]
def __init__(
self,
- session: "Session",
- statement: "Executable",
- parameters: "_ExecuteParams",
- execution_options: "_ExecuteOptions",
+ session: Session,
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams],
+ execution_options: _ExecuteOptions,
bind_arguments: Dict[str, Any],
- compile_state_cls: Type[context.ORMCompileState],
- events_todo: List[Any],
+ compile_state_cls: Optional[Type[ORMCompileState]],
+ events_todo: List[_InstanceLevelDispatch[Session]],
):
self.session = session
self.statement = statement
self._compile_state_cls = compile_state_cls
self._events_todo = list(events_todo)
- def _remaining_events(self):
+ def _remaining_events(self) -> List[_InstanceLevelDispatch[Session]]:
return self._events_todo[self._starting_event_idx + 1 :]
def invoke_statement(
self,
- statement=None,
- params=None,
- execution_options=None,
- bind_arguments=None,
- ):
+ statement: Optional[Executable] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ bind_arguments: Optional[Dict[str, Any]] = None,
+ ) -> Result:
"""Execute the statement represented by this
:class:`.ORMExecuteState`, without re-invoking events that have
already proceeded.
:param statement: optional statement to be invoked, in place of the
statement currently represented by :attr:`.ORMExecuteState.statement`.
- :param params: optional dictionary of parameters which will be merged
- into the existing :attr:`.ORMExecuteState.parameters` of this
- :class:`.ORMExecuteState`.
+ :param params: optional dictionary of parameters or list of parameters
+ which will be merged into the existing
+ :attr:`.ORMExecuteState.parameters` of this :class:`.ORMExecuteState`.
+
+ .. versionchanged:: 2.0 a list of parameter dictionaries is accepted
+ for executemany executions.
:param execution_options: optional dictionary of execution options
will be merged into the existing
_bind_arguments.update(bind_arguments)
_bind_arguments["_sa_skip_events"] = True
+ _params: Optional[_CoreAnyExecuteParams]
if params:
- _params = dict(self.parameters)
- _params.update(params)
+ if self.is_executemany:
+ _params = []
+ exec_many_parameters = cast(
+ "List[Dict[str, Any]]", self.parameters
+ )
+ for _existing_params, _new_params in itertools.zip_longest(
+ exec_many_parameters,
+ cast("List[Dict[str, Any]]", params),
+ ):
+ if _existing_params is None or _new_params is None:
+ raise sa_exc.InvalidRequestError(
+ f"Can't apply executemany parameters to "
+ f"statement; number of parameter sets passed to "
+ f"Session.execute() ({len(exec_many_parameters)}) "
+ f"does not match number of parameter sets given "
+ f"to ORMExecuteState.invoke_statement() "
+ f"({len(params)})"
+ )
+ _existing_params = dict(_existing_params)
+ _existing_params.update(_new_params)
+ _params.append(_existing_params)
+ else:
+ _params = dict(cast("Dict[str, Any]", self.parameters))
+ _params.update(cast("Dict[str, Any]", params))
else:
_params = self.parameters
)
@property
- def bind_mapper(self):
+ def bind_mapper(self) -> Optional[Mapper[Any]]:
"""Return the :class:`_orm.Mapper` that is the primary "bind" mapper.
For an :class:`_orm.ORMExecuteState` object invoking an ORM
return self.bind_arguments.get("mapper", None)
@property
- def all_mappers(self):
+ def all_mappers(self) -> Sequence[Mapper[Any]]:
"""Return a sequence of all :class:`_orm.Mapper` objects that are
involved at the top level of this statement.
"""
if not self.is_orm_statement:
return []
- elif self.is_select:
+ elif isinstance(self.statement, (Select, FromStatement)):
result = []
seen = set()
for d in self.statement.column_descriptions:
seen.add(insp.mapper)
result.append(insp.mapper)
return result
- elif self.is_update or self.is_delete:
+ elif self.statement.is_dml and self.bind_mapper:
return [self.bind_mapper]
else:
return []
@property
- def is_orm_statement(self):
+ def is_orm_statement(self) -> bool:
"""return True if the operation is an ORM statement.
This indicates that the select(), update(), or delete() being
return self._compile_state_cls is not None
@property
- def is_select(self):
+ def is_executemany(self) -> bool:
+ """return True if the parameters are a multi-element list of
+ dictionaries with more than one dictionary.
+
+ .. versionadded:: 2.0
+
+ """
+ return isinstance(self.parameters, list)
+
+ @property
+ def is_select(self) -> bool:
"""return True if this is a SELECT operation."""
return self.statement.is_select
@property
- def is_insert(self):
+ def is_insert(self) -> bool:
"""return True if this is an INSERT operation."""
return self.statement.is_dml and self.statement.is_insert
@property
- def is_update(self):
+ def is_update(self) -> bool:
"""return True if this is an UPDATE operation."""
return self.statement.is_dml and self.statement.is_update
@property
- def is_delete(self):
+ def is_delete(self) -> bool:
"""return True if this is a DELETE operation."""
return self.statement.is_dml and self.statement.is_delete
@property
- def _is_crud(self):
+ def _is_crud(self) -> bool:
return isinstance(self.statement, (dml.Update, dml.Delete))
- def update_execution_options(self, **opts):
+ def update_execution_options(self, **opts: _ExecuteOptions) -> None:
+ """Update the local execution options with new values."""
# TODO: no coverage
self.local_execution_options = self.local_execution_options.union(opts)
- def _orm_compile_options(self):
+ def _orm_compile_options(
+ self,
+ ) -> Optional[
+ Union[
+ context.ORMCompileState.default_compile_options,
+ Type[context.ORMCompileState.default_compile_options],
+ ]
+ ]:
if not self.is_select:
return None
opts = self.statement._compile_options
- if opts.isinstance(context.ORMCompileState.default_compile_options):
- return opts
+ if opts is not None and opts.isinstance(
+ context.ORMCompileState.default_compile_options
+ ):
+ return opts # type: ignore
else:
return None
@property
- def lazy_loaded_from(self):
+ def lazy_loaded_from(self) -> Optional[InstanceState[Any]]:
"""An :class:`.InstanceState` that is using this statement execution
for a lazy load operation.
return self.load_options._lazy_loaded_from
@property
- def loader_strategy_path(self):
+ def loader_strategy_path(self) -> Optional[PathRegistry]:
"""Return the :class:`.PathRegistry` for the current load path.
This object represents the "path" in a query along relationships
return None
@property
- def is_column_load(self):
+ def is_column_load(self) -> bool:
"""Return True if the operation is refreshing column-oriented
attributes on an existing ORM object.
return opts is not None and opts._for_refresh_state
@property
- def is_relationship_load(self):
+ def is_relationship_load(self) -> bool:
"""Return True if this load is loading objects on behalf of a
relationship.
return path is not None and not path.is_root
@property
- def load_options(self):
+ def load_options(
+ self,
+ ) -> Union[
+ context.QueryContext.default_load_options,
+ Type[context.QueryContext.default_load_options],
+ ]:
"""Return the load_options that will be used for this execution."""
if not self.is_select:
)
@property
- def update_delete_options(self):
+ def update_delete_options(
+ self,
+ ) -> Union[
+ persistence.BulkUDCompileState.default_update_options,
+ Type[persistence.BulkUDCompileState.default_update_options],
+ ]:
"""Return the update_delete_options that will be used for this
execution."""
)
@property
- def user_defined_options(self):
+ def user_defined_options(self) -> Sequence[UserDefinedOption]:
"""The sequence of :class:`.UserDefinedOptions` that have been
associated with the statement being invoked.
return [
opt
for opt in self.statement._with_options
- if not opt._is_compile_state and not opt._is_legacy_option
+ if is_user_defined_option(opt)
]
"""
- _rollback_exception = None
+ _rollback_exception: Optional[BaseException] = None
+
+ _connections: Dict[
+ Union[Engine, Connection], Tuple[Connection, Transaction, bool, bool]
+ ]
+ session: Session
+ _parent: Optional[SessionTransaction]
+
+ _state: SessionTransactionState
+
+ _new: weakref.WeakKeyDictionary[InstanceState[Any], object]
+ _deleted: weakref.WeakKeyDictionary[InstanceState[Any], object]
+ _dirty: weakref.WeakKeyDictionary[InstanceState[Any], object]
+ _key_switches: weakref.WeakKeyDictionary[
+ InstanceState[Any], Tuple[Any, Any]
+ ]
def __init__(
self,
- session,
- parent=None,
- nested=False,
- autobegin=False,
+ session: Session,
+ parent: Optional[SessionTransaction] = None,
+ nested: bool = False,
+ autobegin: bool = False,
):
TransactionalContext._trans_ctx_check(session)
self.session.dispatch.after_transaction_create(self.session, self)
- def _raise_for_prerequisite_state(self, operation_name, state):
+ def _raise_for_prerequisite_state(
+ self, operation_name: str, state: SessionTransactionState
+ ) -> NoReturn:
if state is SessionTransactionState.DEACTIVE:
if self._rollback_exception:
raise sa_exc.PendingRollbackError(
)
@property
- def parent(self):
+ def parent(self) -> Optional[SessionTransaction]:
"""The parent :class:`.SessionTransaction` of this
:class:`.SessionTransaction`.
"""
return self._parent
- nested = False
+ nested: bool = False
"""Indicates if this is a nested, or SAVEPOINT, transaction.
When :attr:`.SessionTransaction.nested` is True, it is expected
"""
@property
- def is_active(self):
+ def is_active(self) -> bool:
return (
self.session is not None
and self._state is SessionTransactionState.ACTIVE
)
@property
- def _is_transaction_boundary(self):
+ def _is_transaction_boundary(self) -> bool:
return self.nested or not self._parent
@_StateChange.declare_states(
(SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
)
- def connection(self, bindkey, execution_options=None, **kwargs):
+ def connection(
+ self,
+ bindkey: Optional[Mapper[Any]],
+ execution_options: Optional[_ExecuteOptions] = None,
+ **kwargs: Any,
+ ) -> Connection:
bind = self.session.get_bind(bindkey, **kwargs)
return self._connection_for_bind(bind, execution_options)
@_StateChange.declare_states(
(SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
)
- def _begin(self, nested=False):
+ def _begin(self, nested: bool = False) -> SessionTransaction:
return SessionTransaction(self.session, self, nested=nested)
- def _iterate_self_and_parents(self, upto=None):
+ def _iterate_self_and_parents(
+ self, upto: Optional[SessionTransaction] = None
+ ) -> Iterable[SessionTransaction]:
current = self
- result = ()
+ result: Tuple[SessionTransaction, ...] = ()
while current:
result += (current,)
if current._parent is upto:
return result
- def _take_snapshot(self, autobegin=False):
+ def _take_snapshot(self, autobegin: bool = False) -> None:
if not self._is_transaction_boundary:
- self._new = self._parent._new
- self._deleted = self._parent._deleted
- self._dirty = self._parent._dirty
- self._key_switches = self._parent._key_switches
+ parent = self._parent
+ assert parent is not None
+ self._new = parent._new
+ self._deleted = parent._deleted
+ self._dirty = parent._dirty
+ self._key_switches = parent._key_switches
return
if not autobegin and not self.session._flushing:
self._dirty = weakref.WeakKeyDictionary()
self._key_switches = weakref.WeakKeyDictionary()
- def _restore_snapshot(self, dirty_only=False):
+ def _restore_snapshot(self, dirty_only: bool = False) -> None:
"""Restore the restoration state taken before a transaction began.
Corresponds to a rollback.
if not dirty_only or s.modified or s in self._dirty:
s._expire(s.dict, self.session.identity_map._modified)
- def _remove_snapshot(self):
+ def _remove_snapshot(self) -> None:
"""Remove the restoration state taken before a transaction began.
Corresponds to a commit.
)
self._deleted.clear()
elif self.nested:
- self._parent._new.update(self._new)
- self._parent._dirty.update(self._dirty)
- self._parent._deleted.update(self._deleted)
- self._parent._key_switches.update(self._key_switches)
+ parent = self._parent
+ assert parent is not None
+ parent._new.update(self._new)
+ parent._dirty.update(self._dirty)
+ parent._deleted.update(self._deleted)
+ parent._key_switches.update(self._key_switches)
@_StateChange.declare_states(
(SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
)
- def _connection_for_bind(self, bind, execution_options):
+ def _connection_for_bind(
+ self,
+ bind: _SessionBind,
+ execution_options: Optional[_ExecuteOptions],
+ ) -> Connection:
if bind in self._connections:
if execution_options:
if execution_options:
conn = conn.execution_options(**execution_options)
+ transaction: Transaction
if self.session.twophase and self._parent is None:
transaction = conn.begin_twophase()
elif self.nested:
# if given a future connection already in a transaction, don't
# commit that transaction unless it is a savepoint
if conn.in_nested_transaction():
- transaction = conn.get_nested_transaction()
+ transaction = conn._get_required_nested_transaction()
else:
- transaction = conn.get_transaction()
+ transaction = conn._get_required_transaction()
should_commit = False
else:
transaction = conn.begin()
self.session.dispatch.after_begin(self.session, self, conn)
return conn
- def prepare(self):
+ def prepare(self) -> None:
if self._parent is not None or not self.session.twophase:
raise sa_exc.InvalidRequestError(
"'twophase' mode not enabled, or not root transaction; "
@_StateChange.declare_states(
(SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED
)
- def _prepare_impl(self):
+ def _prepare_impl(self) -> None:
if self._parent is None or self.nested:
self.session.dispatch.before_commit(self.session)
stx = self.session._transaction
+ assert stx is not None
if stx is not self:
for subtransaction in stx._iterate_self_and_parents(upto=self):
subtransaction.commit()
if self._parent is None and self.session.twophase:
try:
for t in set(self._connections.values()):
- t[1].prepare()
+ cast("TwoPhaseTransaction", t[1]).prepare()
except:
with util.safe_reraise():
self.rollback()
self.close()
if _to_root and self._parent:
- return self._parent.commit(_to_root=True)
-
- return self._parent
+ self._parent.commit(_to_root=True)
@_StateChange.declare_states(
(
),
SessionTransactionState.CLOSED,
)
- def rollback(self, _capture_exception=False, _to_root=False):
+ def rollback(
+ self, _capture_exception: bool = False, _to_root: bool = False
+ ) -> None:
stx = self.session._transaction
+ assert stx is not None
if stx is not self:
for subtransaction in stx._iterate_self_and_parents(upto=self):
subtransaction.close()
if self._parent and _capture_exception:
self._parent._rollback_exception = sys.exc_info()[1]
- if rollback_err:
+ if rollback_err and rollback_err[1]:
raise rollback_err[1].with_traceback(rollback_err[2])
sess.dispatch.after_soft_rollback(sess, self)
if _to_root and self._parent:
- return self._parent.rollback(_to_root=True)
- return self._parent
+ self._parent.rollback(_to_root=True)
@_StateChange.declare_states(
_StateChangeStates.ANY, SessionTransactionState.CLOSED
)
- def close(self, invalidate=False):
+ def close(self, invalidate: bool = False) -> None:
if self.nested:
self.session._nested_transaction = (
self._previous_nested_transaction
self._state = SessionTransactionState.CLOSED
sess = self.session
- self.session = None
- self._connections = None
+ # TODO: these two None sets were historically after the
+ # event hook below, and in 2.0 I changed it this way for some reason,
+ # and I remember there being a reason, but not what it was.
+ # Why do we need to get rid of them at all? test_memusage::CycleTest
+ # passes with these commented out.
+ # self.session = None # type: ignore
+ # self._connections = None # type: ignore
sess.dispatch.after_transaction_end(sess, self)
- def _get_subject(self):
+ def _get_subject(self) -> Session:
return self.session
- def _transaction_is_active(self):
+ def _transaction_is_active(self) -> bool:
return self._state is SessionTransactionState.ACTIVE
- def _transaction_is_closed(self):
+ def _transaction_is_closed(self) -> bool:
return self._state is SessionTransactionState.CLOSED
- def _rollback_can_be_called(self):
+ def _rollback_can_be_called(self) -> bool:
return self._state not in (COMMITTED, CLOSED)
-class Session(_SessionClassMethods):
+class Session(_SessionClassMethods, EventTarget):
"""Manages persistence operations for ORM-mapped objects.
The Session's usage paradigm is described at :doc:`/orm/session`.
_is_asyncio = False
- identity_map: identity.IdentityMap
- _new: Dict["InstanceState", Any]
- _deleted: Dict["InstanceState", Any]
+ dispatch: dispatcher[Session]
+
+ identity_map: IdentityMap
+ """A mapping of object identities to objects themselves.
+
+ Iterating through ``Session.identity_map.values()`` provides
+ access to the full set of persistent objects (i.e., those
+ that have row identity) currently in the session.
+
+ .. seealso::
+
+ :func:`.identity_key` - helper function to produce the keys used
+ in this dictionary.
+
+ """
+
+ _new: Dict[InstanceState[Any], Any]
+ _deleted: Dict[InstanceState[Any], Any]
bind: Optional[Union[Engine, Connection]]
- __binds: Dict[
- Union[type, "Mapper", "Table"],
- Union[engine.Engine, engine.Connection],
- ]
- _flusing: bool
+ __binds: Dict[_SessionBindKey, _SessionBind]
+ _flushing: bool
_warn_on_events: bool
_transaction: Optional[SessionTransaction]
_nested_transaction: Optional[SessionTransaction]
expire_on_commit: bool
enable_baked_queries: bool
twophase: bool
- _query_cls: Type[Query]
+ _query_cls: Type[Query[Any]]
def __init__(
self,
- bind: Optional[Union[engine.Engine, engine.Connection]] = None,
+ bind: Optional[_SessionBind] = None,
autoflush: bool = True,
future: Literal[True] = True,
expire_on_commit: bool = True,
twophase: bool = False,
- binds: Optional[
- Dict[
- Union[type, "Mapper", "Table"],
- Union[engine.Engine, engine.Connection],
- ]
- ] = None,
+ binds: Optional[Dict[_SessionBindKey, _SessionBind]] = None,
enable_baked_queries: bool = True,
info: Optional[Dict[Any, Any]] = None,
- query_cls: Optional[Type[query.Query]] = None,
+ query_cls: Optional[Type[Query[Any]]] = None,
autocommit: Literal[False] = False,
):
r"""Construct a new Session.
_sessions[self.hash_key] = self
# used by sqlalchemy.engine.util.TransactionalContext
- _trans_context_manager = None
+ _trans_context_manager: Optional[TransactionalContext] = None
- connection_callable = None
+ connection_callable: Optional[_ConnectionCallableProto] = None
- def __enter__(self):
+ def __enter__(self) -> Session:
return self
- def __exit__(self, type_, value, traceback):
+ def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
self.close()
@contextlib.contextmanager
- def _maker_context_manager(self):
+ def _maker_context_manager(self) -> Iterator[Session]:
with self:
with self.begin():
yield self
- def in_transaction(self):
+ def in_transaction(self) -> bool:
"""Return True if this :class:`_orm.Session` has begun a transaction.
.. versionadded:: 1.4
"""
return self._transaction is not None
- def in_nested_transaction(self):
+ def in_nested_transaction(self) -> bool:
"""Return True if this :class:`_orm.Session` has begun a nested
transaction, e.g. SAVEPOINT.
"""
return self._nested_transaction is not None
- def get_transaction(self):
+ def get_transaction(self) -> Optional[SessionTransaction]:
"""Return the current root transaction in progress, if any.
.. versionadded:: 1.4
trans = trans._parent
return trans
- def get_nested_transaction(self):
+ def get_nested_transaction(self) -> Optional[SessionTransaction]:
"""Return the current nested transaction in progress, if any.
.. versionadded:: 1.4
return self._nested_transaction
@util.memoized_property
- def info(self):
+ def info(self) -> Dict[Any, Any]:
"""A user-modifiable dictionary.
The initial value of this dictionary can be populated using the
"""
return {}
- def _autobegin(self):
+ def _autobegin_t(self) -> SessionTransaction:
if self._transaction is None:
trans = SessionTransaction(self, autobegin=True)
assert self._transaction is trans
- return True
+ return trans
- return False
+ return self._transaction
- def begin(self, nested=False, _subtrans=False):
+ def begin(
+ self, nested: bool = False, _subtrans: bool = False
+ ) -> SessionTransaction:
"""Begin a transaction, or nested transaction,
on this :class:`.Session`, if one is not already begun.
"""
- if self._autobegin():
+ trans = self._transaction
+ if trans is None:
+ trans = self._autobegin_t()
+
if not nested and not _subtrans:
- return self._transaction
+ return trans
- if self._transaction is not None:
+ if trans is not None:
if _subtrans or nested:
- trans = self._transaction._begin(nested=nested)
+ trans = trans._begin(nested=nested)
assert self._transaction is trans
if nested:
self._nested_transaction = trans
trans = SessionTransaction(self)
assert self._transaction is trans
- return self._transaction # needed for __enter__/__exit__ hook
+ if TYPE_CHECKING:
+ assert self._transaction is not None
+
+ return trans # needed for __enter__/__exit__ hook
- def begin_nested(self):
+ def begin_nested(self) -> SessionTransaction:
"""Begin a "nested" transaction on this Session, e.g. SAVEPOINT.
The target database(s) and associated drivers must support SQL
"""
return self.begin(nested=True)
- def rollback(self):
+ def rollback(self) -> None:
"""Rollback the current transaction in progress.
If no transaction is in progress, this method is a pass-through.
:ref:`unitofwork_transaction`
"""
- if self._transaction is None:
- if not self._autobegin():
- raise sa_exc.InvalidRequestError("No transaction is begun.")
+ trans = self._transaction
+ if trans is None:
+ trans = self._autobegin_t()
- self._transaction.commit(_to_root=True)
+ trans.commit(_to_root=True)
def prepare(self) -> None:
"""Prepare the current transaction in progress for two phase commit.
:exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
"""
- if self._transaction is None:
- if not self._autobegin():
- raise sa_exc.InvalidRequestError("No transaction is begun.")
+ trans = self._transaction
+ if trans is None:
+ trans = self._autobegin_t()
- self._transaction.prepare()
+ trans.prepare()
def connection(
self,
bind_arguments: Optional[Dict[str, Any]] = None,
- execution_options: Optional["_ExecuteOptions"] = None,
+ execution_options: Optional[_ExecuteOptions] = None,
) -> "Connection":
r"""Return a :class:`_engine.Connection` object corresponding to this
:class:`.Session` object's transactional state.
execution_options=execution_options,
)
- def _connection_for_bind(self, engine, execution_options=None, **kw):
+ def _connection_for_bind(
+ self,
+ engine: _SessionBind,
+ execution_options: Optional[_ExecuteOptions] = None,
+ **kw: Any,
+ ) -> Connection:
TransactionalContext._trans_ctx_check(self)
- if self._transaction is None:
- assert self._autobegin()
- return self._transaction._connection_for_bind(
- engine, execution_options
- )
+ trans = self._transaction
+ if trans is None:
+ trans = self._autobegin_t()
+ return trans._connection_for_bind(engine, execution_options)
def execute(
self,
- statement: "Executable",
- params: Optional["_ExecuteParams"] = None,
- execution_options: "_ExecuteOptions" = util.EMPTY_DICT,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ):
+ ) -> Result:
r"""Execute a SQL expression construct.
Returns a :class:`_engine.Result` object representing
compile_state_cls = CompileState._get_plugin_class_for_plugin(
statement, "orm"
)
+ if TYPE_CHECKING:
+ assert isinstance(compile_state_cls, ORMCompileState)
else:
compile_state_cls = None
)
for idx, fn in enumerate(events_todo):
orm_exec_state._starting_event_idx = idx
- result = fn(orm_exec_state)
- if result:
- return result
+ fn_result: Optional[Result] = fn(orm_exec_state)
+ if fn_result:
+ return fn_result
statement = orm_exec_state.statement
execution_options = orm_exec_state.local_execution_options
bind = self.get_bind(**bind_arguments)
conn = self._connection_for_bind(bind)
- result = conn.execute(statement, params or {}, execution_options)
+ result: Result = conn.execute(
+ statement, params or {}, execution_options
+ )
if compile_state_cls:
result = compile_state_cls.orm_setup_cursor_result(
def scalar(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[Dict[str, Any]] = None,
+ **kw: Any,
+ ) -> Any:
"""Execute a statement and return a scalar result.
Usage and parameters are the same as that of
def scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[Dict[str, Any]] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
"""Execute a statement and return the results as scalars.
Usage and parameters are the same as that of
**kw,
).scalars()
- def close(self):
+ def close(self) -> None:
"""Close out the transactional resources and ORM objects used by this
:class:`_orm.Session`.
"""
self._close_impl(invalidate=False)
- def invalidate(self):
+ def invalidate(self) -> None:
"""Close this Session, using connection invalidation.
This is a variant of :meth:`.Session.close` that will additionally
"""
self._close_impl(invalidate=True)
- def _close_impl(self, invalidate):
+ def _close_impl(self, invalidate: bool) -> None:
self.expunge_all()
if self._transaction is not None:
for transaction in self._transaction._iterate_self_and_parents():
transaction.close(invalidate)
- def expunge_all(self):
+ def expunge_all(self) -> None:
"""Remove all object instances from this ``Session``.
This is equivalent to calling ``expunge(obj)`` on all objects in this
statelib.InstanceState._detach_states(all_states, self)
- def _add_bind(self, key, bind):
+ def _add_bind(self, key: _SessionBindKey, bind: _SessionBind) -> None:
try:
insp = inspect(key)
except sa_exc.NoInspectionAvailable as err:
"Not an acceptable bind target: %s" % key
)
- def bind_mapper(self, mapper, bind):
+ def bind_mapper(
+ self, mapper: _EntityBindKey[_O], bind: _SessionBind
+ ) -> None:
"""Associate a :class:`_orm.Mapper` or arbitrary Python class with a
"bind", e.g. an :class:`_engine.Engine` or
:class:`_engine.Connection`.
"""
self._add_bind(mapper, bind)
- def bind_table(self, table, bind):
+ def bind_table(self, table: Table, bind: _SessionBind) -> None:
"""Associate a :class:`_schema.Table` with a "bind", e.g. an
:class:`_engine.Engine`
or :class:`_engine.Connection`.
def get_bind(
self,
- mapper=None,
- clause=None,
- bind=None,
- _sa_skip_events=None,
- _sa_skip_for_implicit_returning=False,
- ):
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ clause: Optional[ClauseElement] = None,
+ bind: Optional[_SessionBind] = None,
+ _sa_skip_events: Optional[bool] = None,
+ _sa_skip_for_implicit_returning: bool = False,
+ ) -> Union[Engine, Connection]:
"""Return a "bind" to which this :class:`.Session` is bound.
The "bind" is usually an instance of :class:`_engine.Engine`,
# look more closely at the mapper.
if mapper is not None:
try:
- mapper = inspect(mapper)
+ inspected_mapper = inspect(mapper)
except sa_exc.NoInspectionAvailable as err:
if isinstance(mapper, type):
raise exc.UnmappedClassError(mapper) from err
else:
raise
+ else:
+ inspected_mapper = None
# match up the mapper or clause in the __binds
if self.__binds:
# matching mappers and selectables to entries in the
# binds dictionary; supported use case.
- if mapper:
- for cls in mapper.class_.__mro__:
+ if inspected_mapper:
+ for cls in inspected_mapper.class_.__mro__:
if cls in self.__binds:
return self.__binds[cls]
if clause is None:
- clause = mapper.persist_selectable
+ clause = inspected_mapper.persist_selectable
if clause is not None:
plugin_subject = clause._propagate_attrs.get(
for obj in visitors.iterate(clause):
if obj in self.__binds:
+ if TYPE_CHECKING:
+ assert isinstance(obj, Table)
return self.__binds[obj]
# none of the __binds matched, but we have a fallback bind.
return self.bind
context = []
- if mapper is not None:
- context.append("mapper %s" % mapper)
+ if inspected_mapper is not None:
+ context.append(f"mapper {inspected_mapper}")
if clause is not None:
context.append("SQL expression")
raise sa_exc.UnboundExecutionError(
- "Could not locate a bind configured on %s or this Session."
- % (", ".join(context),),
+ f"Could not locate a bind configured on "
+ f'{", ".join(context)} or this Session.'
)
- def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query:
+ def query(
+ self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ ) -> Query[Any]:
"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
def _identity_lookup(
self,
- mapper,
- primary_key_identity,
- identity_token=None,
- passive=attributes.PASSIVE_OFF,
- lazy_loaded_from=None,
- ):
+ mapper: Mapper[_O],
+ primary_key_identity: Union[Any, Tuple[Any, ...]],
+ identity_token: Any = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ lazy_loaded_from: Optional[InstanceState[Any]] = None,
+ ) -> Union[Optional[_O], LoaderCallableStatus]:
"""Locate an object in the identity map.
Given a primary key identity, constructs an identity key and then
)
return loading.get_from_identity(self, mapper, key, passive)
- @property
+ @util.non_memoized_property
@contextlib.contextmanager
- def no_autoflush(self):
+ def no_autoflush(self) -> Iterator[Session]:
"""Return a context manager that disables autoflush.
e.g.::
finally:
self.autoflush = autoflush
- def _autoflush(self):
+ def _autoflush(self) -> None:
if self.autoflush and not self._flushing:
try:
self.flush()
)
raise e.with_traceback(sys.exc_info()[2])
- def refresh(self, instance, attribute_names=None, with_for_update=None):
+ def refresh(
+ self,
+ instance: object,
+ attribute_names: Optional[Iterable[str]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ ) -> None:
"""Expire and refresh attributes on the given instance.
The selected attributes will first be expired as they would when using
"A blank dictionary is ambiguous."
)
- with_for_update = query.ForUpdateArg._from_argument(with_for_update)
+ with_for_update = ForUpdateArg._from_argument(with_for_update)
stmt = sql.select(object_mapper(instance))
if (
"Could not refresh instance '%s'" % instance_str(instance)
)
- def expire_all(self):
+ def expire_all(self) -> None:
"""Expires all persistent instances within this Session.
When any attributes on a persistent instance is next accessed,
for state in self.identity_map.all_states():
state._expire(state.dict, self.identity_map._modified)
- def expire(self, instance, attribute_names=None):
+ def expire(
+ self, instance: object, attribute_names: Optional[Iterable[str]] = None
+ ) -> None:
"""Expire the attributes on an instance.
Marks the attributes of an instance as out of date. When an expired
raise exc.UnmappedInstanceError(instance) from err
self._expire_state(state, attribute_names)
- def _expire_state(self, state, attribute_names):
+ def _expire_state(
+ self,
+ state: InstanceState[Any],
+ attribute_names: Optional[Iterable[str]],
+ ) -> None:
self._validate_persistent(state)
if attribute_names:
state._expire_attributes(state.dict, attribute_names)
for o, m, st_, dct_ in cascaded:
self._conditional_expire(st_)
- def _conditional_expire(self, state, autoflush=None):
+ def _conditional_expire(
+ self, state: InstanceState[Any], autoflush: Optional[bool] = None
+ ) -> None:
"""Expire a state if persistent, else expunge if pending"""
if state.key:
self._new.pop(state)
state._detach(self)
- def expunge(self, instance):
+ def expunge(self, instance: object) -> None:
"""Remove the `instance` from this ``Session``.
This will free all internal references to the instance. Cascading
)
self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded])
- def _expunge_states(self, states, to_transient=False):
+ def _expunge_states(
+ self, states: Iterable[InstanceState[Any]], to_transient: bool = False
+ ) -> None:
for state in states:
if state in self._new:
self._new.pop(state)
states, self, to_transient=to_transient
)
- def _register_persistent(self, states):
+ def _register_persistent(self, states: Set[InstanceState[Any]]) -> None:
"""Register all persistent objects from a flush.
This is used both for pending objects moving to the persistent
# state has already replaced this one in the identity
# map (see test/orm/test_naturalpks.py ReversePKsTest)
self.identity_map.safe_discard(state)
- if state in self._transaction._key_switches:
- orig_key = self._transaction._key_switches[state][0]
+ trans = self._transaction
+ assert trans is not None
+ if state in trans._key_switches:
+ orig_key = trans._key_switches[state][0]
else:
orig_key = state.key
- self._transaction._key_switches[state] = (
+ trans._key_switches[state] = (
orig_key,
instance_key,
)
for state in set(states).intersection(self._new):
self._new.pop(state)
- def _register_altered(self, states):
+ def _register_altered(self, states: Iterable[InstanceState[Any]]) -> None:
if self._transaction:
for state in states:
if state in self._new:
else:
self._transaction._dirty[state] = True
- def _remove_newly_deleted(self, states):
+ def _remove_newly_deleted(
+ self, states: Iterable[InstanceState[Any]]
+ ) -> None:
persistent_to_deleted = self.dispatch.persistent_to_deleted or None
for state in states:
if self._transaction:
if persistent_to_deleted is not None:
persistent_to_deleted(self, state)
- def add(self, instance: Any, _warn: bool = True) -> None:
+ def add(self, instance: object, _warn: bool = True) -> None:
"""Place an object in the ``Session``.
Its state will be persisted to the database on the next flush
self._save_or_update_state(state)
- def add_all(self, instances):
+ def add_all(self, instances: Iterable[object]) -> None:
"""Add the given collection of instances to this ``Session``."""
if self._warn_on_events:
for instance in instances:
self.add(instance, _warn=False)
- def _save_or_update_state(self, state):
+ def _save_or_update_state(self, state: InstanceState[Any]) -> None:
state._orphaned_outside_of_session = False
self._save_or_update_impl(state)
):
self._save_or_update_impl(st_)
- def delete(self, instance):
+ def delete(self, instance: object) -> None:
"""Mark an instance as deleted.
The database delete operation occurs upon ``flush()``.
self._delete_impl(state, instance, head=True)
- def _delete_impl(self, state, obj, head):
+ def _delete_impl(
+ self, state: InstanceState[Any], obj: object, head: bool
+ ) -> None:
if state.key is None:
if head:
cascade_states = list(
state.manager.mapper.cascade_iterator("delete", state)
)
+ else:
+ cascade_states = None
self._deleted[state] = obj
if head:
+ if TYPE_CHECKING:
+ assert cascade_states is not None
for o, m, st_, dct_ in cascade_states:
self._delete_impl(st_, o, False)
def get(
self,
- entity,
- ident,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- execution_options=None,
- ):
+ entity: _EntityBindKey[_O],
+ ident: _PKIdentityArgument,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ ) -> Optional[_O]:
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
entity,
ident,
loading.load_on_pk_identity,
- options,
+ options=options,
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
def _get_impl(
self,
- entity,
- primary_key_identity,
- db_load_fn,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- execution_options=None,
- ):
+ entity: _EntityBindKey[_O],
+ primary_key_identity: _PKIdentityArgument,
+ db_load_fn: Callable[..., _O],
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Optional[_O]:
# convert composite types to individual args
- if hasattr(primary_key_identity, "__composite_values__"):
+ if is_composite_class(primary_key_identity):
primary_key_identity = primary_key_identity.__composite_values__()
- mapper = inspect(entity)
+ mapper: Optional[Mapper[_O]] = inspect(entity)
- if not mapper or not mapper.is_mapper:
+ if mapper is None or not mapper.is_mapper:
raise sa_exc.ArgumentError(
"Expected mapped class or mapper, got: %r" % entity
)
is_dict = isinstance(primary_key_identity, dict)
if not is_dict:
primary_key_identity = util.to_list(
- primary_key_identity, default=(None,)
+ primary_key_identity, default=[None]
)
if len(primary_key_identity) != len(mapper.primary_key):
if instance is not None:
# reject calls for id in identity map but class
# mismatch.
- if not issubclass(instance.__class__, mapper.class_):
+ if not isinstance(instance, mapper.class_):
return None
return instance
- elif instance is attributes.PASSIVE_CLASS_MISMATCH:
- return None
+
+ # TODO: this was being tested before, but this is not possible
+ assert instance is not LoaderCallableStatus.PASSIVE_CLASS_MISMATCH
# set_label_style() not strictly necessary, however this will ensure
# that tablename_colname style is used which at the moment is
LABEL_STYLE_TABLENAME_PLUS_COL
)
if with_for_update is not None:
- statement._for_update_arg = query.ForUpdateArg._from_argument(
+ statement._for_update_arg = ForUpdateArg._from_argument(
with_for_update
)
load_options=load_options,
)
- def merge(self, instance, load=True, options=None):
+ def merge(
+ self,
+ instance: _O,
+ *,
+ load: bool = True,
+ options: Optional[Sequence[ORMOption]] = None,
+ ) -> _O:
"""Copy the state of a given instance into a corresponding instance
within this :class:`.Session`.
if self._warn_on_events:
self._flush_warning("Session.merge()")
- _recursive = {}
- _resolve_conflict_map = {}
+ _recursive: Dict[InstanceState[Any], object] = {}
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object] = {}
if load:
# flush current contents if we expect to load data
def _merge(
self,
- state,
- state_dict,
- load=True,
- options=None,
- _recursive=None,
- _resolve_conflict_map=None,
- ):
+ state: InstanceState[_O],
+ state_dict: _InstanceDict,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ load: bool,
+ _recursive: Dict[InstanceState[Any], object],
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+ ) -> _O:
mapper = _state_mapper(state)
if state in _recursive:
- return _recursive[state]
+ return cast(_O, _recursive[state])
new_instance = False
key = state.key
+ merged: Optional[_O]
+
if key is None:
if state in self._new:
util.warn(
"load=False."
)
key = mapper._identity_key_from_state(state)
- key_is_persistent = attributes.NEVER_SET not in key[1] and (
+ key_is_persistent = LoaderCallableStatus.NEVER_SET not in key[
+ 1
+ ] and (
not _none_set.intersection(key[1])
or (
mapper.allow_partial_pks
if merged is None:
if key_is_persistent and key in _resolve_conflict_map:
- merged = _resolve_conflict_map[key]
+ merged = cast(_O, _resolve_conflict_map[key])
elif not load:
if state.modified:
state,
state_dict,
mapper.version_id_col,
- passive=attributes.PASSIVE_NO_INITIALIZE,
+ passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
)
merged_version = mapper._get_state_attr_by_column(
merged_state,
merged_dict,
mapper.version_id_col,
- passive=attributes.PASSIVE_NO_INITIALIZE,
+ passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
)
if (
- existing_version is not attributes.PASSIVE_NO_RESULT
- and merged_version is not attributes.PASSIVE_NO_RESULT
+ existing_version
+ is not LoaderCallableStatus.PASSIVE_NO_RESULT
+ and merged_version
+ is not LoaderCallableStatus.PASSIVE_NO_RESULT
and existing_version != merged_version
):
raise exc.StaleDataError(
merged_state.manager.dispatch.load(merged_state, None)
return merged
- def _validate_persistent(self, state):
+ def _validate_persistent(self, state: InstanceState[Any]) -> None:
if not self.identity_map.contains_state(state):
raise sa_exc.InvalidRequestError(
"Instance '%s' is not persistent within this Session"
% state_str(state)
)
- def _save_impl(self, state):
+ def _save_impl(self, state: InstanceState[Any]) -> None:
if state.key is not None:
raise sa_exc.InvalidRequestError(
"Object '%s' already has an identity - "
if to_attach:
self._after_attach(state, obj)
- def _update_impl(self, state, revert_deletion=False):
+ def _update_impl(
+ self, state: InstanceState[Any], revert_deletion: bool = False
+ ) -> None:
if state.key is None:
raise sa_exc.InvalidRequestError(
"Instance '%s' is not persisted" % state_str(state)
elif revert_deletion:
self.dispatch.deleted_to_persistent(self, state)
- def _save_or_update_impl(self, state):
+ def _save_or_update_impl(self, state: InstanceState[Any]) -> None:
if state.key is None:
self._save_impl(state)
else:
self._update_impl(state)
- def enable_relationship_loading(self, obj):
+ def enable_relationship_loading(self, obj: object) -> None:
"""Associate an object with this :class:`.Session` for related
object loading.
if to_attach:
self._after_attach(state, obj)
- def _before_attach(self, state, obj):
- self._autobegin()
+ def _before_attach(self, state: InstanceState[Any], obj: object) -> bool:
+ self._autobegin_t()
if state.session_id == self.hash_key:
return False
return True
- def _after_attach(self, state, obj):
+ def _after_attach(self, state: InstanceState[Any], obj: object) -> None:
state.session_id = self.hash_key
if state.modified and state._strong_obj is None:
state._strong_obj = obj
else:
self.dispatch.transient_to_pending(self, state)
- def __contains__(self, instance):
+ def __contains__(self, instance: object) -> bool:
"""Return True if the instance is associated with this session.
The instance may be pending or persistent within the Session for a
raise exc.UnmappedInstanceError(instance) from err
return self._contains_state(state)
- def __iter__(self):
+ def __iter__(self) -> Iterator[object]:
"""Iterate over all pending or persistent instances within this
Session.
list(self._new.values()) + list(self.identity_map.values())
)
- def _contains_state(self, state):
+ def _contains_state(self, state: InstanceState[Any]) -> bool:
return state in self._new or self.identity_map.contains_state(state)
- def flush(self, objects=None):
+ def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
"""Flush all the object changes to the database.
Writes out all pending object creations, deletions and modifications
finally:
self._flushing = False
- def _flush_warning(self, method):
+ def _flush_warning(self, method: Any) -> None:
util.warn(
"Usage of the '%s' operation is not currently supported "
"within the execution stage of the flush process. "
"event listeners or connection-level operations instead." % method
)
- def _is_clean(self):
+ def _is_clean(self) -> bool:
return (
not self.identity_map.check_modified()
and not self._deleted
and not self._new
)
- def _flush(self, objects=None):
+ def _flush(self, objects: Optional[Sequence[object]] = None) -> None:
dirty = self._dirty_states
if not dirty and not self._deleted and not self._new:
def bulk_save_objects(
self,
- objects,
- return_defaults=False,
- update_changed_only=True,
- preserve_order=True,
- ):
+ objects: Iterable[object],
+ return_defaults: bool = False,
+ update_changed_only: bool = True,
+ preserve_order: bool = True,
+ ) -> None:
"""Perform a bulk save of the given list of objects.
The bulk save feature allows mapped objects to be used as the
"""
+ obj_states: Iterable[InstanceState[Any]]
+
obj_states = (attributes.instance_state(obj) for obj in objects)
if not preserve_order:
key=lambda state: (id(state.mapper), state.key is not None),
)
- def grouping_key(state):
+ def grouping_key(
+ state: InstanceState[_O],
+ ) -> Tuple[Mapper[_O], bool]:
return (state.mapper, state.key is not None)
for (mapper, isupdate), states in itertools.groupby(
)
def bulk_insert_mappings(
- self, mapper, mappings, return_defaults=False, render_nulls=False
- ):
+ self,
+ mapper: Mapper[Any],
+ mappings: Iterable[Dict[str, Any]],
+ return_defaults: bool = False,
+ render_nulls: bool = False,
+ ) -> None:
"""Perform a bulk insert of the given list of mapping dictionaries.
The bulk insert feature allows plain Python dictionaries to be used as
render_nulls,
)
- def bulk_update_mappings(self, mapper, mappings):
+ def bulk_update_mappings(
+ self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]]
+ ) -> None:
"""Perform a bulk update of the given list of mapping dictionaries.
The bulk update feature allows plain Python dictionaries to be used as
def _bulk_save_mappings(
self,
- mapper,
- mappings,
- isupdate,
- isstates,
- return_defaults,
- update_changed_only,
- render_nulls,
- ):
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ isupdate: bool,
+ isstates: bool,
+ return_defaults: bool,
+ update_changed_only: bool,
+ render_nulls: bool,
+ ) -> None:
mapper = _class_to_mapper(mapper)
self._flushing = True
finally:
self._flushing = False
- def is_modified(self, instance, include_collections=True):
+ def is_modified(
+ self, instance: object, include_collections: bool = True
+ ) -> bool:
r"""Return ``True`` if the given instance has locally
modified attributes.
continue
(added, unchanged, deleted) = attr.impl.get_history(
- state, dict_, passive=attributes.NO_CHANGE
+ state, dict_, passive=PassiveFlag.NO_CHANGE
)
if added or deleted:
return False
@property
- def is_active(self):
+ def is_active(self) -> bool:
"""True if this :class:`.Session` not in "partial rollback" state.
.. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins
"""
return self._transaction is None or self._transaction.is_active
- identity_map = None
- """A mapping of object identities to objects themselves.
-
- Iterating through ``Session.identity_map.values()`` provides
- access to the full set of persistent objects (i.e., those
- that have row identity) currently in the session.
-
- .. seealso::
-
- :func:`.identity_key` - helper function to produce the keys used
- in this dictionary.
-
- """
-
@property
- def _dirty_states(self):
+ def _dirty_states(self) -> Iterable[InstanceState[Any]]:
"""The set of all persistent states considered dirty.
This method returns all states that were modified including
return self.identity_map._dirty_states()
@property
- def dirty(self):
+ def dirty(self) -> IdentitySet:
"""The set of all persistent instances considered dirty.
E.g.::
attributes, use the :meth:`.Session.is_modified` method.
"""
- return util.IdentitySet(
+ return IdentitySet(
[
state.obj()
for state in self._dirty_states
)
@property
- def deleted(self):
+ def deleted(self) -> IdentitySet:
"The set of all instances marked as 'deleted' within this ``Session``"
return util.IdentitySet(list(self._deleted.values()))
@property
- def new(self):
+ def new(self) -> IdentitySet:
"The set of all instances marked as 'new' within this ``Session``."
return util.IdentitySet(list(self._new.values()))
"""
+ class_: Type[Session]
+
def __init__(
self,
- bind=None,
- class_=Session,
- autoflush=True,
- expire_on_commit=True,
- info=None,
- **kw,
+ bind: Optional[_SessionBind] = None,
+ class_: Type[Session] = Session,
+ autoflush: bool = True,
+ expire_on_commit: bool = True,
+ info: Optional[Dict[Any, Any]] = None,
+ **kw: Any,
):
r"""Construct a new :class:`.sessionmaker`.
# events can be associated with it specifically.
self.class_ = type(class_.__name__, (class_,), {})
- def begin(self):
+ def begin(self) -> contextlib.AbstractContextManager[Session]:
"""Produce a context manager that both provides a new
:class:`_orm.Session` as well as a transaction that commits.
session = self()
return session._maker_context_manager()
- def __call__(self, **local_kw):
+ def __call__(self, **local_kw: Any) -> Session:
"""Produce a new :class:`.Session` object using the configuration
established in this :class:`.sessionmaker`.
local_kw.setdefault(k, v)
return self.class_(**local_kw)
- def configure(self, **new_kw):
+ def configure(self, **new_kw: Any) -> None:
"""(Re)configure the arguments for this sessionmaker.
e.g.::
"""
self.kw.update(new_kw)
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(class_=%r, %s)" % (
self.__class__.__name__,
self.class_.__name__,
)
-def close_all_sessions():
+def close_all_sessions() -> None:
"""Close all sessions in memory.
This function consults a global registry of all :class:`.Session` objects
sess.close()
-def make_transient(instance):
+def make_transient(instance: object) -> None:
"""Alter the state of the given instance so that it is :term:`transient`.
.. note::
del state._deleted
-def make_transient_to_detached(instance):
+def make_transient_to_detached(instance: object) -> None:
"""Make the given transient instance :term:`detached`.
.. note::
state._expire_attributes(state.dict, state.unloaded_expirable)
-def object_session(instance):
+def object_session(instance: object) -> Optional[Session]:
"""Return the :class:`.Session` to which the given instance belongs.
This is essentially the same as the :attr:`.InstanceState.session`
from __future__ import annotations
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
import weakref
from . import base
from . import exc as orm_exc
from . import interfaces
+from ._typing import is_collection_impl
from .base import ATTR_WAS_SET
from .base import INIT_OK
+from .base import LoaderCallableStatus
from .base import NEVER_SET
from .base import NO_VALUE
from .base import PASSIVE_NO_INITIALIZE
from .. import exc as sa_exc
from .. import inspection
from .. import util
+from ..util.typing import Protocol
+if TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
+ from ._typing import _LoaderCallable
+ from .attributes import AttributeImpl
+ from .attributes import History
+ from .base import LoaderCallableStatus
+ from .base import PassiveFlag
+ from .identity import IdentityMap
+ from .instrumentation import ClassManager
+ from .interfaces import ORMOption
+ from .mapper import Mapper
+ from .session import Session
+ from ..engine import Row
+ from ..ext.asyncio.session import async_session as _async_provider
+ from ..ext.asyncio.session import AsyncSession
-# late-populated by session.py
-_sessions = None
+_T = TypeVar("_T", bound=Any)
-# optionally late-provided by sqlalchemy.ext.asyncio.session
-_async_provider = None
+if TYPE_CHECKING:
+ _sessions: weakref.WeakValueDictionary[int, Session]
+else:
+ # late-populated by session.py
+ _sessions = None
+
+
+if not TYPE_CHECKING:
+ # optionally late-provided by sqlalchemy.ext.asyncio.session
+
+ _async_provider = None # noqa
+
+
+class _InstanceDictProto(Protocol):
+ def __call__(self) -> Optional[IdentityMap]:
+ ...
@inspection._self_inspects
-class InstanceState(interfaces.InspectionAttrInfo):
+class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
"""tracks state information at the instance level.
The :class:`.InstanceState` is a key object used by the
"""
- session_id = None
- key = None
- runid = None
- load_options = ()
- load_path = PathRegistry.root
- insert_order = None
- _strong_obj = None
- modified = False
- expired = False
- _deleted = False
- _load_pending = False
- _orphaned_outside_of_session = False
- is_instance = True
- identity_token = None
- _last_known_values = ()
-
- callables = ()
+ __slots__ = (
+ "__dict__",
+ "__weakref__",
+ "class_",
+ "manager",
+ "obj",
+ "committed_state",
+ "expired_attributes",
+ )
+
+ manager: ClassManager[_T]
+ session_id: Optional[int] = None
+ key: Optional[_IdentityKeyType[_T]] = None
+ runid: Optional[int] = None
+ load_options: Tuple[ORMOption, ...] = ()
+ load_path: PathRegistry = PathRegistry.root
+ insert_order: Optional[int] = None
+ _strong_obj: Optional[object] = None
+ obj: weakref.ref[_T]
+
+ committed_state: Dict[str, Any]
+
+ modified: bool = False
+ expired: bool = False
+ _deleted: bool = False
+ _load_pending: bool = False
+ _orphaned_outside_of_session: bool = False
+ is_instance: bool = True
+ identity_token: object = None
+ _last_known_values: Optional[Dict[str, Any]] = None
+
+ _instance_dict: _InstanceDictProto
+ """A weak reference, or in the default case a plain callable, that
+ returns a reference to the current :class:`.IdentityMap`, if any.
+
+ """
+ if not TYPE_CHECKING:
+
+ def _instance_dict(self):
+ """default 'weak reference' for _instance_dict"""
+ return None
+
+ expired_attributes: Set[str]
+ """The set of keys which are 'expired' to be loaded by
+ the manager's deferred scalar loader, assuming no pending
+ changes.
+
+ see also the ``unmodified`` collection which is intersected
+ against this set when a refresh operation occurs."""
+
+ callables: Dict[str, Callable[[InstanceState[_T], PassiveFlag], Any]]
"""A namespace where a per-state loader callable can be associated.
In SQLAlchemy 1.0, this is only used for lazy loaders / deferred
"""
- def __init__(self, obj, manager):
+ if not TYPE_CHECKING:
+ callables = util.EMPTY_DICT
+
+ def __init__(self, obj: _T, manager: ClassManager[_T]):
self.class_ = obj.__class__
self.manager = manager
self.obj = weakref.ref(obj, self._cleanup)
self.committed_state = {}
self.expired_attributes = set()
- expired_attributes = None
- """The set of keys which are 'expired' to be loaded by
- the manager's deferred scalar loader, assuming no pending
- changes.
-
- see also the ``unmodified`` collection which is intersected
- against this set when a refresh operation occurs."""
-
@util.memoized_property
- def attrs(self):
+ def attrs(self) -> util.ReadOnlyProperties[AttributeState]:
"""Return a namespace representing each attribute on
the mapped object, including its current value
and history.
"""
return util.ReadOnlyProperties(
- dict((key, AttributeState(self, key)) for key in self.manager)
+ {key: AttributeState(self, key) for key in self.manager}
)
@property
- def transient(self):
+ def transient(self) -> bool:
"""Return ``True`` if the object is :term:`transient`.
.. seealso::
return self.key is None and not self._attached
@property
- def pending(self):
+ def pending(self) -> bool:
"""Return ``True`` if the object is :term:`pending`.
return self.key is None and self._attached
@property
- def deleted(self):
+ def deleted(self) -> bool:
"""Return ``True`` if the object is :term:`deleted`.
An object that is in the deleted state is guaranteed to
return self.key is not None and self._attached and self._deleted
@property
- def was_deleted(self):
+ def was_deleted(self) -> bool:
"""Return True if this object is or was previously in the
"deleted" state and has not been reverted to persistent.
return self._deleted
@property
- def persistent(self):
+ def persistent(self) -> bool:
"""Return ``True`` if the object is :term:`persistent`.
An object that is in the persistent state is guaranteed to
return self.key is not None and self._attached and not self._deleted
@property
- def detached(self):
+ def detached(self) -> bool:
"""Return ``True`` if the object is :term:`detached`.
.. seealso::
"""
return self.key is not None and not self._attached
- @property
+ @util.non_memoized_property
@util.preload_module("sqlalchemy.orm.session")
- def _attached(self):
+ def _attached(self) -> bool:
return (
self.session_id is not None
and self.session_id in util.preloaded.orm_session._sessions
)
- def _track_last_known_value(self, key):
+ def _track_last_known_value(self, key: str) -> None:
"""Track the last known value of a particular key after expiration
operations.
"""
- if key not in self._last_known_values:
- self._last_known_values = dict(self._last_known_values)
- self._last_known_values[key] = NO_VALUE
+ lkv = self._last_known_values
+ if lkv is None:
+ self._last_known_values = lkv = {}
+ if key not in lkv:
+ lkv[key] = NO_VALUE
@property
- def session(self):
+ def session(self) -> Optional[Session]:
"""Return the owning :class:`.Session` for this instance,
or ``None`` if none available.
return None
@property
- def async_session(self):
+ def async_session(self) -> Optional[AsyncSession]:
"""Return the owning :class:`_asyncio.AsyncSession` for this instance,
or ``None`` if none available.
return None
@property
- def object(self):
+ def object(self) -> Optional[_T]:
"""Return the mapped object represented by this
- :class:`.InstanceState`."""
+ :class:`.InstanceState`.
+
+ Returns None if the object has been garbage collected
+
+ """
return self.obj()
@property
- def identity(self):
+ def identity(self) -> Optional[Tuple[Any, ...]]:
"""Return the mapped identity of the mapped object.
This is the primary key identity as persisted by the ORM
which can always be passed directly to
return self.key[1]
@property
- def identity_key(self):
+ def identity_key(self) -> Optional[_IdentityKeyType[_T]]:
"""Return the identity key for the mapped object.
This is the key used to locate the object within
"""
- # TODO: just change .key to .identity_key across
- # the board ? probably
return self.key
@util.memoized_property
- def parents(self):
+ def parents(self) -> Dict[int, InstanceState[Any]]:
return {}
@util.memoized_property
- def _pending_mutations(self):
+ def _pending_mutations(self) -> Dict[str, PendingCollection]:
return {}
@util.memoized_property
- def _empty_collections(self):
+ def _empty_collections(self) -> Dict[Any, Any]:
return {}
@util.memoized_property
- def mapper(self):
+ def mapper(self) -> Mapper[_T]:
"""Return the :class:`_orm.Mapper` used for this mapped object."""
return self.manager.mapper
@property
- def has_identity(self):
+ def has_identity(self) -> bool:
"""Return ``True`` if this object has an identity key.
This should always have the same value as the
return bool(self.key)
@classmethod
- def _detach_states(self, states, session, to_transient=False):
+ def _detach_states(
+ self,
+ states: Iterable[InstanceState[_T]],
+ session: Session,
+ to_transient: bool = False,
+ ) -> None:
persistent_to_detached = (
session.dispatch.persistent_to_detached or None
)
state._strong_obj = None
- def _detach(self, session=None):
+ def _detach(self, session: Optional[Session] = None) -> None:
if session:
InstanceState._detach_states([self], session)
else:
self.session_id = self._strong_obj = None
- def _dispose(self):
+ def _dispose(self) -> None:
+ # used by the test suite, apparently
self._detach()
- del self.obj
- def _cleanup(self, ref):
+ def _cleanup(self, ref: weakref.ref[_T]) -> None:
"""Weakref callback cleanup.
This callable cleans out the state when it is being garbage
# assert self not in instance_dict._modified
self.session_id = self._strong_obj = None
- del self.obj
-
- def obj(self):
- return None
@property
- def dict(self):
+ def dict(self) -> _InstanceDict:
"""Return the instance dict used by the object.
Under normal circumstances, this is always synonymous
else:
return {}
- def _initialize_instance(*mixed, **kwargs):
+ def _initialize_instance(*mixed: Any, **kwargs: Any) -> None:
self, instance, args = mixed[0], mixed[1], mixed[2:] # noqa
manager = self.manager
manager.dispatch.init(self, args, kwargs)
try:
- return manager.original_init(*mixed[1:], **kwargs)
+ manager.original_init(*mixed[1:], **kwargs)
except:
with util.safe_reraise():
manager.dispatch.init_failure(self, args, kwargs)
- def get_history(self, key, passive):
+ def get_history(self, key: str, passive: PassiveFlag) -> History:
return self.manager[key].impl.get_history(self, self.dict, passive)
- def get_impl(self, key):
+ def get_impl(self, key: str) -> AttributeImpl:
return self.manager[key].impl
- def _get_pending_mutation(self, key):
+ def _get_pending_mutation(self, key: str) -> PendingCollection:
if key not in self._pending_mutations:
self._pending_mutations[key] = PendingCollection()
return self._pending_mutations[key]
- def __getstate__(self):
- state_dict = {"instance": self.obj()}
+ def __getstate__(self) -> Dict[str, Any]:
+ state_dict = {
+ "instance": self.obj(),
+ "class_": self.class_,
+ "committed_state": self.committed_state,
+ "expired_attributes": self.expired_attributes,
+ }
state_dict.update(
(k, self.__dict__[k])
for k in (
- "committed_state",
"_pending_mutations",
"modified",
"expired",
return state_dict
- def __setstate__(self, state_dict):
+ def __setstate__(self, state_dict: Dict[str, Any]) -> None:
inst = state_dict["instance"]
if inst is not None:
self.obj = weakref.ref(inst, self._cleanup)
self.class_ = inst.__class__
else:
- # None being possible here generally new as of 0.7.4
- # due to storage of state in "parents". "class_"
- # also new.
- self.obj = None
+ self.obj = lambda: None # type: ignore
self.class_ = state_dict["class_"]
self.committed_state = state_dict.get("committed_state", {})
- self._pending_mutations = state_dict.get("_pending_mutations", {})
- self.parents = state_dict.get("parents", {})
+ self._pending_mutations = state_dict.get("_pending_mutations", {}) # type: ignore # noqa E501
+ self.parents = state_dict.get("parents", {}) # type: ignore
self.modified = state_dict.get("modified", False)
self.expired = state_dict.get("expired", False)
if "info" in state_dict:
if "callables" in state_dict:
self.callables = state_dict["callables"]
- try:
- self.expired_attributes = state_dict["expired_attributes"]
- except KeyError:
- self.expired_attributes = set()
- # 0.9 and earlier compat
- for k in list(self.callables):
- if self.callables[k] is self:
- self.expired_attributes.add(k)
- del self.callables[k]
+ self.expired_attributes = state_dict["expired_attributes"]
else:
if "expired_attributes" in state_dict:
self.expired_attributes = state_dict["expired_attributes"]
]
)
if self.key:
- try:
- self.identity_token = self.key[2]
- except IndexError:
- # 1.1 and earlier compat before identity_token
- assert len(self.key) == 2
- self.key = self.key + (None,)
- self.identity_token = None
+ self.identity_token = self.key[2]
if "load_path" in state_dict:
self.load_path = PathRegistry.deserialize(state_dict["load_path"])
state_dict["manager"](self, inst, state_dict)
- def _reset(self, dict_, key):
+ def _reset(self, dict_: _InstanceDict, key: str) -> None:
"""Remove the given attribute and any
callables associated with it."""
old = dict_.pop(key, None)
- if old is not None and self.manager[key].impl.collection:
- self.manager[key].impl._invalidate_collection(old)
+ manager_impl = self.manager[key].impl
+ if old is not None and is_collection_impl(manager_impl):
+ manager_impl._invalidate_collection(old)
self.expired_attributes.discard(key)
if self.callables:
self.callables.pop(key, None)
- def _copy_callables(self, from_):
+ def _copy_callables(self, from_: InstanceState[Any]) -> None:
if "callables" in from_.__dict__:
self.callables = dict(from_.callables)
@classmethod
- def _instance_level_callable_processor(cls, manager, fn, key):
+ def _instance_level_callable_processor(
+ cls, manager: ClassManager[_T], fn: _LoaderCallable, key: Any
+ ) -> Callable[[InstanceState[_T], _InstanceDict, Row], None]:
impl = manager[key].impl
- if impl.collection:
+ if is_collection_impl(impl):
+ fixed_impl = impl
- def _set_callable(state, dict_, row):
+ def _set_callable(
+ state: InstanceState[_T], dict_: _InstanceDict, row: Row
+ ) -> None:
if "callables" not in state.__dict__:
state.callables = {}
old = dict_.pop(key, None)
if old is not None:
- impl._invalidate_collection(old)
+ fixed_impl._invalidate_collection(old)
state.callables[key] = fn
else:
- def _set_callable(state, dict_, row):
+ def _set_callable(
+ state: InstanceState[_T], dict_: _InstanceDict, row: Row
+ ) -> None:
if "callables" not in state.__dict__:
state.callables = {}
state.callables[key] = fn
return _set_callable
- def _expire(self, dict_, modified_set):
+ def _expire(
+ self, dict_: _InstanceDict, modified_set: Set[InstanceState[Any]]
+ ) -> None:
self.expired = True
if self.modified:
modified_set.discard(self)
if self._last_known_values:
self._last_known_values.update(
- (k, dict_[k]) for k in self._last_known_values if k in dict_
+ {k: dict_[k] for k in self._last_known_values if k in dict_}
)
for key in self.manager._all_key_set.intersection(dict_):
self.manager.dispatch.expire(self, None)
- def _expire_attributes(self, dict_, attribute_names, no_loader=False):
+ def _expire_attributes(
+ self,
+ dict_: _InstanceDict,
+ attribute_names: Iterable[str],
+ no_loader: bool = False,
+ ) -> None:
pending = self.__dict__.get("_pending_mutations", None)
callables = self.callables
if callables and key in callables:
del callables[key]
old = dict_.pop(key, NO_VALUE)
- if impl.collection and old is not NO_VALUE:
+ if is_collection_impl(impl) and old is not NO_VALUE:
impl._invalidate_collection(old)
- if (
- self._last_known_values
- and key in self._last_known_values
- and old is not NO_VALUE
- ):
- self._last_known_values[key] = old
+ lkv = self._last_known_values
+ if lkv is not None and key in lkv and old is not NO_VALUE:
+ lkv[key] = old
self.committed_state.pop(key, None)
if pending:
self.manager.dispatch.expire(self, attribute_names)
- def _load_expired(self, state, passive):
+ def _load_expired(
+ self, state: InstanceState[_T], passive: PassiveFlag
+ ) -> LoaderCallableStatus:
"""__call__ allows the InstanceState to act as a deferred
callable for loading expired attributes, which is also
serializable (picklable).
return ATTR_WAS_SET
@property
- def unmodified(self):
+ def unmodified(self) -> Set[str]:
"""Return the set of keys which have no uncommitted changes"""
return set(self.manager).difference(self.committed_state)
- def unmodified_intersection(self, keys):
+ def unmodified_intersection(self, keys: Iterable[str]) -> Set[str]:
"""Return self.unmodified.intersection(keys)."""
return (
)
@property
- def unloaded(self):
+ def unloaded(self) -> Set[str]:
"""Return the set of keys which do not have a loaded value.
This includes expired attributes and any other attribute that
)
@property
- def unloaded_expirable(self):
+ def unloaded_expirable(self) -> Set[str]:
"""Return the set of keys which do not have a loaded value.
This includes expired attributes and any other attribute that
return self.unloaded
@property
- def _unloaded_non_object(self):
+ def _unloaded_non_object(self) -> Set[str]:
return self.unloaded.intersection(
attr
for attr in self.manager
if self.manager[attr].impl.accepts_scalar_loader
)
- def _instance_dict(self):
- return None
-
def _modified_event(
- self, dict_, attr, previous, collection=False, is_userland=False
- ):
+ self,
+ dict_: _InstanceDict,
+ attr: AttributeImpl,
+ previous: Any,
+ collection: bool = False,
+ is_userland: bool = False,
+ ) -> None:
if attr:
if not attr.send_modified_events:
return
)
if attr.key not in self.committed_state or is_userland:
if collection:
+ if TYPE_CHECKING:
+ assert is_collection_impl(attr)
if previous is NEVER_SET:
if attr.key in dict_:
previous = dict_[attr.key]
previous = attr.copy(previous)
self.committed_state[attr.key] = previous
- if attr.key in self._last_known_values:
- self._last_known_values[attr.key] = NO_VALUE
+ lkv = self._last_known_values
+ if lkv is not None and attr.key in lkv:
+ lkv[attr.key] = NO_VALUE
# assert self._strong_obj is None or self.modified
pass
else:
if session._transaction is None:
- session._autobegin()
+ session._autobegin_t()
if inst is None and attr:
raise orm_exc.ObjectDereferencedError(
% (self.manager[attr.key], base.state_class_str(self))
)
- def _commit(self, dict_, keys):
+ def _commit(self, dict_: _InstanceDict, keys: Iterable[str]) -> None:
"""Commit attributes.
This is used by a partial-attribute load operation to mark committed
):
del self.callables[key]
- def _commit_all(self, dict_, instance_dict=None):
+ def _commit_all(
+ self, dict_: _InstanceDict, instance_dict: Optional[IdentityMap] = None
+ ) -> None:
"""commit all attributes unconditionally.
This is used after a flush() or a full load/refresh
self._commit_all_states([(self, dict_)], instance_dict)
@classmethod
- def _commit_all_states(self, iter_, instance_dict=None):
+ def _commit_all_states(
+ self,
+ iter_: Iterable[Tuple[InstanceState[Any], _InstanceDict]],
+ instance_dict: Optional[IdentityMap] = None,
+ ) -> None:
"""Mass / highly inlined version of commit_all()."""
for state, dict_ in iter_:
"""
- def __init__(self, state, key):
+ __slots__ = ("state", "key")
+
+ state: InstanceState[Any]
+ key: str
+
+ def __init__(self, state: InstanceState[Any], key: str):
self.state = state
self.key = key
@property
- def loaded_value(self):
+ def loaded_value(self) -> Any:
"""The current value of this attribute as loaded from the database.
If the value has not been loaded, or is otherwise not present
return self.state.dict.get(self.key, NO_VALUE)
@property
- def value(self):
+ def value(self) -> Any:
"""Return the value of this attribute.
This operation is equivalent to accessing the object's
)
@property
- def history(self):
+ def history(self) -> History:
"""Return the current **pre-flush** change history for
this attribute, via the :class:`.History` interface.
"""
return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE)
- def load_history(self):
+ def load_history(self) -> History:
"""Return the current **pre-flush** change history for
this attribute, via the :class:`.History` interface.
"""
- def __init__(self):
+ __slots__ = ("deleted_items", "added_items")
+
+ deleted_items: util.IdentitySet
+ added_items: util.OrderedIdentitySet
+
+ def __init__(self) -> None:
self.deleted_items = util.IdentitySet()
self.added_items = util.OrderedIdentitySet()
- def append(self, value):
+ def append(self, value: Any) -> None:
if value in self.deleted_items:
self.deleted_items.remove(value)
else:
self.added_items.add(value)
- def remove(self, value):
+ def remove(self, value: Any) -> None:
if value in self.added_items:
self.added_items.remove(value)
else:
from typing import Callable
from typing import Optional
from typing import Tuple
+from typing import TypeVar
from typing import Union
from .. import exc as sa_exc
from .. import util
from ..util.typing import Literal
+_F = TypeVar("_F", bound=Callable[..., Any])
+
class _StateChangeState(Enum):
pass
Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...]
],
moves_to: _StateChangeState,
- ) -> Callable[..., Any]:
+ ) -> Callable[[_F], _F]:
"""Method decorator declaring valid states.
:param prerequisite_states: sequence of acceptable prerequisite
from __future__ import annotations
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import Set
+from typing import TYPE_CHECKING
+
from . import attributes
from . import exc as orm_exc
from . import util as orm_util
from ..util import topological
+if TYPE_CHECKING:
+ from .dependency import DependencyProcessor
+ from .interfaces import MapperProperty
+ from .mapper import Mapper
+ from .session import Session
+ from .session import SessionTransaction
+ from .state import InstanceState
+
+
def track_cascade_events(descriptor, prop):
"""Establish event listeners on object attributes which handle
cascade-on-set/append.
class UOWTransaction:
- def __init__(self, session):
+ session: Session
+ transaction: SessionTransaction
+ attributes: Dict[str, Any]
+ deps: util.defaultdict[Mapper[Any], Set[DependencyProcessor]]
+ mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]]
+
+ def __init__(self, session: Session):
self.session = session
# dictionary used by external actors to
def register_object(
self,
- state,
- isdelete=False,
- listonly=False,
- cancel_delete=False,
- operation=None,
- prop=None,
- ):
+ state: InstanceState[Any],
+ isdelete: bool = False,
+ listonly: bool = False,
+ cancel_delete: bool = False,
+ operation: Optional[str] = None,
+ prop: Optional[MapperProperty] = None,
+ ) -> bool:
if not self.session._contains_state(state):
# this condition is normal when objects are registered
# as part of a relationship cascade operation. it should
[a for a in self.postsort_actions.values() if not a.disabled]
).difference(cycles)
- def execute(self):
+ def execute(self) -> None:
postsort_actions = self._generate_actions()
postsort_actions = sorted(
for rec in topological.sort(self.dependencies, postsort_actions):
rec.execute(self)
- def finalize_flush_changes(self):
+ def finalize_flush_changes(self) -> None:
"""Mark processed objects as clean / deleted after a successful
flush().
from typing import Any
from typing import Generic
from typing import Optional
-from typing import overload
from typing import Tuple
from typing import Type
from typing import TypeVar
from . import attributes # noqa
from .base import _class_to_mapper # noqa
-from .base import _IdentityKeyType
from .base import _never_set # noqa
from .base import _none_set # noqa
from .base import attribute_str # noqa
from ..util.typing import is_origin_of
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
+ from ._typing import _IdentityKeyType
+ from ._typing import _InternalEntityType
from .mapper import Mapper
from ..engine import Row
from ..sql._typing import _PropagateAttrsType
return sql.union_all(*result).alias(aliasname)
-@overload
def identity_key(
- class_: type, ident: Tuple[Any, ...], *, identity_token: Optional[str]
-) -> _IdentityKeyType:
- ...
-
-
-@overload
-def identity_key(*, instance: Any) -> _IdentityKeyType:
- ...
-
-
-@overload
-def identity_key(
- class_: type, *, row: "Row", identity_token: Optional[str]
-) -> _IdentityKeyType:
- ...
-
-
-def identity_key(
- class_=None, ident=None, *, instance=None, row=None, identity_token=None
+ class_: Optional[Type[Any]] = None,
+ ident: Union[Any, Tuple[Any, ...]] = None,
+ *,
+ instance: Optional[Any] = None,
+ row: Optional[Row] = None,
+ identity_token: Optional[Any] = None,
) -> _IdentityKeyType:
r"""Generate "identity key" tuples, as are used as keys in the
:attr:`.Session.identity_map` dictionary.
sql_base.HasCacheKey,
InspectionAttr,
MemoizedSlots,
+ Generic[_T],
):
"""Provide an inspection interface for an
:class:`.AliasedClass` object.
def __init__(
self,
- entity,
- inspected,
+ entity: _EntityType,
+ inspected: _InternalEntityType,
selectable,
name,
with_polymorphic_mappers,
return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
+def _cleanup_mapped_str_annotation(annotation):
+ # fix up an annotation that comes in as the form:
+ # 'Mapped[List[Address]]' so that it instead looks like:
+ # 'Mapped[List["Address"]]' , which will allow us to get
+ # "Address" as a string
+ mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
+ if mm and mm.group(1) == "Mapped":
+ stack = []
+ inner = mm
+ while True:
+ stack.append(inner.group(1))
+ g2 = inner.group(2)
+ inner = re.match(r"^(.+?)\[(.+)\]$", g2)
+ if inner is None:
+ stack.append(g2)
+ break
+
+ # stack: ['Mapped', 'List', 'Address']
+ if not re.match(r"""^["'].*["']$""", stack[-1]):
+ stack[-1] = f'"{stack[-1]}"'
+ # stack: ['Mapped', 'List', '"Address"']
+
+ annotation = "[".join(stack) + ("]" * (len(stack) - 1))
+ return annotation
+
+
def _extract_mapped_subtype(
raw_annotation: Union[type, str],
cls: type,
)
return None
- annotated = de_stringify_annotation(cls, raw_annotation)
+ annotated = de_stringify_annotation(
+ cls, raw_annotation, _cleanup_mapped_str_annotation
+ )
if is_dataclass_field:
return annotated
from typing import TypeVar
from typing import Union
+from sqlalchemy.sql.base import Executable
from . import roles
from .. import util
from ..inspection import Inspectable
def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:
...
- def is_select_base(t: ReturnsRows) -> TypeGuard[SelectBase]:
+ def is_select_base(
+ t: Union[Executable, ReturnsRows]
+ ) -> TypeGuard[SelectBase]:
...
- def is_select_statement(t: ReturnsRows) -> TypeGuard[Select]:
+ def is_select_statement(
+ t: Union[Executable, ReturnsRows]
+ ) -> TypeGuard[Select]:
...
def is_table(t: FromClause) -> TypeGuard[TableClause]:
return None
@classmethod
- def _get_plugin_class_for_plugin(cls, statement, plugin_name):
+ def _get_plugin_class_for_plugin(
+ cls, statement: Executable, plugin_name: str
+ ) -> Optional[Type[CompileState]]:
try:
return cls.plugins[
(plugin_name, statement._effective_plugin_target)
)
@classmethod
- def isinstance(cls, klass):
+ def isinstance(cls, klass: Type[Any]) -> bool:
return issubclass(cls, klass)
@hybridmethod
_is_has_cache_key = False
+ _is_core = True
+
def _clone(self, **kw):
"""Create a shallow copy of this ExecutableOption."""
c = self.__class__.__new__(self.__class__)
schema.SchemaEventTarget,
HasCacheKey,
Options,
- util.langhelpers._symbol,
+ util.langhelpers.symbol,
),
)
and not hasattr(element, "__clause_element__")
skip_locked: bool
@classmethod
- def _from_argument(cls, with_for_update):
+ def _from_argument(
+ cls, with_for_update: Union[ForUpdateArg, None, bool, Dict[str, Any]]
+ ) -> Optional[ForUpdateArg]:
if isinstance(with_for_update, ForUpdateArg):
return with_for_update
elif with_for_update in (None, False):
elif with_for_update is True:
return ForUpdateArg()
else:
- return ForUpdateArg(**with_for_update)
+ return ForUpdateArg(**cast("Dict[str, Any]", with_for_update))
def __eq__(self, other):
return (
from .config import db
from .config import fixture
from .config import requirements as requires
+from .config import skip_test
from .exclusions import _is_excluded
from .exclusions import _server_version
from .exclusions import against as _against
from .langhelpers import duck_type_collection as duck_type_collection
from .langhelpers import ellipses_string as ellipses_string
from .langhelpers import EnsureKWArg as EnsureKWArg
+from .langhelpers import FastIntFlag as FastIntFlag
from .langhelpers import format_argspec_init as format_argspec_init
from .langhelpers import format_argspec_plus as format_argspec_plus
from .langhelpers import generic_fn_descriptor as generic_fn_descriptor
from typing import Iterable
from typing import Iterator
from typing import List
+from typing import Mapping
from typing import Optional
from typing import overload
from typing import Set
return result
-def coerce_to_immutabledict(d):
+def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
if not d:
return EMPTY_DICT
elif isinstance(d, immutabledict):
_DT = TypeVar("_DT", bound=Any)
+_F = TypeVar("_F", bound=Any)
+
class Properties(Generic[_T]):
"""Provide a __getattr__/__setattr__ interface over a dict."""
_data: Dict[str, _T]
- def __init__(self, data):
+ def __init__(self, data: Dict[str, _T]):
object.__setattr__(self, "_data", data)
def __len__(self) -> int:
def __iter__(self) -> Iterator[_T]:
return iter(list(self._data.values()))
- def __dir__(self):
+ def __dir__(self) -> List[str]:
return dir(super(Properties, self)) + [
str(k) for k in self._data.keys()
]
- def __add__(self, other):
- return list(self) + list(other)
+ def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
+ return list(self) + list(other) # type: ignore
- def __setitem__(self, key, obj):
+ def __setitem__(self, key: str, obj: _T) -> None:
self._data[key] = obj
def __getitem__(self, key: str) -> _T:
return self._data[key]
- def __delitem__(self, key):
+ def __delitem__(self, key: str) -> None:
del self._data[key]
- def __setattr__(self, key, obj):
+ def __setattr__(self, key: str, obj: _T) -> None:
self._data[key] = obj
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
return {"_data": self._data}
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_data", state["_data"])
def __getattr__(self, key: str) -> _T:
def __contains__(self, key: str) -> bool:
return key in self._data
- def as_readonly(self) -> "ReadOnlyProperties[_T]":
+ def as_readonly(self) -> ReadOnlyProperties[_T]:
"""Return an immutable proxy for this :class:`.Properties`."""
return ReadOnlyProperties(self._data)
- def update(self, value):
+ def update(self, value: Dict[str, _T]) -> None:
self._data.update(value)
@overload
def has_key(self, key: str) -> bool:
return key in self._data
- def clear(self):
+ def clear(self) -> None:
self._data.clear()
class OrderedIdentitySet(IdentitySet):
- def __init__(self, iterable=None):
+ def __init__(self, iterable: Optional[Iterable[Any]] = None):
IdentitySet.__init__(self)
self._members = OrderedDict()
if iterable:
scopefunc: _ScopeFuncType
registry: Any
- def __init__(self, createfunc, scopefunc):
+ def __init__(
+ self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
+ ):
"""Construct a new :class:`.ScopedRegistry`.
:param createfunc: A creation function that will generate
"""
- def __init__(self, iterable=None):
+ _members: Dict[int, Any]
+
+ def __init__(self, iterable: Optional[Iterable[Any]] = None):
self._members = dict()
if iterable:
self.update(iterable)
- def add(self, value):
+ def add(self, value: Any) -> None:
self._members[id(value)] = value
- def __contains__(self, value):
+ def __contains__(self, value: Any) -> bool:
return id(value) in self._members
- def remove(self, value):
+ def remove(self, value: Any) -> None:
del self._members[id(value)]
- def discard(self, value):
+ def discard(self, value: Any) -> None:
try:
self.remove(value)
except KeyError:
pass
- def pop(self):
+ def pop(self) -> Any:
try:
pair = self._members.popitem()
return pair[1]
except KeyError:
raise KeyError("pop from an empty set")
- def clear(self):
+ def clear(self) -> None:
self._members.clear()
- def __cmp__(self, other):
+ def __cmp__(self, other: Any) -> NoReturn:
raise TypeError("cannot compare sets using cmp()")
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, IdentitySet):
return self._members == other._members
else:
return False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
if isinstance(other, IdentitySet):
return self._members != other._members
else:
return True
- def issubset(self, iterable):
+ def issubset(self, iterable: Iterable[Any]) -> bool:
if isinstance(iterable, self.__class__):
other = iterable
else:
return False
return True
- def __le__(self, other):
+ def __le__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.issubset(other)
- def __lt__(self, other):
+ def __lt__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return len(self) < len(other) and self.issubset(other)
- def issuperset(self, iterable):
+ def issuperset(self, iterable: Iterable[Any]) -> bool:
if isinstance(iterable, self.__class__):
other = iterable
else:
return False
return True
- def __ge__(self, other):
+ def __ge__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.issuperset(other)
- def __gt__(self, other):
+ def __gt__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return len(self) > len(other) and self.issuperset(other)
- def union(self, iterable):
+ def union(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__class__()
members = self._members
result._members.update(members)
result._members.update((id(obj), obj) for obj in iterable)
return result
- def __or__(self, other):
+ def __or__(self, other: Any) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.union(other)
- def update(self, iterable):
+ def update(self, iterable: Iterable[Any]) -> None:
self._members.update((id(obj), obj) for obj in iterable)
- def __ior__(self, other):
+ def __ior__(self, other: Any) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.update(other)
return self
- def difference(self, iterable):
+ def difference(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
other: Collection[Any]
}
return result
- def __sub__(self, other):
+ def __sub__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.difference(other)
- def difference_update(self, iterable):
+ def difference_update(self, iterable: Iterable[Any]) -> None:
self._members = self.difference(iterable)._members
- def __isub__(self, other):
+ def __isub__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.difference_update(other)
return self
- def intersection(self, iterable):
+ def intersection(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
other: Collection[Any]
}
return result
- def __and__(self, other):
+ def __and__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.intersection(other)
- def intersection_update(self, iterable):
+ def intersection_update(self, iterable: Iterable[Any]) -> None:
self._members = self.intersection(iterable)._members
- def __iand__(self, other):
+ def __iand__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.intersection_update(other)
return self
- def symmetric_difference(self, iterable):
+ def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
if isinstance(iterable, self.__class__):
other = iterable._members
)
return result
- def __xor__(self, other):
+ def __xor__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.symmetric_difference(other)
- def symmetric_difference_update(self, iterable):
+ def symmetric_difference_update(self, iterable: Iterable[Any]) -> None:
self._members = self.symmetric_difference(iterable)._members
- def __ixor__(self, other):
+ def __ixor__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.symmetric_difference(other)
return self
- def copy(self):
+ def copy(self) -> IdentitySet:
result = self.__new__(self.__class__)
result._members = self._members.copy()
return result
__copy__ = copy
- def __len__(self):
+ def __len__(self) -> int:
return len(self._members)
- def __iter__(self):
+ def __iter__(self) -> Iterator[Any]:
return iter(self._members.values())
- def __hash__(self):
+ def __hash__(self) -> NoReturn:
raise TypeError("set objects are unhashable")
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(%r)" % (type(self).__name__, list(self._members.values()))
from __future__ import annotations
import collections
+import enum
from functools import update_wrapper
import hashlib
import inspect
def create_proxy_methods(
- target_cls,
- target_cls_sphinx_name,
- proxy_cls_sphinx_name,
- classmethods=(),
- methods=(),
- attributes=(),
-):
+ target_cls: Type[Any],
+ target_cls_sphinx_name: str,
+ proxy_cls_sphinx_name: str,
+ classmethods: Sequence[str] = (),
+ methods: Sequence[str] = (),
+ attributes: Sequence[str] = (),
+) -> Callable[[_T], _T]:
"""A class decorator indicating attributes should refer to a proxy
class.
return self
-class _symbol(int):
+class symbol(int):
+ """A constant symbol.
+
+ >>> symbol('foo') is symbol('foo')
+ True
+ >>> symbol('foo')
+ <symbol 'foo>
+
+ A slight refinement of the MAGICCOOKIE=object() pattern. The primary
+ advantage of symbol() is its repr(). They are also singletons.
+
+ Repeated calls of symbol('name') will all return the same instance.
+
+ In SQLAlchemy 2.0, symbol() is used for the implementation of
+ ``_FastIntFlag``, but otherwise should be mostly replaced by
+ ``enum.Enum`` and variants.
+
+
+ """
+
name: str
+ symbols: Dict[str, symbol] = {}
+ _lock = threading.Lock()
+
def __new__(
cls,
name: str,
doc: Optional[str] = None,
canonical: Optional[int] = None,
- ) -> "_symbol":
- """Construct a new named symbol."""
- assert isinstance(name, str)
- if canonical is None:
- canonical = hash(name)
- v = int.__new__(_symbol, canonical)
- v.name = name
- if doc:
- v.__doc__ = doc
- return v
+ ) -> symbol:
+ with cls._lock:
+ sym = cls.symbols.get(name)
+ if sym is None:
+ assert isinstance(name, str)
+ if canonical is None:
+ canonical = hash(name)
+ sym = int.__new__(symbol, canonical)
+ sym.name = name
+ if doc:
+ sym.__doc__ = doc
+
+ cls.symbols[name] = sym
+ return sym
def __reduce__(self):
return symbol, (self.name, "x", int(self))
return repr(self)
def __repr__(self):
- return "symbol(%r)" % self.name
+ return f"symbol({self.name!r})"
-_symbol.__name__ = "symbol"
+class _IntFlagMeta(type):
+ def __init__(
+ cls,
+ classname: str,
+ bases: Tuple[Type[Any], ...],
+ dict_: Dict[str, Any],
+ **kw: Any,
+ ) -> None:
+ items: List[symbol]
+ cls._items = items = []
+ for k, v in dict_.items():
+ if isinstance(v, int):
+ sym = symbol(k, canonical=v)
+ elif not k.startswith("_"):
+ raise TypeError("Expected integer values for IntFlag")
+ else:
+ continue
+ setattr(cls, k, sym)
+ items.append(sym)
+ def __iter__(self) -> Iterator[symbol]:
+ return iter(self._items)
-class symbol:
- """A constant symbol.
- >>> symbol('foo') is symbol('foo')
- True
- >>> symbol('foo')
- <symbol 'foo>
+class _FastIntFlag(metaclass=_IntFlagMeta):
+ """An 'IntFlag' copycat that isn't slow when performing bitwise
+ operations.
- A slight refinement of the MAGICCOOKIE=object() pattern. The primary
- advantage of symbol() is its repr(). They are also singletons.
+ the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
+ and ``_FastIntFlag`` otherwise.
- Repeated calls of symbol('name') will all return the same instance.
+ """
- The optional ``doc`` argument assigns to ``__doc__``. This
- is strictly so that Sphinx autoattr picks up the docstring we want
- (it doesn't appear to pick up the in-module docstring if the datamember
- is in a different module - autoattribute also blows up completely).
- If Sphinx fixes/improves this then we would no longer need
- ``doc`` here.
- """
+if TYPE_CHECKING:
+ from enum import IntFlag
- symbols: Dict[str, "_symbol"] = {}
- _lock = threading.Lock()
+ FastIntFlag = IntFlag
+else:
+ FastIntFlag = _FastIntFlag
- def __new__( # type: ignore[misc]
- cls,
- name: str,
- doc: Optional[str] = None,
- canonical: Optional[int] = None,
- ) -> _symbol:
- with cls._lock:
- sym = cls.symbols.get(name)
- if sym is None:
- cls.symbols[name] = sym = _symbol(name, doc, canonical)
- return sym
- @classmethod
- def parse_user_argument(
- cls, arg, choices, name, resolve_symbol_names=False
- ):
- """Given a user parameter, parse the parameter into a chosen symbol.
-
- The user argument can be a string name that matches the name of a
- symbol, or the symbol object itself, or any number of alternate choices
- such as True/False/ None etc.
-
- :param arg: the user argument.
- :param choices: dictionary of symbol object to list of possible
- entries.
- :param name: name of the argument. Used in an :class:`.ArgumentError`
- that is raised if the parameter doesn't match any available argument.
- :param resolve_symbol_names: include the name of each symbol as a valid
- entry.
-
- """
- # note using hash lookup is tricky here because symbol's `__hash__`
- # is its int value which we don't want included in the lookup
- # explicitly, so we iterate and compare each.
- for sym, choice in choices.items():
- if arg is sym:
- return sym
- elif resolve_symbol_names and arg == sym.name:
- return sym
- elif arg in choice:
- return sym
-
- if arg is None:
- return None
-
- raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+_E = TypeVar("_E", bound=enum.Enum)
def parse_user_argument_for_enum(
arg: Any,
- choices: Dict[_T, List[Any]],
+ choices: Dict[_E, List[Any]],
name: str,
-) -> Optional[_T]:
+ resolve_symbol_names: bool = False,
+) -> Optional[_E]:
"""Given a user parameter, parse the parameter into a chosen value
from a list of choice objects, typically Enum values.
that is raised if the parameter doesn't match any available argument.
"""
- # TODO: use whatever built in thing Enum provides for this,
- # if applicable
for enum_value, choice in choices.items():
if arg is enum_value:
return enum_value
+ elif resolve_symbol_names and arg == enum_value.name:
+ return enum_value
elif arg in choice:
return enum_value
if arg is None:
return None
- raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+ raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
_creation_order = 1
if TYPE_CHECKING:
from sqlalchemy.engine import default as engine_default
+ from sqlalchemy.orm import session as orm_session
+ from sqlalchemy.orm import util as orm_util
from sqlalchemy.sql import dml as sql_dml
from sqlalchemy.sql import util as sql_util
import sys
import typing
from typing import Any
+from typing import Callable
from typing import cast
from typing import Dict
from typing import ForwardRef
from typing import Iterable
+from typing import Optional
from typing import Tuple
from typing import Type
from typing import TypeVar
def de_stringify_annotation(
- cls: Type[Any], annotation: Union[str, Type[Any]]
+ cls: Type[Any],
+ annotation: Union[str, Type[Any]],
+ str_cleanup_fn: Optional[Callable[[str], str]] = None,
) -> Union[str, Type[Any]]:
"""Resolve annotations that may be string based into real objects.
annotation = cast(ForwardRef, annotation).__forward_arg__
if isinstance(annotation, str):
+ if str_cleanup_fn:
+ annotation = str_cleanup_fn(annotation)
+
base_globals: "Dict[str, Any]" = getattr(
sys.modules.get(cls.__module__, None), "__dict__", {}
)
+
try:
annotation = eval(annotation, base_globals, None)
except NameError:
"sqlalchemy.engine.*",
"sqlalchemy.pool.*",
+ "sqlalchemy.orm.scoping",
+ "sqlalchemy.orm.session",
+ "sqlalchemy.orm.state",
+
# modules
"sqlalchemy.events",
"sqlalchemy.exc",
from sqlalchemy.testing.util import picklers
from sqlalchemy.util import classproperty
from sqlalchemy.util import compat
+from sqlalchemy.util import FastIntFlag
from sqlalchemy.util import get_callable_argspec
from sqlalchemy.util import langhelpers
from sqlalchemy.util import preloaded
assert sym1 is not sym3
assert sym1 != sym3
+ def test_fast_int_flag(self):
+ class Enum(FastIntFlag):
+ sym1 = 1
+ sym2 = 2
+
+ sym3 = 3
+
+ assert Enum.sym1 is not Enum.sym3
+ assert Enum.sym1 != Enum.sym3
+
+ assert Enum.sym1.name == "sym1"
+
+ eq_(list(Enum), [Enum.sym1, Enum.sym2, Enum.sym3])
+
def test_pickle(self):
sym1 = util.symbol("foo")
sym2 = util.symbol("foo")
assert (sym1 | sym2) & (sym2 | sym4)
def test_parser(self):
- sym1 = util.symbol("sym1", canonical=1)
- sym2 = util.symbol("sym2", canonical=2)
- sym3 = util.symbol("sym3", canonical=4)
- sym4 = util.symbol("sym4", canonical=8)
+ class MyEnum(FastIntFlag):
+ sym1 = 1
+ sym2 = 2
+ sym3 = 4
+ sym4 = 8
+ sym1, sym2, sym3, sym4 = tuple(MyEnum)
lookup_one = {sym1: [], sym2: [True], sym3: [False], sym4: [None]}
lookup_two = {sym1: [], sym2: [True], sym3: [False]}
lookup_three = {sym1: [], sym2: ["symbol2"], sym3: []}
is_(
- util.symbol.parse_user_argument(
+ langhelpers.parse_user_argument_for_enum(
"sym2", lookup_one, "some_name", resolve_symbol_names=True
),
sym2,
assert_raises_message(
exc.ArgumentError,
"Invalid value for 'some_name': 'sym2'",
- util.symbol.parse_user_argument,
+ langhelpers.parse_user_argument_for_enum,
"sym2",
lookup_one,
"some_name",
)
is_(
- util.symbol.parse_user_argument(
+ langhelpers.parse_user_argument_for_enum(
True, lookup_one, "some_name", resolve_symbol_names=False
),
sym2,
)
is_(
- util.symbol.parse_user_argument(sym2, lookup_one, "some_name"),
+ langhelpers.parse_user_argument_for_enum(
+ sym2, lookup_one, "some_name"
+ ),
sym2,
)
is_(
- util.symbol.parse_user_argument(None, lookup_one, "some_name"),
+ langhelpers.parse_user_argument_for_enum(
+ None, lookup_one, "some_name"
+ ),
sym4,
)
is_(
- util.symbol.parse_user_argument(None, lookup_two, "some_name"),
+ langhelpers.parse_user_argument_for_enum(
+ None, lookup_two, "some_name"
+ ),
None,
)
is_(
- util.symbol.parse_user_argument(
+ langhelpers.parse_user_argument_for_enum(
"symbol2", lookup_three, "some_name"
),
sym2,
assert_raises_message(
exc.ArgumentError,
"Invalid value for 'some_name': 'foo'",
- util.symbol.parse_user_argument,
+ langhelpers.parse_user_argument_for_enum,
"foo",
lookup_three,
"some_name",
--- /dev/null
+from __future__ import annotations
+
+from typing import List
+from typing import Sequence
+
+from sqlalchemy import create_engine
+from sqlalchemy import ForeignKey
+from sqlalchemy import select
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class User(Base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str]
+ addresses: Mapped[List[Address]] = relationship(back_populates="user")
+
+
+class Address(Base):
+ __tablename__ = "address"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ user_id = mapped_column(ForeignKey("user.id"))
+ email: Mapped[str]
+
+ user: Mapped[User] = relationship(back_populates="addresses")
+
+
+e = create_engine("sqlite://")
+Base.metadata.create_all(e)
+
+with Session(e) as sess:
+ u1 = User(name="u1")
+ sess.add(u1)
+ sess.add_all([Address(user=u1, email="e1"), Address(user=u1, email="e2")])
+ sess.commit()
+
+with Session(e) as sess:
+ users: Sequence[User] = sess.scalars(
+ select(User), execution_options={"stream_results": False}
+ ).all()
from __future__ import annotations
+from typing import List
+
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.testing import is_
from .test_typed_mapping import MappedColumnTest # noqa
-from .test_typed_mapping import RelationshipLHSTest # noqa
+from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
"""runs the annotation-sensitive tests from test_typed_mappings while
having ``from __future__ import annotations`` in effect.
"""
+
+
+class RelationshipLHSTest(_RelationshipLHSTest):
+ def test_bidirectional_literal_annotations(self, decl_base):
+ """test the 'string cleanup' function in orm/util.py, where
+ we receive a string annotation like::
+
+ "Mapped[List[B]]"
+
+ Which then fails to evaluate because we don't have "B" yet.
+ The annotation is converted on the fly to::
+
+ 'Mapped[List["B"]]'
+
+ so that when we evaluated it, we get ``Mapped[List["B"]]`` and
+ can extract "B" as a string.
+
+ """
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column()
+ bs: Mapped[List[B]] = relationship(back_populates="a")
+
+ class B(decl_base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+ a: Mapped[A] = relationship(
+ back_populates="bs", primaryjoin=a_id == A.id
+ )
+
+ a1 = A(data="data")
+ b1 = B()
+ a1.bs.append(b1)
+ is_(a1, b1.a)
},
],
),
+ argnames="cols, expected",
)
def test_column_descriptions(self, cols, expected):
User, Address = self.classes("User", "Address")
)
stmt = select(*cols)
+
eq_(stmt.column_descriptions, expected)
+ if stmt._propagate_attrs:
+ stmt = select(*cols).from_statement(stmt)
+ eq_(stmt.column_descriptions, expected)
+
@testing.combinations(insert, update, delete, argnames="dml_construct")
@testing.combinations(
(
import sqlalchemy as sa
from sqlalchemy import delete
from sqlalchemy import event
+from sqlalchemy import exc as sa_exc
from sqlalchemy import ForeignKey
+from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal_column
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_not
+from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
),
)
+ def test_override_parameters_executesingle(self):
+ User = self.classes.User
+
+ sess = Session(testing.db, future=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def one(ctx):
+ return ctx.invoke_statement(params={"name": "overridden"})
+
+ orig_params = {"id": 18, "name": "original"}
+ with self.sql_execution_asserter() as asserter:
+ sess.execute(insert(User), orig_params)
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO users (id, name) VALUES (:id, :name)",
+ [{"id": 18, "name": "overridden"}],
+ )
+ )
+ # orig params weren't mutated
+ eq_(orig_params, {"id": 18, "name": "original"})
+
+ def test_override_parameters_executemany(self):
+ User = self.classes.User
+
+ sess = Session(testing.db, future=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def one(ctx):
+ return ctx.invoke_statement(
+ params=[{"name": "overridden1"}, {"name": "overridden2"}]
+ )
+
+ orig_params = [
+ {"id": 18, "name": "original1"},
+ {"id": 19, "name": "original2"},
+ ]
+ with self.sql_execution_asserter() as asserter:
+ sess.execute(insert(User), orig_params)
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO users (id, name) VALUES (:id, :name)",
+ [
+ {"id": 18, "name": "overridden1"},
+ {"id": 19, "name": "overridden2"},
+ ],
+ )
+ )
+ # orig params weren't mutated
+ eq_(
+ orig_params,
+ [{"id": 18, "name": "original1"}, {"id": 19, "name": "original2"}],
+ )
+
+ def test_override_parameters_executemany_mismatch(self):
+ User = self.classes.User
+
+ sess = Session(testing.db, future=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def one(ctx):
+ return ctx.invoke_statement(
+ params=[{"name": "overridden1"}, {"name": "overridden2"}]
+ )
+
+ orig_params = [
+ {"id": 18, "name": "original1"},
+ {"id": 19, "name": "original2"},
+ {"id": 20, "name": "original3"},
+ ]
+ with expect_raises_message(
+ sa_exc.InvalidRequestError,
+ r"Can't apply executemany parameters to statement; number "
+ r"of parameter sets passed to Session.execute\(\) \(3\) does "
+ r"not match number of parameter sets given to "
+ r"ORMExecuteState.invoke_statement\(\) \(2\)",
+ ):
+ sess.execute(insert(User), orig_params)
+
def test_chained_events_one(self):
sess = Session(testing.db, future=True)
from sqlalchemy.orm import attributes
from sqlalchemy.orm import clear_mappers
from sqlalchemy.orm import exc as orm_exc
-from sqlalchemy.orm import instrumentation
from sqlalchemy.orm import lazyload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import state as sa_state
u2 = loads(dumps(u1))
eq_(u1, u2)
- def test_09_pickle(self):
- users = self.tables.users
- self.mapper_registry.map_imperatively(User, users)
- sess = fixture_session()
- sess.add(User(id=1, name="ed"))
- sess.commit()
- sess.close()
-
- inst = User(id=1, name="ed")
- del inst._sa_instance_state
-
- state = sa_state.InstanceState.__new__(sa_state.InstanceState)
- state_09 = {
- "class_": User,
- "modified": False,
- "committed_state": {},
- "instance": inst,
- "callables": {"name": state, "id": state},
- "key": (User, (1,)),
- "expired": True,
- }
- manager = instrumentation._SerializeManager.__new__(
- instrumentation._SerializeManager
- )
- manager.class_ = User
- state_09["manager"] = manager
- state.__setstate__(state_09)
- eq_(state.expired_attributes, {"name", "id"})
-
- sess = fixture_session()
- sess.add(inst)
- eq_(inst.name, "ed")
- # test identity_token expansion
- eq_(sa.inspect(inst).key, (User, (1,), None))
-
- def test_11_pickle(self):
- users = self.tables.users
- self.mapper_registry.map_imperatively(User, users)
- sess = fixture_session()
- u1 = User(id=1, name="ed")
- sess.add(u1)
- sess.commit()
-
- sess.close()
-
- manager = instrumentation._SerializeManager.__new__(
- instrumentation._SerializeManager
- )
- manager.class_ = User
-
- state_11 = {
- "class_": User,
- "modified": False,
- "committed_state": {},
- "instance": u1,
- "manager": manager,
- "key": (User, (1,)),
- "expired_attributes": set(),
- "expired": True,
- }
-
- state = sa_state.InstanceState.__new__(sa_state.InstanceState)
- state.__setstate__(state_11)
-
- eq_(state.identity_token, None)
- eq_(state.identity_key, (User, (1,), None))
-
def test_state_info_pickle(self):
users = self.tables.users
self.mapper_registry.map_imperatively(User, users)
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy import util
from sqlalchemy.orm import query
from sqlalchemy.orm import relationship
from sqlalchemy.orm import scoped_session
populate_existing=False,
with_for_update=None,
identity_token=None,
- execution_options=None,
+ execution_options=util.EMPTY_DICT,
),
],
)
iscoroutine = inspect.iscoroutinefunction(fn)
- if spec.defaults:
- new_defaults = tuple(
- _repr_sym("util.EMPTY_DICT") if df is util.EMPTY_DICT else df
- for df in spec.defaults
- )
+ if spec.defaults or spec.kwonlydefaults:
elem = list(spec)
- elem[3] = tuple(new_defaults)
+
+ if spec.defaults:
+ new_defaults = tuple(
+ _repr_sym("util.EMPTY_DICT")
+ if df is util.EMPTY_DICT
+ else df
+ for df in spec.defaults
+ )
+ elem[3] = new_defaults
+
+ if spec.kwonlydefaults:
+ new_kwonlydefaults = {
+ name: _repr_sym("util.EMPTY_DICT")
+ if df is util.EMPTY_DICT
+ else df
+ for name, df in spec.kwonlydefaults.items()
+ }
+ elem[5] = new_kwonlydefaults
+
spec = compat.FullArgSpec(*elem)
caller_argspec = format_argspec_plus(spec, grouped=False)